xref: /aosp_15_r20/external/pytorch/test/test_sparse.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: sparse"]
2
3import torch
4import itertools
5import functools
6import operator
7import random
8import unittest
9from torch.testing import make_tensor
10from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \
11    load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, \
12    DeterministicGuard, first_sample, TEST_WITH_CROSSREF, TEST_WITH_ROCM, skipIfTorchDynamo, \
13    parametrize, subtest, is_coalesced_indices, suppress_warnings, instantiate_parametrized_tests, \
14    skipIfCrossRef
15from torch.testing._internal.common_cuda import TEST_CUDA
16from numbers import Number
17from typing import Dict, Any
18from packaging import version
19from torch.testing._internal.common_cuda import \
20    (SM53OrLater, SM80OrLater, TEST_MULTIGPU)
21from torch.testing._internal.common_device_type import \
22    (instantiate_device_type_tests, ops, dtypes, dtypesIfCUDA, onlyCPU, onlyCUDA, precisionOverride,
23     deviceCountAtLeast, OpDTypes, onlyNativeDeviceTypes)
24from torch.testing._internal.common_methods_invocations import \
25    (op_db, reduction_ops, sparse_unary_ufuncs, sparse_masked_reduction_ops, binary_ufuncs)
26from torch.testing._internal.common_dtype import (
27    all_types, all_types_and_complex, all_types_and_complex_and, floating_and_complex_types,
28    floating_and_complex_types_and, integral_types, floating_types_and,
29)
30from torch.testing._internal.opinfo.definitions.sparse import validate_sample_input_sparse
31from torch.testing._internal.opinfo.refs import (
32    ElementwiseBinaryPythonRefInfo,
33    ReductionPythonRefInfo
34)
35
36def _op_supports_any_sparse(op):
37    return (op.supports_sparse
38            or op.supports_sparse_csr
39            or op.supports_sparse_csc
40            or op.supports_sparse_bsr
41            or op.supports_sparse_bsc)
42
43
44
45reduction_ops_with_sparse_support = [
46    op for op in reduction_ops if 'masked.' not in op.name and
47    _op_supports_any_sparse(op) and not isinstance(op, ReductionPythonRefInfo)]
48
49binary_ufuncs_with_sparse_support = [
50    op for op in binary_ufuncs if _op_supports_any_sparse(op) and
51    not isinstance(op, ElementwiseBinaryPythonRefInfo)]
52
53like_fns_with_sparse_support = [op for op in op_db if _op_supports_any_sparse(op) and '_like' in op.name]
54
55if TEST_SCIPY:
56    import scipy.sparse
57
58# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
59# sharding on sandcastle. This line silences flake warnings
60load_tests = load_tests
61
62# batched grad doesn't support sparse
63gradcheck = functools.partial(gradcheck, check_batched_grad=False)
64
65CUSPARSE_SPMM_COMPLEX128_SUPPORTED = (
66    IS_WINDOWS and torch.version.cuda and version.parse(torch.version.cuda) > version.parse("11.2")
67) or (not IS_WINDOWS and not TEST_WITH_ROCM)
68
69HIPSPARSE_SPMM_COMPLEX128_SUPPORTED = torch.version.hip and version.parse(torch.version.hip.split("-")[0]) >= version.parse("6.0")
70
71def all_sparse_layouts(test_name='layout', include_strided=False):
72    return parametrize(test_name, [
73        subtest(torch.strided, name='Strided'),
74        subtest(torch.sparse_coo, name='SparseCOO'),
75        subtest(torch.sparse_csr, name='SparseCSR'),
76        subtest(torch.sparse_csc, name='SparseCSC'),
77        subtest(torch.sparse_bsr, name='SparseBSR'),
78        subtest(torch.sparse_bsc, name='SparseBSC'),
79    ][(0 if include_strided else 1):])
80
81def gradcheck_semantics(test_name='gradcheck'):
82    gradcheck_sparse = functools.partial(gradcheck, masked=False)
83    gradcheck_masked = functools.partial(gradcheck, masked=True)
84    gradcheck_sparse.masked = False
85    gradcheck_masked.masked = True
86    return parametrize(test_name, [
87        subtest(gradcheck_sparse, name='sparse'),
88        subtest(gradcheck_masked, name='masked')])
89
90
91class CrossRefSparseFakeMode(torch._subclasses.CrossRefFakeMode):
92    def __init__(self) -> None:
93        super().__init__(
94            self.ignore_op, check_strides=False,
95            check_aliasing=False,
96        )  # TODO: enable stride/alias checking
97
98    # empty_like excluded for now due to sparse complex
99    # aten._to_dense.default this one is getting called with csc
100    @staticmethod
101    def ignore_op(func):
102        return func in (
103            torch.ops.aten.empty_like.default,
104            torch.ops.aten.set_.source_Storage_storage_offset,
105            torch.ops.aten.sspaddmm.out,
106            torch.ops.aten._spdiags.default,
107            torch.ops.aten._to_dense.default,
108            torch.ops.aten.indices.default,
109            torch.ops.aten._indices.default,
110            torch.ops.aten.values.default,
111            torch.ops.aten._values.default,
112        )
113
114class TestSparseLegacyAndDeprecation(TestCase):
115
116    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
117    def test_legacy_warnings(self):
118
119        def f1():
120            "torch.sparse.SparseTensor() is deprecated."\
121                "  Please use torch.sparse_coo_tensor((0,), dtype=)"
122            x_ref = torch.sparse_coo_tensor((0,), dtype=torch.float64)
123            x = torch.sparse.DoubleTensor()
124            self.assertEqual(x, x_ref)
125
126        def f2():
127            "torch.sparse.SparseTensor(cdata=x._cdata) is deprecated."\
128                "  Please use torch.sparse_coo_tensor(x._indices(), x._values(), x.shape)"
129            x_ref = torch.tensor([[1, 2], [3, 4]], dtype=torch.float64).to_sparse()
130            x = torch.sparse.DoubleTensor(cdata=x_ref._cdata)
131            y = torch.sparse_coo_tensor(x._indices(), x._values(), x.shape)
132            self.assertEqual(x, x_ref)
133            self.assertEqual(y, x_ref)
134
135        def f3():
136            "torch.sparse.SparseTensor(indices, values, *, device=) is deprecated."\
137                "  Please use torch.sparse_coo_tensor(indices, values, dtype=, device=)"
138            x_ref = torch.sparse_coo_tensor([[0, 0, 1, 1], [0, 1, 0, 1]], [1, 2, 3, 4], dtype=torch.float64)
139            x = torch.sparse.DoubleTensor(torch.tensor([[0, 0, 1, 1], [0, 1, 0, 1]]),
140                                          torch.tensor([1, 2, 3, 4], dtype=torch.float64))
141            self.assertEqual(x, x_ref)
142
143        def f4():
144            "torch.sparse.SparseTensor(indices, values, shape, *, device=) is deprecated."\
145                "  Please use torch.sparse_coo_tensor(indices, values, shape, dtype=, device=)"
146            x_ref = torch.sparse_coo_tensor([[0, 0, 1, 1], [0, 1, 0, 1]], [1, 2, 3, 4], (2, 3), dtype=torch.float64)
147            x = torch.sparse.DoubleTensor(torch.tensor([[0, 0, 1, 1], [0, 1, 0, 1]]),
148                                          torch.tensor([1, 2, 3, 4], dtype=torch.float64), (2, 3))
149            self.assertEqual(x, x_ref)
150
151        def f5():
152            "torch.sparse.SparseTensor(shape, *, device=) is deprecated."\
153                "  Please use torch.sparse_coo_tensor(shape, dtype=, device=)"
154            x_ref = torch.sparse_coo_tensor((2, 3), dtype=torch.float64)
155            x = torch.sparse.DoubleTensor(2, 3)
156            self.assertEqual(x, x_ref)
157
158        for test_f in [f1, f2, f3, f4, f5]:
159
160            with self.assertWarns(UserWarning, msg=test_f.__doc__) as cm:
161                test_f()
162                test_f()
163
164            # Check warn-once:
165            self.assertEqual(len(cm.warnings), 1)
166
167
168class TestSparseBase(TestCase):
169    def run(self, result=None):
170        if TEST_WITH_CROSSREF:
171            with CrossRefSparseFakeMode():
172                return super().run(result)
173        else:
174            return super().run(result)
175
176class TestSparse(TestSparseBase):
177
178    def setUp(self):
179        TestCase.setUp(self)
180
181        self.index_tensor = lambda *args, **kwargs: torch.tensor(*args, **kwargs, dtype=torch.int64)
182
183        def sparse_empty_factory(*args, **kwargs):
184            kwargs['layout'] = kwargs.get('layout', torch.sparse_coo)
185            return torch.empty(*args, **kwargs)
186        self.sparse_empty = sparse_empty_factory
187
188        def sparse_tensor_factory(*args, **kwargs):
189            return torch.sparse_coo_tensor(*args, **kwargs)
190        self.sparse_tensor = sparse_tensor_factory
191
192    def _gen_sparse(self, sparse_dim, nnz, with_size, dtype, device, coalesced):
193        if isinstance(with_size, Number):
194            with_size = [with_size] * sparse_dim
195
196        x, i, v = self.genSparseTensor(with_size, sparse_dim, nnz, not coalesced, dtype=dtype, device=device)
197
198        if not coalesced:
199            self.assert_uncoalesced(x)
200
201        return x, i, v
202
203    def assert_uncoalesced(self, x):
204        """
205        Test if a CPU tensor is uncoalesced.  This is used to ensure
206        correctness of the uncoalesced tensor generation algorithm.
207        """
208        assert not x.is_coalesced()
209        existing_indices = set()
210        indices = x._indices()
211        for i in range(x._nnz()):
212            index = str(indices[:, i])
213            if index in existing_indices:
214                return True
215            else:
216                existing_indices.add(index)
217
218    def randn(self, *args, **kwargs):
219        """
220        Variant of torch.randn that also works in the TEST_CUDA case.
221        """
222        # TODO: Put this in torch.cuda.randn
223        return torch.empty(*args, **kwargs).normal_()
224
225    @dtypes(torch.double)
226    def test_print_coalesced(self, device, dtype):
227        self._test_print(device, dtype, True)
228
229    @dtypes(torch.double)
230    def test_print_uncoalesced(self, device, dtype):
231        self._test_print(device, dtype, False)
232
233    def _test_print(self, device, dtype, coalesced):
234        shape_sparse_dim_nnz = [
235            ((), 0, 2),
236            ((0,), 0, 10),
237            ((2,), 0, 3),
238            ((100, 3), 1, 3),
239            ((100, 20, 3), 2, 0),
240            ((10, 0, 3), 0, 3),
241            ((10, 0, 3), 0, 0),
242        ]
243        printed = []
244        for shape, sparse_dim, nnz in shape_sparse_dim_nnz:
245            indices_shape = torch.Size((sparse_dim, nnz))
246            values_shape = torch.Size((nnz,) + shape[sparse_dim:])
247            printed.append(f"# shape: {torch.Size(shape)}")
248            printed.append(f"# nnz: {nnz}")
249            printed.append(f"# sparse_dim: {sparse_dim}")
250            printed.append(f"# indices shape: {indices_shape}")
251            printed.append(f"# values shape: {values_shape}")
252
253            indices = torch.arange(indices_shape.numel(), dtype=self.index_tensor(0).dtype,
254                                   device=device).view(indices_shape)
255            for d in range(sparse_dim):
256                indices[d].clamp_(max=(shape[d] - 1))  # make it valid index
257            if not coalesced and indices.numel() > 0:
258                indices[:, -1] = indices[:, 0]  # make it uncoalesced
259            values_numel = values_shape.numel()
260            values = torch.arange(values_numel, dtype=dtype,
261                                  device=device).view(values_shape).div_(values_numel / 2.)
262            sp_tensor = self.sparse_tensor(indices, values, shape, dtype=dtype, device=device)
263
264            dtypes = [torch.int32]
265            if values.dtype == torch.double:
266                dtypes.append(torch.float)
267            else:
268                dtypes.append(torch.double)
269            for dtype in dtypes:
270                printed.append(f"########## {dtype} ##########")
271                x = sp_tensor.detach().to(dtype)
272                printed.append("# sparse tensor")
273                printed.append(str(x))
274                if x.dtype.is_floating_point:
275                    printed.append("# after requires_grad_")
276                    printed.append(str(x.requires_grad_()))
277                    printed.append("# after addition")
278                    printed.append(str(x + x))
279                printed.append("# _indices")
280                printed.append(str(x._indices()))
281                printed.append("# _values")
282                printed.append(str(x._values()))
283            printed.append('')
284        self.assertExpected('\n'.join(printed))
285
286    @coalescedonoff
287    @dtypes(torch.double, torch.cdouble)
288    def test_basic(self, device, dtype, coalesced):
289        def test_shape(sparse_dims, nnz, with_size):
290            if isinstance(with_size, Number):
291                with_size = [with_size] * sparse_dims
292            x, i, v = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)
293            self.assertEqual(i, x._indices())
294            self.assertEqual(v, x._values())
295            self.assertEqual(x.ndimension(), len(with_size))
296            self.assertEqual(x.coalesce()._nnz(), nnz if x.is_coalesced() else nnz // 2)
297            self.assertEqual(list(x.size()), with_size)
298
299            # Test .indices() and .values()
300            if not coalesced:
301                with self.assertRaisesRegex(RuntimeError, "Cannot get indices on an uncoalesced tensor"):
302                    x.indices()
303                with self.assertRaisesRegex(RuntimeError, "Cannot get values on an uncoalesced tensor"):
304                    x.values()
305            else:
306                self.assertEqual(x.indices(), x._indices())
307                self.assertEqual(x.values(), x._values())
308
309        test_shape(3, 10, 100)
310        test_shape(3, 10, [100, 100, 100])
311        test_shape(3, 10, [100, 100, 100, 5, 5, 5, 0])
312        test_shape(3, 0, [0, 0, 100, 5, 5, 5, 0])
313
314        # Make sure that coalesce handles duplicate indices correctly
315        i = self.index_tensor([[9, 0, 0, 0, 8, 1, 1, 1, 2, 7, 2, 2, 3, 4, 6, 9]], device=device)
316        v = torch.tensor([[idx**2, idx] for idx in range(i.size(1))], dtype=dtype, device=device)
317        x = self.sparse_tensor(i, v, torch.Size([10, 2]), dtype=dtype, device=device)
318        self.assertEqual(x.coalesce()._nnz(), 9)
319
320    @coalescedonoff
321    @dtypes(torch.double, torch.cdouble, torch.bfloat16)
322    @precisionOverride({torch.bfloat16: 1e-2})
323    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
324    def test_coalesce(self, device, dtype, coalesced):
325
326        def _test_coalesce(t):
327            tc = t.coalesce()
328            self.assertEqual(tc.to_dense(), t.to_dense())
329            self.assertTrue(tc.is_coalesced())
330            # Our code below doesn't work when nnz is 0, because
331            # then it's a 0D tensor, not a 2D tensor.
332            if t._nnz() == 0:
333                self.assertEqual(t._indices(), tc._indices())
334                self.assertEqual(t._values(), tc._values())
335                return tc
336
337            value_map: Dict[Any, Any] = {}
338            for idx, val in zip(t._indices().t(), t._values()):
339                idx_tup = tuple(idx.tolist())
340                if idx_tup in value_map:
341                    value_map[idx_tup] += val
342                else:
343                    value_map[idx_tup] = val.clone() if isinstance(val, torch.Tensor) else val
344
345            new_indices = sorted(value_map.keys())
346            _new_values = [value_map[idx] for idx in new_indices]
347            if t._values().ndimension() < 2:
348                new_values = t._values().new(_new_values)
349            else:
350                new_values = torch.stack(_new_values)
351
352            new_indices = t._indices().new(new_indices).t()
353            tg = t.new(new_indices, new_values, t.size())
354
355            self.assertEqual(tc._indices(), tg._indices())
356            self.assertEqual(tc._values(), tg._values())
357
358            if t.is_coalesced():
359                self.assertEqual(tc._indices(), t._indices())
360                self.assertEqual(tc._values(), t._values())
361
362        for empty_i, empty_v, empty_nnz in itertools.product([True, False], repeat=3):
363            sparse_size = [] if empty_i else [2, 1]
364            dense_size = [1, 0, 2] if empty_v else [1, 2]
365            nnz = 0 if empty_nnz else 5
366
367            t, _, _ = self._gen_sparse(len(sparse_size), nnz, sparse_size + dense_size, dtype, device, coalesced)
368            _test_coalesce(t)  # this tests correctness
369
370    @dtypes(torch.double)
371    @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/89395")
372    def test_coalesce_reference_cycle(self, device, dtype):
373        # Test coalesce doesn't create autograd graph cycles (gh-52253)
374
375        # Sanity check that the helper class works as expected
376        t = torch.rand(2)
377        t_ref = torch._C._WeakTensorRef(t)
378        self.assertFalse(t_ref.expired())
379
380        del t
381        self.assertTrue(t_ref.expired())
382
383        def test_sparse_sum():
384            i = torch.tensor([[0], [4]], dtype=torch.long, device=device)
385            v = torch.tensor([[[-0.4567, -1.8797, 0.0380, 1.4316]]],
386                             dtype=dtype, device=device)
387            S = torch.sparse_coo_tensor(i, v)
388            S = S.coalesce()
389            S.requires_grad_(True)
390            S2 = S.coalesce()
391            self.assertTrue(S2.is_coalesced())
392            return torch._C._WeakTensorRef(S2)
393
394        ref = test_sparse_sum()
395        self.assertTrue(ref.expired())
396
397    @dtypes(torch.double)
398    def test_ctor_large_sizes(self, device, dtype):
399        # Test that integer overflow is detected when computing numel
400        # of a sparse tensor with large dimensions (gh-57416). Notice
401        # that numel is computed internally when constructing a
402        # tensor, hence the overflow may appear during the tensor
403        # construction step.
404        N = 100000
405        indices = torch.tensor([[N, N - 1]] * 4, dtype=torch.int64, device=device)
406        values = torch.tensor([1, 2], dtype=dtype, device=device)
407        self.assertRaises(RuntimeError,
408                          lambda: torch.sparse_coo_tensor(
409                              indices, values, (N + 1,) * 4, device=device))
410
411    @dtypes(torch.double, torch.cdouble)
412    def test_ctor_size_checks(self, device, dtype):
413        indices = self.index_tensor([
414            [0, 0, 0],
415            [0, 3, 0],
416            [0, 0, 0],
417            [0, 0, 0],
418        ], device=device)
419        values = torch.tensor([2, 1, 3, 4], dtype=dtype, device=device)
420
421        # indices inconsistent with size
422        self.assertRaises(
423            RuntimeError,
424            lambda: self.sparse_tensor(indices, values, torch.Size([2, 1, 1])))
425
426        # values inconsistent with size
427        values = torch.tensor([
428            [2, 1, 2, 1],
429            [1, 0, 5, 2],
430        ], dtype=dtype, device=device)
431        self.assertRaises(
432            RuntimeError,
433            lambda: self.sparse_tensor(indices, values, torch.Size([2, 4, 2, 1])))
434
435    @coalescedonoff
436    @dtypes(torch.double)
437    def test_ctor_is_coalesced_with_gradcheck(self, device, dtype, coalesced):
438        for sparse_size, nnz in (((3, 3), 5), ((2, 3, 1, 5), 11)):
439            t, _, _ = self._gen_sparse(len(sparse_size), nnz, sparse_size, dtype, device, coalesced)
440            self.assertEqual(t.is_coalesced(), coalesced)
441
442            def func(indices, values, shape, is_coalesced):
443                s = torch.sparse_coo_tensor(indices, values, shape, check_invariants=True, is_coalesced=is_coalesced)
444                self.assertEqual(s.is_coalesced(), is_coalesced)
445                return s.to_dense(masked_grad=False)
446
447            if coalesced:
448                torch.autograd.gradcheck(func, (t._indices(), t._values().requires_grad_(True), t.shape, False))
449                torch.autograd.gradcheck(func, (t._indices(), t._values().requires_grad_(True), t.shape, True))
450            else:
451                torch.autograd.gradcheck(func, (t._indices(), t._values().requires_grad_(True), t.shape, False))
452                with self.assertRaisesRegex(RuntimeError,
453                                            "cannot set is_coalesced to true if indices correspond to uncoalesced COO tensor"):
454                    torch.autograd.gradcheck(func, (t._indices(), t._values().requires_grad_(True), t.shape, True))
455
456    @dtypes(*floating_and_complex_types_and(torch.float16, torch.bfloat16))
457    @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
458    @gradcheck_semantics()
459    def test_to_dense_with_gradcheck(self, device, dtype, gradcheck):
460
461        def test_tensor(x, res):
462            x.to_dense()  # Tests triple to_dense for memory corruption
463            x.to_dense()
464            x.to_dense()
465            dense_x = x.to_dense()
466            safe_dense_x = self.safeToDense(x)
467            dense_x = dense_x.to(res.dtype)
468            safe_dense_x = safe_dense_x.to(res.dtype)
469            self.assertEqual(res, dense_x)
470            self.assertEqual(res, safe_dense_x)
471
472            # Only run autograd test for float64
473            if x.dtype != torch.float64:
474                return
475
476            def fn(x):
477                return x.to_dense(masked_grad=gradcheck.masked)
478            x.requires_grad_(True)
479            gradcheck(fn, (x,))
480
481        for value_type in [torch.double, torch.cdouble]:
482            i = self.index_tensor([
483                [0, 1, 2, 2],
484                [0, 0, 0, 3],
485                [0, 0, 1, 4],
486            ], device=device)
487            # we don't have to_dense for half types on CPU because it is implemented
488            # with a slower add_ operation
489            v = torch.tensor([2, 1, 3, 4], dtype=dtype, device=device)
490            x = self.sparse_tensor(i, v, torch.Size([3, 4, 5]), dtype=value_type, device=device)
491            res = torch.tensor([
492                [[2, 0, 0, 0, 0],
493                 [0, 0, 0, 0, 0],
494                 [0, 0, 0, 0, 0],
495                 [0, 0, 0, 0, 0]],
496                [[1, 0, 0, 0, 0],
497                 [0, 0, 0, 0, 0],
498                 [0, 0, 0, 0, 0],
499                 [0, 0, 0, 0, 0]],
500                [[0, 3, 0, 0, 0],
501                 [0, 0, 0, 0, 0],
502                 [0, 0, 0, 0, 0],
503                 [0, 0, 0, 0, 4]],
504            ], dtype=dtype, device=device)
505
506            test_tensor(x, res)
507            test_tensor(res, res)
508
509            i = self.index_tensor([
510                [0, 1, 2, 2],
511                [0, 0, 0, 3],
512                [0, 0, 1, 4],
513            ], device=device)
514            v = torch.empty(4, 0, dtype=dtype, device=device)
515            x = self.sparse_tensor(i, v, torch.Size([3, 4, 5, 0]), dtype=value_type, device=device)
516            res = torch.empty((3, 4, 5, 0), dtype=dtype, device=device)
517            test_tensor(x, res)
518
519    @coalescedonoff
520    @dtypes(torch.float16, torch.bfloat16, torch.float64, torch.int, torch.cfloat, torch.cdouble)
521    def test_to_sparse(self, device, dtype, coalesced):
522        shape = [5, 2, 10, 4]
523        max_nnz = 1
524        for value_type in [torch.double, torch.cdouble]:
525            for dim, dim_sz in enumerate(shape, 1):
526                max_nnz *= dim_sz
527                rnnz = torch.randint(2, max_nnz, (1,)).item()
528                for nnz in [0, 1, rnnz]:
529                    expected, _, _ = self._gen_sparse(dim, nnz, shape, dtype=value_type, device=device,
530                                                      coalesced=coalesced)
531                    expected = expected.to(dtype)
532
533                    d = expected.to_dense()
534                    result = d.to_sparse(dim)
535                    self.assertEqual(d, result.to_dense())
536                    self.assertEqual(expected.size(), result.size())
537                    self.assertEqual(dim, result.sparse_dim())
538
539    @dtypes(torch.double, torch.cdouble)
540    def test_sparse_bool(self, device, dtype):
541        a = torch.tensor([True, False], dtype=dtype, device=device).to(torch.bool)
542        b = a.to_sparse().to_dense()
543        self.assertEqual(a, b)
544
545    @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/108667")
546    @dtypes(torch.double, torch.cdouble)
547    def test_scalar(self, device, dtype):
548        # tensor with value
549        a = self.sparse_tensor(self.index_tensor([], device=device).unsqueeze(1), 12.3, [], dtype=dtype, device=device)
550        self.assertEqual(1, a._values().numel())
551        self.assertEqual(a, a.clone())
552        a_coalesced = a.coalesce()
553        self.assertTrue(a_coalesced.is_coalesced())
554        self.assertEqual(torch.tensor(12.3, dtype=dtype, device=device), a.to_dense())
555        self.assertEqual(a, a.to_dense().to_sparse())
556
557        # tensor with multiple values
558        a = self.sparse_tensor(self.index_tensor([], device=device).unsqueeze(1).expand(0, 2),
559                               [12.3, 12.3], [], dtype=dtype, device=device)
560        self.assertEqual(2, a._values().numel())
561        self.assertEqual(a, a.clone())
562        a_coalesced = a.coalesce()
563        self.assertTrue(a_coalesced.is_coalesced())
564        self.assertEqual(torch.tensor(12.3 * 2, dtype=dtype, device=device), a.to_dense())
565        self.assertEqual(a.coalesce(), a.coalesce().to_dense().to_sparse())
566
567        # tensor without value
568        a = self.sparse_empty((), dtype=dtype, device=device)
569        self.assertEqual(0, a._values().numel())
570        self.assertEqual(a, a.clone())
571        a_coalesced = a.coalesce()
572        self.assertTrue(a_coalesced.is_coalesced())
573        self.assertEqual(torch.tensor(0, dtype=dtype, device=device), a.to_dense())
574        self.assertEqual(a, a.to_dense().to_sparse())
575
576    @dtypes(torch.double, torch.cdouble)
577    def test_shared(self, device, dtype):
578        i = self.index_tensor([[2]], device=device)
579        v = torch.tensor([5], dtype=dtype, device=device)
580        x = self.sparse_tensor(i, v, torch.Size([3]))
581        v[0] = 6
582        self.assertEqual(torch.tensor([0, 0, 6], dtype=dtype, device=device), self.safeToDense(x))
583        i[0][0] = 0
584        self.assertEqual(torch.tensor([6, 0, 0], dtype=dtype, device=device), self.safeToDense(x))
585
586        i = self.index_tensor([[2]], device=device)
587        v = torch.empty((1, 0), dtype=dtype, device=device)
588        x = self.sparse_tensor(i, v, torch.Size([3, 0]))
589        i[0][0] = 0
590        self.assertEqual(torch.empty((3, 0), dtype=dtype, device=device), self.safeToDense(x))
591
592    @dtypes(torch.double, torch.cdouble)
593    @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
594    @gradcheck_semantics()
595    def test_to_dense_hybrid(self, device, dtype, gradcheck):
596
597        def test_tensor(x, res):
598            x.to_dense()  # Tests double to_dense for memory corruption
599            x.to_dense()
600            x.to_dense()
601            self.assertEqual(res, x.to_dense())
602            self.assertEqual(res, self.safeToDense(x))
603
604            def fn(x):
605                return x.to_dense(masked_grad=gradcheck.masked)
606            x.requires_grad_(True)
607            gradcheck(fn, (x,))
608
609        i = self.index_tensor([
610            [0, 1, 2, 2],
611            [0, 0, 0, 3],
612        ], device=device)
613        v = torch.tensor([[2, 3], [1, 2], [3, 4], [4, 5]], dtype=dtype, device=device)
614        x = self.sparse_tensor(i, v, torch.Size([3, 4, 2]))
615        res = torch.tensor([
616            [[2, 3],
617             [0, 0],
618             [0, 0],
619             [0, 0]],
620            [[1, 2],
621             [0, 0],
622             [0, 0],
623             [0, 0]],
624            [[3, 4],
625             [0, 0],
626             [0, 0],
627             [4, 5]],
628        ], dtype=dtype, device=device)
629        test_tensor(x, res)
630
631        i = self.index_tensor([
632            [0, 1, 2, 2],
633            [0, 0, 0, 3],
634        ], device=device)
635        v = torch.empty((4, 2, 0), dtype=dtype, device=device)
636        x = self.sparse_tensor(i, v, torch.Size([3, 4, 2, 0]))
637        res = torch.empty((3, 4, 2, 0), dtype=dtype, device=device)
638        test_tensor(x, res)
639
640    @dtypes(torch.double, torch.cdouble)
641    def test_contig(self, device, dtype):
642        def test_tensor(x, exp_i, exp_v):
643            x = x.coalesce()
644            self.assertEqual(exp_i, x._indices())
645            self.assertEqual(exp_v, x._values())
646
647        i = self.index_tensor([
648            [1, 0, 35, 14, 39, 6, 71, 66, 40, 27],
649            [92, 31, 62, 50, 22, 65, 89, 74, 56, 34],
650        ], device=device)
651        v = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype, device=device)
652        x = self.sparse_tensor(i, v, torch.Size([100, 100]))
653        exp_i = self.index_tensor([
654            [0, 1, 6, 14, 27, 35, 39, 40, 66, 71],
655            [31, 92, 65, 50, 34, 62, 22, 56, 74, 89],
656        ], device=device)
657        exp_v = torch.tensor([2, 1, 6, 4, 10, 3, 5, 9, 8, 7], dtype=dtype, device=device)
658        test_tensor(x, exp_i, exp_v)
659
660        i = self.index_tensor([
661            [2, 0, 2, 1],
662            [0, 0, 3, 0],
663            [1, 0, 4, 0],
664        ], device=device)
665        v = torch.tensor([3, 2, 4, 1], dtype=dtype, device=device)
666        x = self.sparse_tensor(i, v, torch.Size([3, 4, 5]))
667        exp_i = self.index_tensor([
668            [0, 1, 2, 2],
669            [0, 0, 0, 3],
670            [0, 0, 1, 4],
671        ], device=device)
672        exp_v = torch.tensor([2, 1, 3, 4], dtype=dtype, device=device)
673        test_tensor(x, exp_i, exp_v)
674
675        i = self.index_tensor([
676            [2, 0, 2, 1],
677            [0, 0, 3, 0],
678            [1, 0, 4, 0],
679        ], device=device)
680        v = torch.empty([4, 0], dtype=dtype, device=device)
681        x = self.sparse_tensor(i, v, torch.Size([3, 4, 5, 0]))
682        exp_i = self.index_tensor([
683            [0, 1, 2, 2],
684            [0, 0, 0, 3],
685            [0, 0, 1, 4],
686        ], device=device)
687        exp_v = torch.empty([4, 0], dtype=dtype, device=device)
688        test_tensor(x, exp_i, exp_v)
689
690        # Duplicate indices
691        i = self.index_tensor([
692            [0, 0, 2, 0],
693            [0, 0, 3, 0],
694            [0, 0, 4, 0],
695        ], device=device)
696        v = torch.tensor([3, 2, 4, 1], dtype=dtype, device=device)
697        x = self.sparse_tensor(i, v, torch.Size([3, 4, 5]))
698        exp_i = self.index_tensor([
699            [0, 2],
700            [0, 3],
701            [0, 4],
702        ], device=device)
703        exp_v = torch.tensor([6, 4], dtype=dtype, device=device)
704        test_tensor(x, exp_i, exp_v)
705
706        i = self.index_tensor([
707            [0, 0, 2, 0],
708            [0, 0, 3, 0],
709            [0, 0, 4, 0],
710        ], device=device)
711        v = torch.empty([4, 0], dtype=dtype, device=device)
712        x = self.sparse_tensor(i, v, torch.Size([3, 4, 5, 0]))
713        exp_i = self.index_tensor([
714            [0, 2],
715            [0, 3],
716            [0, 4],
717        ], device=device)
718        exp_v = torch.empty([2, 0], dtype=dtype, device=device)
719        test_tensor(x, exp_i, exp_v)
720
721    @dtypes(torch.double, torch.cdouble)
722    def test_contig_hybrid(self, device, dtype):
723        def test_tensor(x, exp_i, exp_v):
724            x = x.coalesce()
725            self.assertEqual(exp_i, x._indices())
726            self.assertEqual(exp_v, x._values())
727
728        i = self.index_tensor([
729            [1, 0, 35, 14, 39, 6, 71, 66, 40, 27],
730            [92, 31, 62, 50, 22, 65, 89, 74, 56, 34],
731        ], device=device)
732        v = torch.tensor([
733            [1, 2], [2, 3], [3, 4], [4, 5], [5, 6],
734            [6, 7], [7, 8], [8, 9], [9, 10], [10, 11],
735        ], dtype=dtype, device=device)
736        x = self.sparse_tensor(i, v, torch.Size([100, 100, 2]))
737        exp_i = self.index_tensor([
738            [0, 1, 6, 14, 27, 35, 39, 40, 66, 71],
739            [31, 92, 65, 50, 34, 62, 22, 56, 74, 89],
740        ], device=device)
741        exp_v = torch.tensor([
742            [2, 3], [1, 2], [6, 7], [4, 5], [10, 11],
743            [3, 4], [5, 6], [9, 10], [8, 9], [7, 8],
744        ], dtype=dtype, device=device)
745        test_tensor(x, exp_i, exp_v)
746
747        i = self.index_tensor([
748            [2, 0, 2, 1],
749            [0, 0, 3, 0],
750            [1, 0, 4, 0],
751        ], device=device)
752        v = torch.tensor([[3, 3, 3], [2, 2, 2], [4, 4, 4], [1, 1, 1]], dtype=dtype, device=device)
753        x = self.sparse_tensor(i, v, torch.Size([3, 4, 5, 3]))
754        exp_i = self.index_tensor([
755            [0, 1, 2, 2],
756            [0, 0, 0, 3],
757            [0, 0, 1, 4],
758        ], device=device)
759        exp_v = torch.tensor([[2, 2, 2], [1, 1, 1], [3, 3, 3], [4, 4, 4]], dtype=dtype, device=device)
760        test_tensor(x, exp_i, exp_v)
761
762        i = self.index_tensor([
763            [2, 0, 2, 1],
764            [0, 0, 3, 0],
765            [1, 0, 4, 0],
766        ], device=device)
767        v = torch.empty([4, 3, 0], dtype=dtype, device=device)
768        x = self.sparse_tensor(i, v, torch.Size([3, 4, 5, 3, 0]))
769        exp_i = self.index_tensor([
770            [0, 1, 2, 2],
771            [0, 0, 0, 3],
772            [0, 0, 1, 4],
773        ], device=device)
774        exp_v = torch.empty([4, 3, 0], dtype=dtype, device=device)
775        test_tensor(x, exp_i, exp_v)
776
777        # Duplicate indices
778        i = self.index_tensor([
779            [0, 0, 2, 0],
780            [0, 0, 3, 0],
781            [0, 0, 4, 0],
782        ], device=device)
783        v = torch.tensor([[3, 2, 3], [2, 1, 1], [4, 3, 4], [1, 1, 1]], dtype=dtype, device=device)
784        x = self.sparse_tensor(i, v, torch.Size([3, 4, 5, 3]))
785        exp_i = self.index_tensor([
786            [0, 2],
787            [0, 3],
788            [0, 4],
789        ], device=device)
790        exp_v = torch.tensor([[6, 4, 5], [4, 3, 4]], dtype=dtype, device=device)
791        test_tensor(x, exp_i, exp_v)
792
793        i = self.index_tensor([
794            [0, 0, 2, 0],
795            [0, 0, 3, 0],
796            [0, 0, 4, 0],
797        ], device=device)
798        v = torch.empty([4, 3, 0], dtype=dtype, device=device)
799        x = self.sparse_tensor(i, v, torch.Size([3, 4, 5, 3, 0]))
800        exp_i = self.index_tensor([
801            [0, 2],
802            [0, 3],
803            [0, 4],
804        ], device=device)
805        exp_v = torch.empty([2, 3, 0], dtype=dtype, device=device)
806        test_tensor(x, exp_i, exp_v)
807
808    @coalescedonoff
809    @dtypes(torch.double, torch.cdouble)
810    def test_clone(self, device, dtype, coalesced):
811        def test_shape(sparse_dims, nnz, with_size):
812            x = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0]
813            if not coalesced:
814                self.assertFalse(x.is_coalesced())
815                y = x.clone()
816                self.assertFalse(y.is_coalesced())
817            x = x.coalesce()
818            self.assertTrue(x.is_coalesced())
819            y = x.clone()
820            self.assertTrue(y.is_coalesced())
821
822        test_shape(4, 20, 5)
823        test_shape(3, 10, [100, 100, 100, 5, 5, 5, 0])
824        test_shape(3, 0, [0, 0, 100, 5, 5, 5, 0])
825
826    @coalescedonoff
827    @dtypes(torch.double, torch.cdouble, torch.bfloat16)
828    @precisionOverride({torch.bfloat16: 2e-2})
829    def test_Sparse_to_Sparse_copy_(self, device, dtype, coalesced):
830        # This is for testing torch.copy_(SparseTensor, SparseTensor)
831        sparse_dims = 3
832        nnz = 10
833        sizes = [2, 3, 4, 5]  # hybrid sparse
834        x1, _, _ = self._gen_sparse(sparse_dims, nnz, sizes, dtype, device, coalesced)
835        x2, _, _ = self._gen_sparse(sparse_dims, nnz + 10, sizes, dtype, device, coalesced)
836
837        # test copy
838        x2_dense = x2.to_dense()
839        x1.copy_(x2)
840        self.assertEqual(x2_dense, x1.to_dense())
841
842        # test type conversion (when x1.copy_(x2), x1.dtype should stay the same)
843        x1 = x1.to(torch.float32)
844
845        x2 = x2.to(torch.float16)
846        x1_dtype = x1.dtype
847        x1.copy_(x2)
848        self.assertEqual(x1_dtype, x1.dtype)
849
850        x2 = x2.to(torch.float64)
851        x1_dtype = x1.dtype
852        x1.copy_(x2)
853        self.assertEqual(x1_dtype, x1.dtype)
854
855        # test no broadcast
856        self.assertRaises(RuntimeError, lambda: x1.copy_(x2.narrow_copy(0, 0, 1)))
857
858        # test raise error on copy_() between dense and sparse Tensors
859        self.assertRaises(RuntimeError, lambda: x1.copy_(torch.randn(5, 5)))
860
861        # test autograd
862        x1, _, _ = self._gen_sparse(sparse_dims, nnz, sizes, dtype, device, coalesced)
863        x2, _, _ = self._gen_sparse(sparse_dims, nnz + 10, sizes, dtype, device, coalesced)
864        x2.requires_grad_(True)
865        x1.copy_(x2)
866        y = x1 * 2
867        x2_clone = x2.clone()
868        y.backward(x2_clone)
869        expected_grad = x2_clone * 2
870        self.assertEqual(expected_grad.to_dense(), x2.grad.to_dense())
871        self.assertEqual(None, x1.grad)
872
873    @coalescedonoff
874    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
875    @dtypes(torch.double, torch.cdouble)
876    def test_Sparse_to_Sparse_copy_multi_gpu(self, device, dtype, coalesced):
877        # This is for testing torch.copy_(SparseTensor, SparseTensor) across GPU devices
878        sparse_dims = 3
879        nnz = 10
880        sizes = [2, 3, 4, 5]  # hybrid sparse
881        x1, _, _ = self._gen_sparse(sparse_dims, nnz, sizes, dtype, device, coalesced)
882        x2, _, _ = self._gen_sparse(sparse_dims, nnz + 10, sizes, dtype, device, coalesced)
883        x1 = x1.to('cuda:0')
884
885        def test_cross_device(x1, x2):
886            x1_device = x1.device
887            x1.copy_(x2)
888            self.assertEqual(x2.to('cuda:0').to_dense(), x1.to_dense())
889            self.assertEqual(x1_device, x1.device)
890
891        test_cross_device(x1, x2.to('cuda:1'))  # test across gpu devices
892        test_cross_device(x1, x2.to('cpu'))  # test between cpu and gpu
893
894        # test autograd
895        x2 = x2.to('cuda:1')
896        x2.requires_grad_(True)
897        x1.copy_(x2)
898        y = x1 * 2
899        x2_clone = x2.clone().to('cuda:0')
900        y.backward(x2_clone)
901        expected_grad = x2_clone * 2
902        self.assertEqual(expected_grad.to_dense(), x2.grad.to('cuda:0').to_dense())
903        self.assertEqual(None, x1.grad)
904
905    @onlyCUDA
906    def test_cuda_empty(self, device):
907        def test_tensor(x):
908            y = x.to(device)
909            self.assertEqual(x.sparse_dim(), y.sparse_dim())
910            self.assertEqual(x.dense_dim(), y.dense_dim())
911            x = y.cpu()
912            self.assertEqual(y.sparse_dim(), x.sparse_dim())
913            self.assertEqual(y.dense_dim(), x.dense_dim())
914
915        x = torch.sparse_coo_tensor((2, 3, 4), dtype=torch.float32)
916        test_tensor(x)
917
918        x = torch.sparse_coo_tensor((2, 3, 4), dtype=torch.float16)
919        test_tensor(x)
920
921        x = torch.sparse_coo_tensor((2, 3, 4), dtype=torch.float16)
922        test_tensor(x)
923
924        x = torch.sparse_coo_tensor((2, 3, 4, 0), dtype=torch.float32)
925        test_tensor(x)
926
927    @coalescedonoff
928    @dtypes(torch.double, torch.cdouble)
929    def test_transpose(self, device, dtype, coalesced):
930        def test_shape(sparse_dims, nnz, with_size):
931            x = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0]
932            y = self.safeToDense(x)
933
934            for i, j in itertools.combinations(range(4), 2):
935                x = x.transpose_(i, j)
936                y = y.transpose(i, j)
937                self.assertEqual(self.safeToDense(x), y)
938
939                x = x.transpose(i, j)
940                y = y.transpose(i, j)
941                self.assertEqual(self.safeToDense(x), y)
942
943        test_shape(4, 6, 3)
944        test_shape(4, 3, [7, 7, 7, 3, 3, 3, 0])
945        test_shape(4, 0, [0, 0, 7, 3, 3, 3, 0])
946
947    @coalescedonoff
948    @dtypes(torch.double, torch.cdouble)
949    @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
950    @gradcheck_semantics()
951    def test_permute(self, device, dtype, coalesced, gradcheck):
952        # trivial checks
953        s = torch.rand(3, 3, 3, device=device, dtype=dtype).to_sparse()
954        with self.assertRaisesRegex(RuntimeError, "does not match the length"):
955            s.permute(dims=(1, 0))
956        with self.assertRaisesRegex(RuntimeError, "duplicate dims"):
957            s.permute(dims=(1, 1, 1))
958        # Calling permute on a sparse tensor with an empty tuple used to segfault,
959        # see https://github.com/pytorch/pytorch/issues/116325
960        x = torch.rand((), device=device, dtype=dtype).to_sparse()
961        x.permute(())
962        self.assertEqual(len(x.values()), 1)
963
964        def test_shape(sparse_dims, nnz, with_size):
965            ndim = len(with_size)
966            valid_sparse_dims = torch.arange(-ndim, -ndim + sparse_dims)
967            valid_dense_dims = torch.arange(-ndim + sparse_dims, 0)
968
969            for dims in itertools.permutations(range(-ndim, 0)):
970                s = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0]
971                d = self.safeToDense(s)
972
973                dims_sparse, _ = torch.tensor(dims[:sparse_dims]).sort()
974                dims_dense, _ = torch.tensor(dims[sparse_dims:]).sort()
975
976                if (valid_sparse_dims == dims_sparse).all() and (valid_dense_dims == dims_dense).all():
977                    # if valid permutation, test for correctness
978                    s_permuted = s.permute(dims)
979                    self.assertEqual(s_permuted, d.permute(dims))
980
981                    # if s is coalesced, and perm does not touch 0-dim,
982                    # the result has to be coalesced as well
983                    if dims[0] == 0:
984                        self.assertEqual(s_permuted.is_coalesced(), s.is_coalesced())
985                    else:
986                        self.assertFalse(s_permuted.is_coalesced())
987
988                    gradcheck(lambda t: t.permute(dims).to_dense(masked_grad=gradcheck.masked), s.requires_grad_())
989                else:
990                    # otherwise check if exception is thrown
991                    fail_message = "transpositions between sparse and dense dimensions are not allowed"
992                    with self.assertRaisesRegex(RuntimeError, fail_message):
993                        s.permute(dims)
994
995        test_shape(2, 3, [2, 3, 4, 5])
996        test_shape(2, 3, [2, 2, 0])
997        # if nnz=0, it is not true that t == t.to_dense().to_sparse()
998        # unless t.sparse_dim == t.dim (i.e. t is not hybrid)
999        test_shape(3, 0, [0, 0, 2])
1000
1001    @coalescedonoff
1002    @onlyCPU
1003    @dtypes(torch.double)
1004    def test_coalesce_transpose_mm(self, device, dtype, coalesced):
1005        def test_shape(di, dj, dk, nnz):
1006            x, _, _ = self._gen_sparse(2, nnz, [dj, di], dtype, device, coalesced)
1007            y = torch.randn(dj, dk, dtype=dtype, device=device)
1008
1009            x_coalesced = x.coalesce()
1010            self.assertTrue(x_coalesced.is_coalesced())
1011
1012            x_coalesced_t = x_coalesced.t()
1013            # Transpose is `colasced`-preserving if the indices tensor is empty.
1014            self.assertEqual(x_coalesced_t.is_coalesced(), di * nnz == 0)
1015
1016            res = torch.mm(x_coalesced_t, y)
1017            expected = torch.mm(self.safeToDense(x_coalesced_t), y)
1018            self.assertEqual(res, expected)
1019
1020        test_shape(10, 20, 30, 20)
1021        test_shape(0, 20, 30, 0)
1022        test_shape(10, 0, 30, 0)
1023        test_shape(10, 20, 0, 0)
1024        test_shape(10, 20, 0, 20)
1025
1026    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1166")
1027    @dtypes(torch.double, torch.cdouble)
1028    def test_t_empty(self, device, dtype):
1029        def test_in_place(x):
1030            shape_original = x.shape
1031            x.t_()
1032            self.assertEqual(torch.Size([shape_original[1], shape_original[0]]), x.size())
1033            self.assertEqual(0, x._indices().numel())
1034            self.assertEqual(0, x._values().numel())
1035            self.assertEqual(x.sparse_dim(), 2)
1036            self.assertEqual(x.dense_dim(), 0)
1037
1038        def test_not_in_place(x):
1039            shape_original = x.shape
1040            y = x.t()
1041            self.assertEqual(torch.Size([shape_original[1], shape_original[0]]), y.size())
1042            self.assertEqual(0, y._indices().numel())
1043            self.assertEqual(0, y._values().numel())
1044            self.assertEqual(x.sparse_dim(), 2)
1045            self.assertEqual(x.dense_dim(), 0)
1046
1047        x = self.sparse_empty(2, 3, dtype=dtype, device=device)
1048        test_in_place(x)
1049        test_not_in_place(x)
1050
1051        x = self.sparse_empty(2, 0, dtype=dtype, device=device)
1052        test_in_place(x)
1053        test_not_in_place(x)
1054
1055    @coalescedonoff
1056    @dtypes(torch.double, torch.cdouble)
1057    def test_add_zeros(self, device, dtype, coalesced):
1058        def test_shape(sparse_dims, nnz, sizes):
1059            x, _, _ = self._gen_sparse(sparse_dims, nnz, sizes, dtype, device, coalesced)
1060            zeros = torch.sparse_coo_tensor(sizes, device=x.device)
1061            r1 = zeros + x
1062            r2 = x + zeros
1063            self.assertEqual(r1, x)
1064            self.assertEqual(r2, x)
1065
1066        test_shape(1, 20, [1])
1067        test_shape(4, 20, [3, 17, 19, 5])
1068        test_shape(2, 20, [3, 17, 19, 5])
1069        test_shape(2, 20, [3, 17, 19, 0])
1070
1071    @dtypes(torch.double, torch.cdouble)
1072    def test_add_sub_nnz(self, device, dtype):
1073        # nnz should not grow unbounded (gh-34964)
1074        x = torch.randn(10, dtype=dtype, device=device).to_sparse()
1075        x.add_(x)
1076        x.add_(x)
1077        self.assertLessEqual(x._nnz(), 10)
1078
1079        x.sub_(2 * x)
1080        x.sub_(2 * x)
1081        self.assertLessEqual(x._nnz(), 10)
1082
1083    @coalescedonoff
1084    @dtypes(torch.double, torch.cdouble)
1085    def test_cat(self, device, dtype, coalesced):
1086        # shapes: list of tuples (sparse_dims, nnz, sizes)
1087        def test_shapes(shapes, dim, fail_message=None):
1088            inputs = [self._gen_sparse(shape[0], shape[1], shape[2], dtype, device, coalesced)[0]
1089                      for shape in shapes]
1090            if fail_message:
1091                with self.assertRaisesRegex(RuntimeError, fail_message):
1092                    torch.cat(inputs, dim)
1093            else:
1094                result = torch.cat(inputs, dim)
1095                dense_result = torch.cat([t.to_dense() for t in inputs], dim)
1096                self.assertEqual(dense_result, result.to_dense())
1097
1098        test_shapes(
1099            [(3, 10, [2, 3, 4]), (3, 10, [2, 1, 4]), (3, 10, [2, 4, 4])], 1)
1100
1101        # mismatched sizes
1102        test_shapes([(3, 10, [2, 3, 4]), (3, 10, [2, 1, 4])], 0,
1103                    "All tensors must have the same shape: \\[2, 3, 4].*\\[2, 1, 4]")
1104        # hybrid sparse/dense
1105        test_shapes(
1106            [(2, 10, [2, 3, 4]), (2, 10, [2, 1, 4]), (2, 10, [2, 4, 4])], 1)
1107        # cat along dense dim
1108        test_shapes([(2, 10, [2, 3, 4]), (2, 10, [2, 3, 7])], 2)
1109        test_shapes([(1, 10, [2, 3, 4]), (1, 10, [2, 3, 4])], 1)
1110        test_shapes([(1, 10, [2, 3, 4]), (1, 10, [2, 3, 4])], 2)
1111        # mismatched dimensions
1112        test_shapes([(2, 10, [2, 3, 4]), (3, 10, [2, 3, 4])], 0,
1113                    "All tensors must have the same.*2, 1, but tensor at position 1 has 3, 0.")
1114        # wrapped dimension
1115        test_shapes(
1116            [(3, 10, [2, 3, 4]), (3, 10, [2, 1, 4]), (3, 10, [2, 4, 4])], -2)
1117
1118        # sparse with dense
1119        sp = self._gen_sparse(3, 10, [2, 3, 4], dtype, device, coalesced)[0]
1120        dn = sp.to_dense()
1121        with self.assertRaisesRegex(RuntimeError,
1122                                    "Concatenating sparse tensors, but a dense tensor was found at position 1."):
1123            torch.cat((sp, dn))
1124
1125    @coalescedonoff
1126    @dtypes(torch.double, torch.cdouble)
1127    def test_unsqueeze(self, device, dtype, coalesced):
1128        def test_shape(sparse_dims, nnz, sizes, unsqueeze_dim, fail_message=None):
1129            x, _, _ = self._gen_sparse(sparse_dims, nnz, sizes, dtype, device, coalesced)
1130            if fail_message:
1131                with self.assertRaisesRegex(IndexError, fail_message):
1132                    torch.unsqueeze(x, unsqueeze_dim)
1133            else:
1134                result = torch.unsqueeze(x, unsqueeze_dim)
1135                dense_result = torch.unsqueeze(x.to_dense(), unsqueeze_dim)
1136                self.assertEqual(dense_result, result.to_dense())
1137
1138        # basic case
1139        test_shape(3, 10, [5, 7, 11], 0)
1140
1141        # hybrid sparse/dense, unsqueeze along sparse dim
1142        test_shape(3, 10, [5, 7, 11, 13, 17], 0)
1143        test_shape(3, 10, [5, 7, 11, 13, 17], 3)
1144
1145        # unsqueeze along dense dimensions
1146        test_shape(3, 10, [5, 7, 11, 13, 17], 4)
1147        test_shape(3, 10, [5, 7, 11, 13, 17], 5)
1148
1149        # wrapped dimensions
1150        test_shape(3, 10, [5, 7, 11, 13, 17], -1)
1151        test_shape(3, 10, [5, 7, 11, 13, 17], -6)
1152
1153        # bounds
1154        test_shape(3, 10, [5, 7, 11, 13, 17], -7, "Dimension out of range")
1155        test_shape(3, 10, [5, 7, 11, 13, 17], 6, "Dimension out of range")
1156
1157    @coalescedonoff
1158    @dtypes(torch.double, torch.cdouble)
1159    def test_select(self, device, dtype, coalesced):
1160        def test_shape(sparse_dims, nnz, sizes, select_dim, select_index, fail_message=None):
1161            x, _, _ = self._gen_sparse(sparse_dims, nnz, sizes, dtype, device, coalesced)
1162            if fail_message:
1163                with self.assertRaisesRegex(IndexError, fail_message):
1164                    torch.select(x, select_dim, select_index)
1165            else:
1166                result = torch.select(x, select_dim, select_index)
1167                if result.is_sparse:
1168                    result = result.to_dense()
1169                dense_result = torch.select(x.to_dense(), select_dim, select_index)
1170                self.assertEqual(dense_result, result)
1171
1172
1173        sizes = [5, 7, 11, 13, 17]
1174        # hybrid sparse/dense, select sparse dim, result is dense
1175        for i in range(sizes[0]):
1176            test_shape(1, 10, sizes, 0, i)
1177        test_shape(1, 10, sizes, 0, sizes[0] + 1, r'select[(][)][:] index \d out of range.*')
1178
1179        # hybrid sparse/dense, select sparse dim, result is sparse
1180        for d in range(3):
1181            for i in range(sizes[d]):
1182                test_shape(3, 10, sizes, d, i)
1183
1184        # hybrid sparse/dense, select dense dim, result is sparse
1185        for d in range(1, 3):
1186            for i in range(sizes[d]):
1187                test_shape(1, 10, sizes, d, i)
1188
1189    @dtypes(*integral_types())
1190    def test_select_no_type_promotion(self, device, dtype):
1191        # see https://github.com/pytorch/pytorch/issues/82150
1192        idx = torch.tensor([[0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 1]])
1193        val = torch.ones(6, dtype=dtype)
1194        s = torch.sparse_coo_tensor(idx, val, size=(3, 3))
1195
1196        for t in (s, s * torch.tensor(0, dtype=dtype)):
1197            # empty checks
1198            self.assertEqual(t.dtype, t[2].dtype)
1199            self.assertEqual(t.dtype, t[0, 1].dtype)
1200            # sum should not promote
1201            self.assertEqual(t.dtype, t[0, 0].dtype)
1202            self.assertEqual(t.dtype, t[1, 1].dtype)
1203
1204    @coalescedonoff
1205    @dtypes(torch.double, torch.cdouble)
1206    def test_index_select(self, device, dtype, coalesced):
1207        def test_shape(sparse_dims, nnz, sizes, select_dim, select_index, fail_message=None):
1208            if isinstance(select_index, int):
1209                select_index = [select_index]
1210            if isinstance(select_index, list):
1211                select_index = torch.tensor(select_index, device=device, dtype=torch.long)
1212            x, _, _ = self._gen_sparse(sparse_dims, nnz, sizes, dtype, device, coalesced)
1213            if fail_message:
1214                with self.assertRaisesRegex(IndexError, fail_message):
1215                    torch.index_select(x, select_dim, select_index)
1216            else:
1217                result = torch.index_select(x, select_dim, select_index)
1218                if result.is_sparse:
1219                    result = result.to_dense()
1220                dense_result = torch.index_select(x.to_dense(), select_dim, select_index)
1221                self.assertEqual(dense_result, result)
1222
1223        sizes = [5, 7, 11, 13, 17]
1224        for d in range(len(sizes)):
1225            for index in [0, sizes[d] - 1, [0, sizes[d] // 2, sizes[d] - 1]]:
1226                test_shape(1, 10, sizes, d, index)
1227                test_shape(len(sizes) // 2, 10, sizes, d, index)
1228                test_shape(len(sizes), 10, sizes, d, index)
1229
1230    def _test_index_select_exhaustive_index(self, sizes, dims, device, dtype, coalesced):
1231        t = make_tensor(sizes, dtype=dtype, device=device)
1232        t_sparse = t.to_sparse().coalesce() if coalesced else t.to_sparse()
1233        t_small_sparse, _, _ = self._gen_sparse(len(sizes), 2, sizes, dtype, device, coalesced)
1234        t_small = t_small_sparse.to_dense()
1235        for d in dims:
1236            # NOTE: indices are negative
1237            idx_dim_d_range = list(range(-sizes[d], 0))
1238            for idx_len in range(sizes[d], sizes[d] + 1):
1239                # creates all possible valid indices into dim d of lenght idx_len
1240                for idx in itertools.product(*itertools.repeat(idx_dim_d_range, idx_len)):
1241                    t_idx = torch.tensor(idx, dtype=torch.long, device=device)
1242
1243                    # NOTE: index_select for dense does not support negative indices,
1244                    # hence + sizes[d]. See https://github.com/pytorch/pytorch/issues/76347
1245
1246                    # tests the nnz > sizes[d] branch
1247                    dense_result = t.index_select(d, t_idx + sizes[d])
1248                    sparse_result = t_sparse.index_select(d, t_idx)
1249                    self.assertEqual(dense_result, sparse_result)
1250
1251                    # tests the nnz <= sizes[d] branch
1252                    small_dense_result = t_small.index_select(d, t_idx + sizes[d])
1253                    small_sparse_result = t_small_sparse.index_select(d, t_idx)
1254                    self.assertEqual(small_dense_result, small_sparse_result)
1255
1256    @coalescedonoff
1257    @dtypes(torch.double, torch.cdouble)
1258    def test_index_select_exhaustive_index_small(self, device, dtype, coalesced):
1259        # will trigger brute-force algo
1260        self._test_index_select_exhaustive_index((3, 3, 4), range(3), device, dtype, coalesced)
1261
1262    @coalescedonoff
1263    @dtypes(torch.double, torch.cdouble)
1264    def test_index_select_exhaustive_index_large(self, device, dtype, coalesced):
1265        # will trigger more sophisticated algos
1266        self._test_index_select_exhaustive_index((100, 50, 3, 3), (2, 3), device, dtype, coalesced)
1267
1268    @coalescedonoff
1269    @dtypes(torch.double, torch.cdouble)
1270    def test_index_select_empty_and_non_contiguous_index(self, device, dtype, coalesced):
1271        # empty index
1272        idx_empty = torch.tensor([], dtype=torch.long, device=device)
1273        t = make_tensor((5, 5), dtype=dtype, device=device)
1274        res_dense = t.index_select(0, idx_empty)
1275        res_sparse = t.to_sparse().index_select(0, idx_empty)
1276        self.assertEqual(res_dense, res_sparse)
1277
1278        # non-contigous index
1279        idx = torch.randint(low=0, high=5, size=(10, 2), device=device)[:, 0]
1280
1281        def run_test(sizes):
1282            # case nnz > size[d]
1283            t = make_tensor(sizes, dtype=dtype, device=device)
1284            res_dense = t.index_select(0, idx)
1285            res_sparse = t.to_sparse().index_select(0, idx)
1286            self.assertEqual(res_dense, res_sparse)
1287
1288            # case nnz <= size[d]
1289            t_small_sparse, _, _ = self._gen_sparse(len(sizes), 2, sizes, dtype, device, coalesced)
1290            res_sparse = t_small_sparse.index_select(0, idx)
1291            res_dense = t_small_sparse.to_dense().index_select(0, idx)
1292            self.assertEqual(res_dense, res_sparse)
1293
1294        # brute-force
1295        run_test((10, 10))
1296        # more sophisticated algos
1297        run_test((10, 100, 100))
1298
1299    @onlyCPU
1300    @coalescedonoff
1301    @dtypes(torch.double, torch.cdouble)
1302    def test_index_select_parallelization(self, device, dtype, coalesced):
1303        """
1304        Test with sizes that will trigger parallelization (i.e. with sizes
1305        that are >= at::internal::GRAIN_SIZE)
1306        """
1307        def run_test(nnz, size):
1308            t_sparse, _, _ = self._gen_sparse(1, nnz, (size,), dtype, device, coalesced)
1309            t_dense = t_sparse.to_dense()
1310
1311            # idx_small to (sort) and (binary) search into t_sparse
1312            idx_small = torch.randint(size, (nnz // 2,), device=device)
1313            # idx_large to (sort) and (binary) search into idx_large
1314            # NOTE: when coalesced=True, the (binary) search will be
1315            # done over t_sparse anyway, as it is already sorted.
1316            idx_large = torch.randint(size, (nnz * 2,), device=device)
1317            for idx in (idx_small, idx_large):
1318                res_dense = t_dense.index_select(0, idx)
1319                res_sparse = t_sparse.index_select(0, idx)
1320                self.assertEqual(res_dense, res_sparse)
1321
1322        # NOTE: GRAIN_SIZE = 32768
1323        # case nnz <= size[d]
1324        tlen = 70000  # > 2 * GRAIN_SIZE
1325        run_test(tlen, tlen)
1326
1327        # case nnz > size[d]
1328        run_test(tlen, tlen // 2)
1329
1330    @onlyCPU
1331    @coalescedonoff
1332    @dtypes(torch.double, torch.cdouble)
1333    def test_mm(self, device, dtype, coalesced):
1334        def test_shape(di, dj, dk, nnz):
1335            x, _, _ = self._gen_sparse(2, nnz, [di, dj], dtype, device, coalesced)
1336            t = torch.randn(di, dk, dtype=dtype, device=device)
1337            y = torch.randn(dj, dk, dtype=dtype, device=device)
1338            alpha = random.random()
1339            beta = random.random()
1340
1341            res = torch.addmm(t, x, y, beta=beta, alpha=alpha)
1342            expected = torch.addmm(t, self.safeToDense(x), y, beta=beta, alpha=alpha)
1343            self.assertEqual(res, expected)
1344
1345            res = torch.addmm(t, x, y)
1346            expected = torch.addmm(t, self.safeToDense(x), y)
1347            self.assertEqual(res, expected)
1348
1349            res = torch.mm(x, y)
1350            expected = torch.mm(self.safeToDense(x), y)
1351            self.assertEqual(res, expected)
1352
1353        test_shape(10, 100, 100, 20)
1354        test_shape(100, 1000, 200, 20)
1355        test_shape(64, 10000, 300, 20)
1356        test_shape(0, 100, 100, 0)
1357        test_shape(10, 0, 100, 0)
1358        test_shape(10, 100, 0, 0)
1359        test_shape(10, 100, 0, 20)
1360
1361    @unittest.skipIf(
1362        IS_WINDOWS and TEST_CUDA,
1363        "bmm sparse-dense CUDA is not yet supported in Windows, at least up to CUDA 10.1"
1364    )
1365    @coalescedonoff
1366    @dtypes(torch.double)
1367    def test_bmm(self, device, dtype, coalesced):
1368        def test_shape(num_mats, dim_i, dim_j, dim_k, nnz):
1369            a_list = []
1370            b_list = []
1371            for mat_idx in range(num_mats):
1372                a_mat = self._gen_sparse(2, nnz, [dim_i, dim_j], dtype, device, coalesced)[0]
1373                b_mat = torch.randn([dim_j, dim_k], dtype=dtype, device=device)
1374                a_list.append(a_mat)
1375                b_list.append(b_mat)
1376
1377            a = torch.stack(a_list)
1378            b = torch.stack(b_list)
1379            ab = a.bmm(b)
1380
1381            # Compare each matrix against result from mm()
1382            for mat_idx in range(num_mats):
1383                a_mat = a_list[mat_idx]
1384                b_mat = b_list[mat_idx]
1385                ab_mat_bmm = ab[mat_idx]
1386                ab_mat_mm = a_mat.mm(b_mat)
1387                self.assertEqual(ab_mat_bmm, ab_mat_mm)
1388
1389        test_shape(10, 10, 100, 99, 20)
1390        test_shape(10, 100, 1000, 200, 20)
1391        test_shape(10, 64, 10000, 300, 20)
1392        test_shape(10, 0, 100, 99, 0)
1393        test_shape(10, 10, 0, 100, 0)
1394        test_shape(10, 10, 100, 0, 0)
1395        test_shape(10, 10, 100, 0, 20)
1396        test_shape(10, 10, 100, 0, 20)
1397
1398        a = torch.rand([10, 23, 32], dtype=dtype, device=device)
1399        a[3] = torch.zeros(23, 32, dtype=dtype, device=device)
1400        a[6] = torch.zeros(23, 32, dtype=dtype, device=device)
1401        a = a.to_sparse()
1402        b = torch.rand([10, 32, 10], dtype=dtype, device=device)
1403        b[4] = torch.zeros(32, 10, dtype=dtype, device=device)
1404        b[6] = torch.zeros(32, 10, dtype=dtype, device=device)
1405        ab = a.bmm(b)
1406        for mat_idx in range(ab.size(0)):
1407            ab_mat = ab[mat_idx]
1408            ab_mat_check = a[mat_idx].mm(b[mat_idx])
1409            self.assertEqual(ab_mat, ab_mat_check)
1410
1411        ab_traspose_check = b.transpose(1, 2).to_sparse().bmm(
1412            a.transpose(1, 2).to_dense()
1413        ).transpose(1, 2)
1414        self.assertEqual(ab, ab_traspose_check)
1415
1416    @onlyCUDA
1417    @coalescedonoff
1418    @dtypes(torch.double)
1419    @unittest.skipIf(
1420        IS_WINDOWS,
1421        "bmm sparse-dense CUDA is not yet supported in Windows, at least up to CUDA 10.1"
1422    )
1423    def test_bmm_deterministic(self, device, dtype, coalesced):
1424        def test_shape(num_mats, dim_i, dim_j, dim_k, nnz):
1425            a_list = []
1426            b_list = []
1427            for mat_idx in range(num_mats):
1428                a_list.append(self._gen_sparse(2, nnz, [dim_i, dim_j], dtype, device, coalesced)[0])
1429                b_list.append(torch.randn([dim_j, dim_k], dtype=dtype, device=device))
1430
1431            a = torch.stack(a_list).cuda()
1432            b = torch.stack(b_list).cuda()
1433            with DeterministicGuard(torch.are_deterministic_algorithms_enabled()):
1434                torch.use_deterministic_algorithms(False)
1435                ab_nondeterministic = torch.bmm(a, b)
1436                torch.use_deterministic_algorithms(True)
1437                ab_deterministic = torch.bmm(a, b)
1438            diff_abs = (ab_deterministic - ab_nondeterministic).abs()
1439            diff_rel = diff_abs / ab_deterministic.abs()
1440            diff_rel[torch.isnan(diff_rel)] = 0
1441
1442            # deterministic and non-deterministic results should either be
1443            # equal or within a small relative difference
1444            equal_abs_or_rel = diff_abs.eq(0).logical_or(diff_rel.lt(0.001))
1445            self.assertTrue(equal_abs_or_rel.all())
1446
1447        test_shape(10, 10, 100, 99, 20)
1448        test_shape(10, 100, 1000, 200, 20)
1449        test_shape(10, 64, 10000, 300, 20)
1450        test_shape(10, 0, 100, 99, 0)
1451        test_shape(10, 10, 0, 100, 0)
1452        test_shape(10, 10, 100, 0, 0)
1453        test_shape(10, 10, 100, 0, 20)
1454        test_shape(10, 10, 100, 0, 20)
1455
1456    @onlyCUDA
1457    @unittest.skipIf(
1458        not IS_WINDOWS or not TEST_WITH_ROCM,
1459        "this test ensures bmm sparse-dense CUDA gives an error when run on Windows with CUDA < 11.0"
1460    )
1461    @dtypes(torch.double)
1462    def test_bmm_windows_error(self, device, dtype):
1463        a = torch.rand(2, 2, 2, dtype=dtype).to_sparse().cuda()
1464        b = torch.rand(2, 2, 2, dtype=dtype).cuda()
1465        with self.assertRaisesRegex(
1466                RuntimeError,
1467                "bmm sparse-dense CUDA is not supported on Windows with cuda before 11.0"):
1468            ab = a.bmm(b)
1469
1470    @onlyCPU
1471    @coalescedonoff
1472    @dtypes(torch.double, torch.cdouble)
1473    def test_saddmm(self, device, dtype, coalesced):
1474        def test_shape(di, dj, dk, nnz):
1475            x = self._gen_sparse(2, nnz, [di, dj], dtype, device, coalesced)[0]
1476            t = self._gen_sparse(2, nnz, [di, dk], dtype, device, coalesced)[0]
1477            y = torch.randn(dj, dk, dtype=dtype, device=device)
1478            alpha = random.random()
1479            beta = random.random()
1480
1481            res = torch.saddmm(t, x, y, beta=beta, alpha=alpha)
1482            expected = torch.addmm(self.safeToDense(t), self.safeToDense(x), y, beta=beta, alpha=alpha)
1483            self.assertEqual(self.safeToDense(res), expected)
1484
1485            res = torch.saddmm(t, x, y)
1486            expected = torch.addmm(self.safeToDense(t), self.safeToDense(x), y)
1487            self.assertEqual(self.safeToDense(res), expected)
1488
1489            res = torch.smm(x, y)
1490            expected = torch.mm(self.safeToDense(x), y)
1491            self.assertEqual(self.safeToDense(res), expected)
1492
1493        test_shape(7, 5, 3, 20)
1494        test_shape(1000, 100, 100, 20)
1495        test_shape(3000, 64, 300, 20)
1496        test_shape(0, 100, 100, 0)
1497        test_shape(1000, 0, 100, 0)
1498        test_shape(1000, 100, 0, 0)
1499
1500    @onlyCPU
1501    @coalescedonoff
1502    # adding a graph break before self.assertFalse(weight._indices().is_contiguous())
1503    # makes the test pass so some existent sparse related bug
1504    @skipIfTorchDynamo("skip")
1505    @dtypes(torch.double, torch.cdouble)
1506    def test_sspaddmm(self, device, dtype, coalesced):
1507
1508        def test_shape(di, dj, dk, nnz):
1509            x = self._gen_sparse(2, nnz, [di, dj], dtype, device, coalesced)[0]
1510            t = self._gen_sparse(2, nnz, [di, dk], dtype, device, coalesced)[0]
1511            y = torch.randn(dj, dk, dtype=dtype, device=device)
1512            alpha = random.random()
1513            beta = random.random()
1514
1515            res = t.sspaddmm(x, y, beta=beta, alpha=alpha)
1516            expected = torch.addmm(self.safeToDense(t), self.safeToDense(x), y, beta=beta, alpha=alpha)
1517            self.assertEqual(self.safeToDense(res), expected)
1518
1519            res = t.sspaddmm(x, y)
1520            expected = torch.addmm(self.safeToDense(t), self.safeToDense(x), y)
1521            self.assertEqual(self.safeToDense(res), expected)
1522
1523        test_shape(7, 5, 3, 20)
1524        test_shape(1000, 100, 100, 20)
1525        test_shape(3000, 64, 300, 20)
1526        test_shape(0, 100, 100, 0)
1527        test_shape(1000, 0, 100, 0)
1528        test_shape(1000, 100, 0, 0)
1529
1530        # Test code from issue https://github.com/pytorch/pytorch/issues/45113
1531        batch_size, input_size, hidden_size = 5, 3, 7
1532
1533        # Create coalesced sparse tensor with non-contiguous indices
1534        weight = torch.randn(hidden_size, input_size, dtype=dtype, device=device).to_sparse()
1535        self.assertTrue(weight.is_coalesced())
1536        non_contig_indices = weight.indices().mT.contiguous().mT
1537        weight = torch.sparse_coo_tensor(
1538            indices=non_contig_indices, values=weight.values(), size=weight.shape)
1539        weight._coalesced_(True)
1540        self.assertFalse(weight._indices().is_contiguous())
1541        # Create un/coalesced sparse tensor
1542        bias = torch.randn((hidden_size, 1), dtype=dtype, device=device).to_sparse()
1543        bias = torch.cat([bias] * batch_size, dim=1)
1544
1545        if coalesced:
1546            bias = bias.coalesce()
1547
1548        x = torch.randn(input_size, batch_size, dtype=dtype, device=device)
1549        res = bias.sspaddmm(weight, x)
1550
1551        true_result = (bias.to_dense() + torch.matmul(weight.to_dense(), x)).to_sparse()
1552        self.assertEqual(self.safeToDense(res), self.safeToDense(true_result))
1553
1554    @coalescedonoff
1555    @precisionOverride({torch.bfloat16: 5e-2, torch.float16: 5e-2})
1556    @dtypes(torch.double, torch.cdouble, torch.bfloat16, torch.float16)
1557    def test_sparse_addmm(self, device, dtype, coalesced):
1558        if (dtype is torch.bfloat16 or dtype is torch.float16) and device.startswith("cuda"):
1559            self.skipTest('addmm_sparse_cuda is not implemented for BFloat16 and Half')
1560
1561
1562        def test_shape(m, n, p, nnz, broadcast, alpha_beta=None):
1563            if alpha_beta is None:
1564                alpha = random.random()
1565                beta = random.random()
1566            else:
1567                alpha, beta = alpha_beta
1568            if broadcast:
1569                D1 = make_tensor((), dtype=dtype, device=device, requires_grad=True)
1570            else:
1571                D1 = make_tensor([n, p], dtype=dtype, device=device, requires_grad=True)
1572            D2 = make_tensor([m, p], dtype=dtype, device=device, requires_grad=True)
1573            S = self._gen_sparse(2, nnz, [n, m], dtype, device, coalesced)[0]
1574            S_dense = S.to_dense().requires_grad_(True)
1575            S.requires_grad_(True)
1576            Y = torch.sparse.addmm(D1, S, D2, beta=beta, alpha=alpha)
1577            Y_dense = torch.addmm(D1, S_dense, D2, beta=beta, alpha=alpha)
1578            self.assertEqual(Y, Y_dense)
1579
1580            if dtype not in {torch.double, torch.cdouble}:
1581                # gradcheck will likely fail with low-precision input dtypes.
1582                return
1583
1584            def fn(S, D1, D2, beta=beta, alpha=alpha):
1585                return torch.sparse.addmm(D1, S, D2, beta=beta, alpha=alpha)
1586            gradcheck(fn, (S, D1, D2), masked=True)
1587
1588        test_shape(7, 8, 9, 20, False, None)
1589        test_shape(7, 8, 9, 20, True, None)
1590        test_shape(7, 8, 9, 20, False, (1, 0))
1591        test_shape(7, 8, 9, 20, True, (1, 0))
1592        test_shape(7, 8, 9, 20, False, (1, 1))
1593        test_shape(7, 8, 9, 20, True, (1, 1))
1594
1595    @coalescedonoff
1596    @dtypes(torch.double)
1597    @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
1598    def test_sparse_mm(self, device, dtype, coalesced):
1599        def test_shape(d1, d2, d3, nnz, transposed):
1600            if transposed:
1601                D = torch.randn(d3, d2, dtype=dtype,
1602                                device=device).t_().requires_grad_(True)
1603            else:
1604                D = torch.randn(d2, d3, dtype=dtype, device=device).requires_grad_(True)
1605            S = self._gen_sparse(2, nnz, [d1, d2], dtype, device, coalesced)[0]
1606            S_dense = S.to_dense().requires_grad_(True)
1607            S.requires_grad_(True)
1608            self.assertEqual(torch.sparse.mm(S, D), torch.mm(S_dense, D))
1609
1610            def fn(S, D):
1611                return torch.sparse.mm(S, D)
1612            gradcheck(fn, (S, D), masked=True)
1613
1614        test_shape(7, 8, 9, 20, False)
1615        test_shape(7, 8, 9, 20, True)
1616
1617    @coalescedonoff
1618    @dtypes(torch.double)
1619    @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
1620    @gradcheck_semantics()
1621    def test_sparse_mul(self, device, dtype, coalesced, gradcheck):
1622        # https://github.com/pytorch/pytorch/issues/79914
1623        a = torch.tensor([[0., 1]], dtype=dtype, device=device).to_sparse().requires_grad_(True)
1624        b = torch.tensor([[0., 1]], dtype=dtype, device=device).to_sparse().requires_grad_(True)
1625        gradcheck(lambda x, y: torch.sparse.sum(x * y).to_dense(masked_grad=gradcheck.masked), [a, b])
1626
1627        def test_shape(sparse_dims, nnz, with_shape):
1628            a = self._gen_sparse(sparse_dims, nnz, with_shape, dtype, device, coalesced)[0].requires_grad_(True)
1629            b = self._gen_sparse(sparse_dims, nnz, with_shape, dtype, device, coalesced)[0].requires_grad_(True)
1630
1631            self.assertEqual((a * b).to_dense(), a.to_dense() * b.to_dense(), masked=True)
1632            gradcheck(lambda x, y: (x * y).to_dense(), [a, b])
1633            # Issues with 0-dim indices/values
1634            gradcheck(lambda x, y: torch.sparse.sum(x * y).to_dense(), [a, b], masked=True)
1635
1636        # TODO: Re-enable these
1637        # test_shape(2, 3, [2, 3, 4, 5])
1638        # test_shape(2, 3, [2, 2, 0])
1639
1640    @coalescedonoff
1641    @dtypes(torch.double)
1642    def test_dsmm(self, device, dtype, coalesced):
1643        def test_shape(di, dj, dk, nnz):
1644            x = self._gen_sparse(2, nnz, [di, dj], dtype, device, coalesced)[0]
1645            y = self.randn(dj, dk, dtype=dtype, device=device)
1646
1647            res = torch.dsmm(x, y)
1648            expected = torch.mm(self.safeToDense(x), y)
1649            self.assertEqual(res, expected)
1650
1651        test_shape(7, 5, 3, 20)
1652        test_shape(1000, 100, 100, 20)
1653        test_shape(3000, 64, 300, 20)
1654        test_shape(0, 100, 100, 0)
1655        test_shape(1000, 0, 100, 0)
1656        test_shape(1000, 100, 0, 0)
1657        test_shape(1000, 100, 0, 20)
1658
1659    @coalescedonoff
1660    @dtypes(torch.double)
1661    def test_hsmm(self, device, dtype, coalesced):
1662        def test_shape(di, dj, dk, nnz):
1663            x = self._gen_sparse(2, nnz, [di, dj], dtype, device, coalesced)[0]
1664            y = self.randn(dj, dk, dtype=dtype, device=device)
1665
1666            res = torch.hsmm(x, y)
1667            expected = torch.mm(self.safeToDense(x), y)
1668            self.assertEqual(res.to_dense(), expected)
1669
1670        test_shape(7, 5, 3, 20)
1671        test_shape(1000, 100, 100, 20)
1672        test_shape(3000, 64, 300, 20)
1673        test_shape(0, 100, 100, 0)
1674        test_shape(1000, 0, 100, 0)
1675        test_shape(1000, 100, 0, 0)
1676        test_shape(1000, 100, 0, 20)
1677
1678    @coalescedonoff
1679    @dtypes(torch.double)
1680    def test_spadd(self, device, dtype, coalesced):
1681
1682        def _test_spadd_shape(nnz, shape_i, shape_v=None):
1683            shape = shape_i + (shape_v or [])
1684            x, _, _ = self._gen_sparse(len(shape_i), nnz, shape, dtype, device, coalesced)
1685            y = self.randn(*shape, dtype=dtype, device=device)
1686            r = random.random()
1687
1688            res = torch.add(y, x, alpha=r)
1689            expected = y + r * self.safeToDense(x)
1690
1691            self.assertEqual(res, expected)
1692
1693            # Non contiguous dense tensor
1694            s = list(shape)
1695            s[0] = shape[-1]
1696            s[-1] = shape[0]
1697            y = self.randn(*s, dtype=dtype, device=device)
1698            y.transpose_(0, len(s) - 1)
1699            r = random.random()
1700
1701            res = torch.add(y, x, alpha=r)
1702            expected = y + r * self.safeToDense(x)
1703
1704            self.assertEqual(res, expected)
1705
1706            x, i, v = self._gen_sparse(len(shape_i), nnz, shape, dtype, device, coalesced)
1707            nnz = i.size(1)
1708
1709            # Non contiguous sparse indices tensor
1710            x_ = self.sparse_tensor(i[:, ::2], v[:(nnz + 1) // 2], x.shape, dtype=dtype, device=device)
1711            res = torch.add(y, x_, alpha=r)
1712            expected = y + r * self.safeToDense(x_)
1713            self.assertEqual(res, expected)
1714
1715            # Non contiguous sparse values tensor
1716
1717            x_ = self.sparse_tensor(i[:, :(nnz + 1) // 2], v[::2], x.shape, dtype=dtype, device=device)
1718            res = torch.add(y, x_, alpha=r)
1719            expected = y + r * self.safeToDense(x_)
1720            self.assertEqual(res, expected)
1721
1722            # Non contiguous sparse indices and values tensors
1723            x_ = self.sparse_tensor(i[:, 1::2], v[1::2], x.shape, dtype=dtype, device=device)
1724            res = torch.add(y, x_, alpha=r)
1725            expected = y + r * self.safeToDense(x_)
1726            self.assertEqual(res, expected)
1727
1728        def _test_spadd():
1729            _test_spadd_shape(10, [5, 6])
1730            _test_spadd_shape(10, [10, 10, 10])
1731            _test_spadd_shape(10, [50, 30, 20])
1732            _test_spadd_shape(10, [5, 5, 5, 5, 5, 5])
1733            _test_spadd_shape(0, [0, 30, 20])
1734            _test_spadd_shape(0, [50, 0, 20])
1735            _test_spadd_shape(0, [50, 30, 0])
1736
1737        def _test_spadd_hybrid():
1738            _test_spadd_shape(10, [5, 6], [2, 3])
1739            _test_spadd_shape(10, [10, 10, 10], [3])
1740            _test_spadd_shape(10, [50, 30, 20], [2])
1741            _test_spadd_shape(10, [5, 5, 5, 5, 5, 5], [2])
1742            _test_spadd_shape(0, [0, 30, 20], [2, 0])
1743            _test_spadd_shape(0, [50, 0, 20], [2, 0])
1744            _test_spadd_shape(0, [50, 30, 0], [2, 0])
1745            _test_spadd_shape(10, [50, 30, 20], [2, 0])
1746
1747        _test_spadd()
1748        _test_spadd_hybrid()
1749
1750    @coalescedonoff
1751    @dtypes(torch.float)
1752    def test_sparse_add_out_bfloat16(self, device, dtype, coalesced):
1753        # fp32
1754        x, _, _ = self._gen_sparse(3, 5, 10, dtype, device, coalesced)
1755        y, _, _ = self._gen_sparse(3, 5, 10, dtype, device, coalesced)
1756        res_fp32 = torch.add(x, y)
1757
1758        # bfloat16
1759        x = x.bfloat16()
1760        y = y.bfloat16()
1761        res_bf16 = torch.add(x, y)
1762        res_bf16 = res_bf16.float()  # to compare with reference
1763        self.assertEqual(res_fp32, res_bf16, atol=1e-2, rtol=0)
1764
1765    @coalescedonoff
1766    @dtypes(torch.double, torch.cdouble)
1767    def test_norm(self, device, dtype, coalesced):
1768        def test_shape(sparse_dims, nnz, with_size):
1769            x, _, _ = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)
1770            y = x.coalesce()
1771            self.assertEqual(x.norm(), y._values().norm())
1772
1773        test_shape(3, 10, 100)
1774        test_shape(4, 10, [100, 100, 100, 5, 5, 5, 0])
1775        test_shape(4, 0, [0, 0, 100, 5, 5, 5, 0])
1776
1777        # Unsupported arguments should error
1778        kwarg_error_pairs = [
1779            ({'keepdim': True},
1780             RuntimeError, r'norm_sparse currently does not support keepdim=True'),
1781            ({'dim': 0},
1782             RuntimeError, r'norm_sparse currently only supports full reductions'),
1783            ({'dtype': torch.double, 'p': 'fro'},
1784             ValueError, r'dtype argument is not supported in frobenius norm'),
1785            ({'dtype': torch.double, 'p': 0},
1786             RuntimeError, r"norm_sparse currently does not support 'dtype' argument")
1787        ]
1788        x = self._gen_sparse(3, 10, 100, dtype, device, coalesced)[0]
1789        for kwargs, err, msg in kwarg_error_pairs:
1790            with self.assertRaisesRegex(err, msg):
1791                x.norm(**kwargs)
1792
1793    @coalescedonoff
1794    @dtypes(torch.double)
1795    @unittest.skipIf(TEST_WITH_CROSSREF, "fallback triggers cuda device error")
1796    def test_sparse_sum(self, device, dtype, coalesced):
1797
1798        def run_tests(S, td=None):
1799            D = S.coalesce().to_dense().detach().requires_grad_(True)
1800            if td is None:
1801                S_sum = torch.sparse.sum(S)
1802                D_sum = D.sum()
1803                self.assertEqual(S_sum.item(), D_sum.item())
1804
1805                def fn(S):
1806                    return torch.sparse.sum(S)
1807                gradcheck(fn, (S,), masked=True)
1808            else:
1809                S_sum = torch.sparse.sum(S, td)
1810                D_sum = D.sum(td)
1811                self.assertEqual(S_sum.to_dense() if S_sum.is_sparse else S_sum, D_sum)
1812
1813                def fn(S):
1814                    res = torch.sparse.sum(S, td)
1815                    return res.to_dense(masked_grad=True)
1816                gradcheck(fn, (S,), masked=True)
1817
1818        nnz = 10
1819        sparse_dims = 2
1820        with_size = [5, 5, 1, 4]  # use a dense dim = 1 to test for squeeze
1821        test_dims = []
1822        for i in range(1, 5):
1823            test_dims += itertools.combinations(range(len(with_size)), i)
1824
1825        # https://github.com/pytorch/pytorch/issues/16501
1826        x = torch.tensor([[1., 0., 0., 1.],
1827                          [0., 1., 0., 0.],
1828                          [0., 1., 1., 0.],
1829                          [0., 1., 0., 2.]], dtype=dtype, device=device).to_sparse()
1830        self.assertEqual(torch.sparse.sum(x, dim=0), torch.sparse.sum(x, dim=-2))
1831        self.assertEqual(torch.sum(x.to_dense(), dim=0), torch.sparse.sum(x, dim=0).to_dense())
1832
1833        S = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0]
1834
1835        # dim out of range
1836        self.assertRaises(IndexError, lambda: torch.sparse.sum(S, 5))
1837
1838        # dim 0 appears multiple times in the list of dims
1839        self.assertRaises(RuntimeError, lambda: torch.sparse.sum(S, [0, 0]))
1840
1841        # sum an empty tensor
1842        empty_S = torch.sparse_coo_tensor(size=with_size, dtype=dtype, device=device)
1843        self.assertEqual(torch.sparse.sum(empty_S, [0]).to_dense(), torch.sum(empty_S.to_dense(), [0]))
1844        self.assertEqual(torch.sparse.sum(empty_S), torch.tensor(0, dtype=dtype, device=device))
1845        empty_S.requires_grad_(True)
1846        empty_S_sum = torch.sparse.sum(empty_S)
1847        empty_S_sum.backward()
1848        self.assertEqual(empty_S.grad.to_dense(), empty_S.clone().detach().to_dense())
1849
1850        # test values().sum()
1851        S = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0]
1852        run_tests(S.requires_grad_(True))
1853
1854        for test_dim in test_dims:
1855            S = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0]
1856            run_tests(S.requires_grad_(True), test_dim)
1857
1858    def _test_basic_ops_shape(self, nnz_x1, nnz_x2, shape_i, shape_v, dtype, device, coalesced):
1859        shape = shape_i + (shape_v)
1860        x1, _, _ = self._gen_sparse(len(shape_i), nnz_x1, shape, dtype, device, coalesced)
1861        x2, _, _ = self._gen_sparse(len(shape_i), nnz_x2, shape, dtype, device, coalesced)
1862
1863        y1 = x1 + x2
1864        y2 = x1.clone()
1865        y2.add_(x2)
1866        expected = self.safeToDense(x1) + self.safeToDense(x2)
1867        self.assertEqual(self.safeToDense(y1), expected)
1868        self.assertEqual(self.safeToDense(y2), expected)
1869
1870        y1 = x1 - x2
1871        y2 = x1.clone()
1872        y2.sub_(x2)
1873        expected = self.safeToDense(x1) - self.safeToDense(x2)
1874        self.assertEqual(self.safeToDense(y1), expected)
1875        self.assertEqual(self.safeToDense(y2), expected)
1876
1877        y1 = x1 * x2
1878        y2 = x1.clone()
1879        y2.mul_(x2)
1880        expected = self.safeToDense(x1) * self.safeToDense(x2)
1881        self.assertEqual(self.safeToDense(y1), expected)
1882        self.assertEqual(self.safeToDense(y2), expected)
1883
1884        y1 = x1 * 37.5
1885        y2 = x1.clone()
1886        y2.mul_(37.5)
1887        expected = self.safeToDense(x1) * 37.5
1888        self.assertEqual(self.safeToDense(y1), expected)
1889        self.assertEqual(self.safeToDense(y2), expected)
1890
1891        y1 = x1 / 37.5
1892        y2 = x1.clone()
1893        y2.div_(37.5)
1894        expected = self.safeToDense(x1) / 37.5
1895        self.assertEqual(self.safeToDense(y1), expected)
1896        self.assertEqual(self.safeToDense(y2), expected)
1897
1898        y1 = x1 // 37.5
1899        y2 = x1.clone()
1900        y2.floor_divide_(37.5)
1901        expected = self.safeToDense(x1) // 37.5
1902        self.assertEqual(self.safeToDense(y1), expected)
1903        self.assertEqual(self.safeToDense(y2), expected)
1904
1905        # TODO: add back inplace support
1906        y1 = x1 ** 2
1907        y2 = x1.clone()
1908        y2 = y2.pow(2)
1909        expected = self.safeToDense(x1) ** 2
1910        self.assertEqual(self.safeToDense(y1), expected)
1911        self.assertEqual(self.safeToDense(y2), expected)
1912
1913        y = x1.clone()
1914        y.zero_()
1915        expected = torch.zeros(x1.size(), dtype=dtype, device=device)
1916        self.assertEqual(self.safeToDense(y), expected)
1917
1918        self.assertEqual(x1.is_coalesced(), coalesced)
1919        y = x1.coalesce()
1920        z = x1.coalesce()
1921        self.assertEqual(x1.is_coalesced(), coalesced)
1922        self.assertTrue(y.is_coalesced())
1923        y._values().add_(1)
1924        if not x1.is_coalesced():
1925            # check that coalesce is out of place if the original tensor is not
1926            # coalesced.
1927            self.assertEqual(z._values() + 1, y._values())
1928        else:
1929            # check that coalesce is in-place if the original tensor is
1930            # coalesced.
1931            self.assertEqual(z._values(), y._values())
1932
1933    @coalescedonoff
1934    @dtypes(torch.double)
1935    def test_basic_ops(self, device, dtype, coalesced):
1936
1937        def _test_basic_ops():
1938            self._test_basic_ops_shape(9, 12, [5, 6], [], dtype, device, coalesced)
1939            self._test_basic_ops_shape(9, 12, [10, 10, 10], [], dtype, device, coalesced)
1940            self._test_basic_ops_shape(9, 12, [50, 30, 20], [], dtype, device, coalesced)
1941            self._test_basic_ops_shape(9, 12, [5, 5, 5, 5, 5, 5], [], dtype, device, coalesced)
1942            self._test_basic_ops_shape(0, 12, [10, 10, 10], [], dtype, device, coalesced)
1943            self._test_basic_ops_shape(9, 0, [10, 10, 10], [], dtype, device, coalesced)
1944            self._test_basic_ops_shape(0, 0, [10, 10, 10], [], dtype, device, coalesced)
1945            self._test_basic_ops_shape(0, 0, [10, 10, 0], [], dtype, device, coalesced)
1946            self._test_basic_ops_shape(0, 0, [], [], dtype, device, coalesced)
1947
1948        def _test_basic_ops_hybrid():
1949            self._test_basic_ops_shape(9, 12, [5, 6], [2, 3], dtype, device, coalesced)
1950            self._test_basic_ops_shape(9, 12, [10, 10, 10], [3], dtype, device, coalesced)
1951            self._test_basic_ops_shape(9, 12, [50, 30, 20], [2], dtype, device, coalesced)
1952            self._test_basic_ops_shape(9, 12, [5, 5, 5, 5, 5, 5], [2], dtype, device, coalesced)
1953            self._test_basic_ops_shape(0, 12, [10, 10, 10], [2], dtype, device, coalesced)
1954            self._test_basic_ops_shape(9, 0, [10, 10, 10], [2], dtype, device, coalesced)
1955            self._test_basic_ops_shape(0, 0, [10, 10, 10], [2], dtype, device, coalesced)
1956            self._test_basic_ops_shape(9, 12, [10, 10, 10], [2, 0], dtype, device, coalesced)
1957            self._test_basic_ops_shape(0, 12, [10, 10, 10], [2, 0], dtype, device, coalesced)
1958            self._test_basic_ops_shape(9, 0, [10, 10, 10], [2, 0], dtype, device, coalesced)
1959            self._test_basic_ops_shape(0, 0, [10, 10, 10], [2, 0], dtype, device, coalesced)
1960            self._test_basic_ops_shape(0, 0, [10, 10, 0], [2, 0], dtype, device, coalesced)
1961
1962        _test_basic_ops()
1963        _test_basic_ops_hybrid()
1964
1965    @dtypes(torch.double, torch.cdouble)
1966    def test_add_dense_sparse_mismatch(self, device, dtype):
1967        def test_shape(dense_size, sparse_dims_shape, dense_dims_shape, sparse_size):
1968            x = torch.zeros(dense_size, dtype=dtype, device=device)
1969            sparse_y = self.sparse_tensor(torch.zeros(sparse_dims_shape, dtype=torch.int64, device=device),
1970                                          torch.randn(dense_dims_shape, dtype=dtype, device=device),
1971                                          torch.Size(sparse_size))
1972            with self.assertRaisesRegex(
1973                    RuntimeError,
1974                    "add: expected 'self' and 'other' to have same size"):
1975                x + sparse_y
1976
1977        test_shape([3, 4], [1, 4], [4, 4, 4], [3, 4, 4])
1978        test_shape([3, 4, 0], [1, 4], [4, 4, 4, 0], [3, 4, 4, 0])
1979
1980    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
1981    @dtypes(torch.double, torch.cdouble)
1982    def test_add_noncontiguous(self, device, dtype):
1983        indices = self.index_tensor([[1, 2], [0, 2]], device=device)
1984        values = torch.tensor([1.], dtype=dtype, device=device).expand(2, 3, 4, 5)
1985        x = self.sparse_tensor(indices, values, dtype=dtype, device=device)
1986        assert not x._values().is_contiguous()
1987        y = x + x
1988        expected = self.safeToDense(x) + self.safeToDense(x)
1989        self.assertEqual(self.safeToDense(y), expected)
1990
1991    def _test_sparse_mask_shape(self, nnz_x1, nnz_x2, shape_i, shape_v, dtype, device, coalesced):
1992        shape = shape_i + (shape_v or [])
1993        x1, _, _ = self._gen_sparse(len(shape_i), nnz_x1, shape, dtype, device, coalesced)
1994        x2, _, _ = self._gen_sparse(len(shape_i), nnz_x2, shape, dtype, device, coalesced)
1995
1996        y1 = x1 + x2
1997        y2 = x1.clone()
1998        y2.add_(x2)
1999        expected = self.safeToDense(x1) + self.safeToDense(x2)
2000        self.assertEqual(self.safeToDense(y1), expected)
2001        self.assertEqual(self.safeToDense(y2), expected)
2002
2003    @coalescedonoff
2004    @dtypes(torch.double, torch.cdouble)
2005    def test_sparse_mask(self, device, dtype, coalesced):
2006        def _test_sparse_mask_fixed():
2007            i = self.index_tensor([
2008                [1, 3, 0, 4],
2009                [2, 1, 2, 3],
2010            ], device=device)
2011            v = torch.tensor([1, 2, 3, 4], dtype=dtype, device=device)
2012            x = self.sparse_tensor(i, v, torch.Size([5, 4]), dtype=dtype, device=device).coalesce()
2013            dense = torch.tensor([
2014                [1, 2, 3, 4],
2015                [5, 6, 7, 8],
2016                [9, 10, 11, 12],
2017                [13, 14, 15, 16],
2018                [17, 18, 19, 20],
2019            ], dtype=dtype, device=device)
2020            exp_v = torch.tensor([7, 14, 3, 20], dtype=dtype, device=device)
2021            res_dense_lhs = dense.sparse_mask(x)
2022            sparse = dense.to_sparse()
2023            res_sparse_lhs = sparse.sparse_mask(x)
2024            expected = self.sparse_tensor(i, exp_v, torch.Size([5, 4]), dtype=dtype, device=device)
2025            self.assertEqual(res_dense_lhs.coalesce(), expected.coalesce())
2026            # check no side effects for the coalesce flag.
2027            self.assertTrue(sparse.is_coalesced())
2028            self.assertEqual(res_sparse_lhs.coalesce(), expected.coalesce())
2029
2030            i = self.index_tensor([
2031                [1, 3, 0, 4],
2032                [2, 1, 2, 3],
2033            ], device=device)
2034            v = torch.empty([4, 0], dtype=dtype, device=device)
2035            x = self.sparse_tensor(i, v, torch.Size([5, 4, 0])).coalesce()
2036            dense = torch.empty([5, 4, 0], dtype=dtype, device=device)
2037            exp_v = torch.empty([4, 0], dtype=dtype, device=device)
2038            res_dense_lhs = dense.sparse_mask(x)
2039            sparse = dense.to_sparse(2)
2040            res_sparse_lhs = sparse.sparse_mask(x)
2041            expected = self.sparse_tensor(i, exp_v, torch.Size([5, 4, 0]), dtype=dtype, device=device)
2042            self.assertEqual(res_dense_lhs.coalesce(), expected.coalesce())
2043            # check no side effects for the coalesce flag.
2044            self.assertTrue(sparse.is_coalesced())
2045            self.assertEqual(res_sparse_lhs.coalesce(), expected.coalesce())
2046
2047        _test_sparse_mask_fixed()
2048
2049        self._test_sparse_mask_shape(9, 12, [5, 6], [], dtype, device, coalesced)
2050        self._test_sparse_mask_shape(9, 12, [10, 10, 10], [], dtype, device, coalesced)
2051        self._test_sparse_mask_shape(9, 12, [50, 30, 20], [], dtype, device, coalesced)
2052        self._test_sparse_mask_shape(9, 12, [5, 5, 5, 5, 5, 5], [], dtype, device, coalesced)
2053        self._test_sparse_mask_shape(0, 12, [10, 10, 10], [], dtype, device, coalesced)
2054        self._test_sparse_mask_shape(9, 0, [10, 10, 10], [], dtype, device, coalesced)
2055        self._test_sparse_mask_shape(0, 0, [10, 10, 10], [], dtype, device, coalesced)
2056        self._test_sparse_mask_shape(0, 0, [10, 10, 0], [], dtype, device, coalesced)
2057
2058        # check repetitions and matchings in the intersection
2059        lhs = torch.randint(0, 5, (100,), device=device)
2060        rhs = torch.randint(0, 5, (100,), device=device).to_sparse()
2061        self.assertEqual(lhs.to_sparse().sparse_mask(rhs), lhs.sparse_mask(rhs))
2062
2063        # check coalesce
2064        sparse_c = torch.rand(3, 3, device=device).to_sparse()
2065        sparse_unc = torch.rand(3, 3, device=device).to_sparse()._coalesced_(False)
2066        for lhs, rhs in [(sparse_c, sparse_unc), (sparse_unc, sparse_c)]:
2067            res_all_sparse = lhs.sparse_mask(rhs)
2068            res_dense_sparse = lhs.to_dense().sparse_mask(rhs)
2069            self.assertEqual(res_all_sparse.coalesce(), res_dense_sparse.coalesce())
2070            self.assertEqual(rhs.is_coalesced(), res_all_sparse.is_coalesced())
2071
2072    @coalescedonoff
2073    @dtypes(torch.double, torch.cdouble)
2074    def test_sparse_mask_hybrid(self, device, dtype, coalesced):
2075        def _test_sparse_mask_hybrid_fixed():
2076            i = self.index_tensor([
2077                [1, 3, 0, 4],
2078                [2, 1, 2, 3],
2079            ])
2080            v = torch.tensor([[1, 2], [2, 3], [3, 4], [4, 5]])
2081            # TODO: This is also testing that, if coalesce is a no-op,
2082            # the indices don't get permuted. I don't know if we actually
2083            # want to give this invariant.
2084            x = self.sparse_tensor(i, v, torch.Size([5, 4, 2])).coalesce()
2085            dense = torch.tensor([
2086                [[1, 3], [2, 2], [3, 3], [4, 2]],
2087                [[5, 7], [6, 7], [7, 9], [8, 9]],
2088                [[9, 2], [10, 4], [11, 1], [12, 3]],
2089                [[13, 5], [14, 1], [15, 1], [16, 6]],
2090                [[17, 7], [18, 2], [19, 7], [20, 1]],
2091            ])
2092            res_dense_lhs = dense.sparse_mask(x)
2093            sparse = dense.to_sparse(2)
2094            res_sparse_lhs = sparse.sparse_mask(x)
2095            exp_v = torch.tensor([[7, 9], [14, 1], [3, 3], [20, 1]])
2096            expected = self.sparse_tensor(i, exp_v, torch.Size([5, 4, 2]))
2097            self.assertEqual(res_dense_lhs.coalesce(), expected.coalesce())
2098            # check no side effects for the coalesce flag
2099            self.assertTrue(sparse.is_coalesced())
2100            self.assertEqual(res_sparse_lhs.coalesce(), expected.coalesce())
2101
2102            i = self.index_tensor([
2103                [1, 3, 0, 4],
2104                [2, 1, 2, 3],
2105            ])
2106            v = torch.empty(4, 2, 0)
2107            x = self.sparse_tensor(i, v, torch.Size([5, 4, 2, 0])).coalesce()
2108            dense = torch.empty(5, 4, 2, 0)
2109            res_dense_lhs = dense.sparse_mask(x)
2110            sparse = dense.to_sparse(2)
2111            res_sparse_lhs = sparse.sparse_mask(x)
2112            exp_v = torch.empty(4, 2, 0)
2113            expected = self.sparse_tensor(i, exp_v, torch.Size([5, 4, 2, 0]))
2114            self.assertEqual(res_dense_lhs.coalesce(), expected.coalesce())
2115            # check no side effects for the coalesce flag
2116            self.assertTrue(sparse.is_coalesced())
2117            self.assertEqual(res_sparse_lhs.coalesce(), expected.coalesce())
2118
2119        _test_sparse_mask_hybrid_fixed()
2120
2121        self._test_sparse_mask_shape(9, 12, [5, 6], [2, 3], dtype, device, coalesced)
2122        self._test_sparse_mask_shape(9, 12, [10, 10, 10], [3], dtype, device, coalesced)
2123        self._test_sparse_mask_shape(9, 12, [50, 30, 20], [2], dtype, device, coalesced)
2124        self._test_sparse_mask_shape(9, 12, [5, 5, 5, 5, 5, 5], [2], dtype, device, coalesced)
2125        self._test_sparse_mask_shape(0, 12, [10, 10, 10], [2], dtype, device, coalesced)
2126        self._test_sparse_mask_shape(9, 0, [10, 10, 10], [2], dtype, device, coalesced)
2127        self._test_sparse_mask_shape(0, 0, [10, 10, 10], [2], dtype, device, coalesced)
2128        self._test_sparse_mask_shape(9, 12, [10, 10, 10], [2, 0], dtype, device, coalesced)
2129        self._test_sparse_mask_shape(0, 12, [10, 10, 10], [2, 0], dtype, device, coalesced)
2130        self._test_sparse_mask_shape(9, 0, [10, 10, 10], [2, 0], dtype, device, coalesced)
2131        self._test_sparse_mask_shape(0, 0, [10, 10, 10], [2, 0], dtype, device, coalesced)
2132        self._test_sparse_mask_shape(0, 0, [10, 10, 0], [2, 0], dtype, device, coalesced)
2133
2134    @dtypes(torch.double, torch.cdouble)
2135    @skipIfCrossRef
2136    def test_sparse_mask_backward(self, device, dtype):
2137        from itertools import product, repeat
2138
2139        shape = (5, 5)
2140        sparse_dims = len(shape)
2141        nnzs = (0, 5, 15, 25)
2142
2143        lhs_data = torch.arange(1, 26, device=device).reshape(shape).to(dtype).to_sparse(sparse_dims)
2144        rhs_data = lhs_data.clone()
2145
2146        for nnz in nnzs:
2147            for lhs_is_coalesced, rhs_is_coalesced in product(*repeat((True, False), 2)):
2148                lhs = torch.sparse_coo_tensor(
2149                    lhs_data._indices()[:, :nnz],
2150                    lhs_data._values()[:nnz],
2151                    lhs_data.shape
2152                ).clone()._coalesced_(lhs_is_coalesced).requires_grad_(True)
2153
2154                rhs = torch.sparse_coo_tensor(
2155                    lhs_data._indices()[:, -nnz:],
2156                    lhs_data._values()[-nnz:],
2157                    lhs_data.shape
2158                ).clone()._coalesced_(rhs_is_coalesced)
2159
2160                # To test masked semantics we need to make sure that
2161                # sparsity_pattern(lhs) == sparsity_pattern(lhs.grad).
2162                # lhs.sparse_mask(lhs_mask) accomplishes that.
2163                lhs_mask = lhs.detach().clone()
2164                gradcheck(lambda x: x.sparse_mask(lhs_mask).sparse_mask(rhs).to_dense(masked_grad=True), (lhs,), masked=True)
2165                gradcheck(lambda x: x.sparse_mask(rhs).to_dense(masked_grad=False), (lhs,), masked=False)
2166
2167    @coalescedonoff
2168    @dtypes(torch.double, torch.cdouble)
2169    def test_zeros(self, device, dtype, coalesced):
2170        def _test_zeros(nnzs, shape, out_shape_i, out_shape_v=None):
2171            out_shape = out_shape_i + (out_shape_v or [])
2172            for nnz in nnzs:
2173                out, _, _ = self._gen_sparse(len(out_shape_i), nnz, out_shape, dtype, device, coalesced)
2174                torch.zeros(*shape, out=out, dtype=dtype, device=device)
2175                self.assertEqual(tuple(out.size()), tuple(shape))
2176                self.assertTrue(out._indices().numel() == out._values().numel() == 0)
2177                self.assertEqual(out._nnz(), 0)
2178                self.assertEqual(out.sparse_dim(), len(shape))
2179                self.assertEqual(out.dense_dim(), 0)
2180
2181        def test_shape(i_shapes, v_shapes, shape, nnzs):
2182            for i_dim in range(1, len(i_shapes) + 1):
2183                for v_dim in range(len(v_shapes) + 1):
2184                    _test_zeros(nnzs, shape, i_shapes[:i_dim], v_shapes[:v_dim])
2185        test_shape([2, 3, 4], [3, 4, 5, 6], [2, 3, 4], [9, 12])
2186        test_shape([0, 3, 4], [3, 4, 5, 6], [2, 3, 4], [0])
2187        test_shape([2, 3, 4], [0, 4, 5, 6], [2, 3, 4], [9, 12])
2188        test_shape([2, 3, 4], [3, 4, 5, 6], [2, 3, 0], [9, 12])
2189        test_shape([0, 3, 4], [3, 4, 5, 6], [2, 3, 0], [0])
2190        test_shape([2, 3, 4], [0, 4, 5, 6], [2, 3, 0], [9, 12])
2191
2192    @coalescedonoff
2193    @dtypes(torch.double, torch.cdouble)
2194    def test_zeros_like(self, device, dtype, coalesced):
2195        def _test_zeros_like(nnzs, template_shape_i, template_shape_v=None):
2196            template_shape_v = template_shape_v or []
2197            template_shape = template_shape_i + template_shape_v
2198            for nnz in nnzs:
2199                t, _, _ = self._gen_sparse(len(template_shape_i), nnz, template_shape, dtype, device, coalesced)
2200                res = torch.zeros_like(t)
2201                self.assertEqual(tuple(res.size()), tuple(template_shape))
2202                self.assertTrue(res._indices().numel() == res._values().numel() == 0)
2203                self.assertEqual(res._nnz(), 0)
2204                self.assertEqual(res.sparse_dim(), len(template_shape_i))
2205                self.assertEqual(res.dense_dim(), len(template_shape_v))
2206
2207        def test_shape(i_shapes, v_shapes, nnzs):
2208            for i_dim in range(1, len(i_shapes) + 1):
2209                for v_dim in range(len(v_shapes) + 1):
2210                    _test_zeros_like(nnzs, i_shapes[:i_dim], v_shapes[:v_dim])
2211        test_shape([2, 3, 4], [3, 4, 5, 6], [9, 12])
2212        test_shape([0, 3, 4], [3, 4, 5, 6], [0])
2213        test_shape([2, 3, 4], [0, 4, 5, 6], [9, 12])
2214        test_shape([2, 3, 4], [3, 4, 5, 6], [9, 12])
2215        test_shape([0, 3, 4], [3, 4, 5, 6], [0])
2216        test_shape([2, 3, 4], [0, 4, 5, 6], [9, 12])
2217
2218        sparse_tensor, _, _ = self._gen_sparse(len([2, 3]), 9, [2, 3] + [5, 6], dtype, device, coalesced)
2219        data = (sparse_tensor, sparse_tensor, sparse_tensor, sparse_tensor.unsqueeze(0))
2220        mem_formats = [torch.channels_last, torch.contiguous_format, torch.preserve_format, torch.channels_last_3d]
2221        for x, mem_format in zip(data, mem_formats):
2222
2223            with self.assertRaisesRegex(RuntimeError, "memory format option is only supported by strided tensors"):
2224                result = torch.zeros_like(x, memory_format=mem_format)
2225
2226            result = torch.zeros_like(x, layout=torch.strided, memory_format=mem_format)
2227            self.assertTrue(result.layout == torch.strided)
2228
2229        dense_tensor = sparse_tensor.to_dense()
2230        result = torch.zeros_like(dense_tensor, layout=torch.sparse_coo)
2231        self.assertEqual(dense_tensor.shape, result.shape)
2232        self.assertEqual(result.layout, torch.sparse_coo)
2233
2234        sparse_zeros = torch.sparse_coo_tensor(dense_tensor.shape)
2235        self.assertEqual(result._indices().shape, sparse_zeros._indices().shape)
2236        self.assertEqual(result._values().shape, sparse_zeros._values().shape)
2237
2238    def _assert_sparse_invars(self, t):
2239        # SparseTensor has the following invariants:
2240        # - sparse_dim + dense_dim = len(SparseTensor.shape)
2241        # - SparseTensor._indices().shape = (sparse_dim, nnz)
2242        # - SparseTensor._values().shape = (nnz, SparseTensor.shape[sparse_dim:])
2243        self.assertEqual(t.sparse_dim() + t.dense_dim(), len(t.shape))
2244        self.assertEqual(tuple(t._indices().shape), (t.sparse_dim(), t._nnz()))
2245        self.assertEqual(tuple(t._values().shape), (t._nnz(), ) + t.shape[t.sparse_dim():])
2246
2247    def _test_empty_like(self, sparse_tensor, dtype, device, coalesced):
2248
2249        result = torch.empty_like(sparse_tensor)
2250        self.assertTrue(result.is_sparse)
2251        self._assert_sparse_invars(result)
2252        self.assertEqual(result.shape, sparse_tensor.shape)
2253        self.assertEqual(result.dtype, sparse_tensor.dtype)
2254        self.assertEqual(result.device, sparse_tensor.device)
2255        self.assertEqual(result.sparse_dim(), sparse_tensor.sparse_dim())
2256        self.assertEqual(result.dense_dim(), sparse_tensor.dense_dim())
2257
2258        sparse_tensor, _, _ = self._gen_sparse(len([2, 3]), 9, [2, 3] + [5, 6], dtype, device, coalesced)
2259        data = (sparse_tensor, sparse_tensor, sparse_tensor, sparse_tensor.unsqueeze(0))
2260        mem_formats = [torch.channels_last, torch.contiguous_format, torch.preserve_format, torch.channels_last_3d]
2261        for x, mem_format in zip(data, mem_formats):
2262
2263            with self.assertRaisesRegex(RuntimeError, "memory format option is only supported by strided tensors"):
2264                result = torch.empty_like(x, memory_format=mem_format)
2265
2266            result = torch.empty_like(x, layout=torch.strided, memory_format=mem_format)
2267            self.assertTrue(result.layout == torch.strided)
2268
2269        with self.assertRaisesRegex(
2270            RuntimeError, r"Could not run 'aten::empty_strided' with arguments from the 'Sparse(CPU|CUDA)' backend"
2271        ):
2272            dense_tensor = sparse_tensor.to_dense()
2273            result = torch.empty_like(dense_tensor, layout=torch.sparse_coo)
2274
2275    @coalescedonoff
2276    @dtypes(torch.double, torch.cdouble)
2277    def test_empty_like(self, device, dtype, coalesced):
2278        # tests https://github.com/pytorch/pytorch/issues/43699
2279
2280        if coalesced:
2281            input_coalesced = torch.sparse_coo_tensor(
2282                indices=torch.tensor([[0, 1, 2]]),
2283                values=torch.tensor([3.0, -4.0, 5.0]),
2284                size=[3, ],
2285                dtype=dtype,
2286                device=device
2287            ).coalesce()
2288            self._test_empty_like(input_coalesced, dtype, device, coalesced)
2289
2290            # hybrid sparse input
2291            input_coalesced = torch.sparse_coo_tensor(
2292                indices=torch.tensor([[1, 3], [2, 4]]),
2293                values=torch.tensor([[-1.0, 3.0], [-5.0, 7.0]]),
2294                size=[4, 5, 2],
2295                dtype=dtype,
2296                device=device
2297            ).coalesce()
2298            self._test_empty_like(input_coalesced, dtype, device, coalesced)
2299
2300        if not coalesced:
2301            # test uncoalesced input
2302            input_uncoalesced = torch.sparse_coo_tensor(
2303                indices=torch.tensor([[0], [1], [2], [0], [1], [2]]).transpose(1, 0),
2304                values=torch.tensor([2.0, -3.0, -4.0, 1.0, -1.0, 1.5]),
2305                size=[3, ],
2306                dtype=dtype,
2307                device=device
2308            )
2309            self._test_empty_like(input_uncoalesced, dtype, device, coalesced)
2310
2311            # test on empty sparse tensor
2312            input_uncoalesced = torch.sparse_coo_tensor(
2313                indices=torch.zeros([2, 0]),
2314                values=torch.zeros([0, 5, 5, 5, 5, 5, 5, 0]),
2315                size=[0, 0, 5, 5, 5, 5, 5, 5, 0],
2316                dtype=dtype,
2317                device=device
2318            )
2319            self._test_empty_like(input_uncoalesced, dtype, device, coalesced)
2320
2321    def _test_narrow(self, input, narrow_args):
2322        expected = input.to_dense().narrow(*narrow_args)
2323        self.assertEqual(expected, input.narrow_copy(*narrow_args).to_dense())
2324
2325    def _all_narrow_combs(self, shape):
2326        for dim, dim_sz in enumerate(shape):
2327            for start in range(dim_sz):
2328                for length in range(dim_sz - start):
2329                    yield [dim, start, length]
2330
2331    @coalescedonoff
2332    @dtypes(torch.double, torch.cdouble)
2333    def test_narrow(self, device, dtype, coalesced):
2334        shape = [3, 3, 4, 2]
2335        input, _, _ = self._gen_sparse(4, 19, shape, dtype, device, coalesced)
2336        for narrow_args in self._all_narrow_combs(shape):
2337            self._test_narrow(input, narrow_args)
2338
2339        self.assertRaises(RuntimeError, lambda: input.narrow_copy(-1, 0, 3))  # dim < 0
2340        self.assertRaises(RuntimeError, lambda: input.narrow_copy(10, 0, 3))  # dim > input.dim()
2341        self.assertRaises(RuntimeError, lambda: input.narrow_copy(0, shape[0] + 1, 3))  # start > size of dim
2342        self.assertRaises(RuntimeError, lambda: input.narrow_copy(0, 2, shape[0]))  # start+length > size of dim
2343
2344        with_dense, _, _ = self._gen_sparse(2, 7, shape, dtype, device, coalesced)
2345        for narrow_args in self._all_narrow_combs(shape):
2346            self._test_narrow(with_dense, narrow_args)
2347
2348        self.assertRaises(RuntimeError, lambda: with_dense.narrow_copy(10, 0, 3))  # dim > sparseDim + denseDim
2349
2350    def _test_log1p_tensor(self, sparse_tensor, coalesced):
2351        def is_integral(dtype):
2352            return dtype in integral_types()
2353
2354        dense_tensor = sparse_tensor.to_dense()
2355        expected_output = dense_tensor.log1p()
2356        is_integral_dtype = is_integral(sparse_tensor.dtype)
2357        self.assertEqual(expected_output, sparse_tensor.log1p().to_dense())
2358        if is_integral_dtype:
2359            with self.assertRaisesRegex(RuntimeError, "result type .* can't be cast to"):
2360                sparse_tensor.coalesce().log1p_()
2361        else:
2362            self.assertEqual(expected_output, sparse_tensor.coalesce().log1p_().to_dense())
2363
2364        if not coalesced:
2365            # test in-place op on uncoalesced input
2366            with self.assertRaisesRegex(RuntimeError, "log1p_ requires coalesced input"):
2367                sparse_tensor.log1p_()
2368
2369        if is_integral_dtype:
2370            with self.assertRaisesRegex(RuntimeError, "only Tensors of floating point dtype can require gradients"):
2371                sparse_tensor.requires_grad_()
2372
2373    @coalescedonoff
2374    @dtypes(*all_types())
2375    def test_log1p(self, device, dtype, coalesced):
2376        if coalesced:
2377            input_coalesced = torch.sparse_coo_tensor(
2378                indices=torch.tensor([[0], [1], [2]]).transpose(1, 0),
2379                values=torch.tensor([3.0, 4.0, 5.0]),
2380                size=[3, ],
2381                device=device,
2382                dtype=dtype
2383            ).coalesce()
2384            self._test_log1p_tensor(input_coalesced, coalesced)
2385
2386            # hybrid sparse input
2387            input_coalesced = torch.sparse_coo_tensor(
2388                indices=torch.tensor([[1, 3], [2, 4]]),
2389                values=torch.tensor([[1.0, 3.0], [5.0, 7.0]]),
2390                size=[4, 5, 2],
2391                device=device,
2392                dtype=dtype
2393            ).coalesce()
2394            self._test_log1p_tensor(input_coalesced, coalesced)
2395
2396        if not coalesced:
2397            # test uncoalesced input
2398            input_uncoalesced = torch.sparse_coo_tensor(
2399                indices=torch.tensor([[0], [1], [2], [0], [1], [2]]).transpose(1, 0),
2400                values=torch.tensor([2.0, 3.0, 4.0, 1.0, 1.0, 1.0]),
2401                size=[3, ],
2402                device=device,
2403                dtype=dtype
2404            )
2405            self._test_log1p_tensor(input_uncoalesced, coalesced)
2406
2407            # test on empty sparse tensor
2408            input_uncoalesced = torch.sparse_coo_tensor(
2409                indices=torch.zeros([2, 0]),
2410                values=torch.zeros([0, 5, 5, 5, 5, 5, 5, 0]),
2411                size=[0, 0, 5, 5, 5, 5, 5, 5, 0],
2412                device=device,
2413                dtype=dtype
2414            )
2415            # empty tensors are coalesced at creation (nnz < 2) we must force the uncoalesced state
2416            input_uncoalesced._coalesced_(False)
2417            self._test_log1p_tensor(input_uncoalesced, coalesced)
2418
2419    def _test_neg_negative(self, sparse_tensor):
2420        dense_tensor = sparse_tensor.to_dense()
2421        expected_output = dense_tensor.neg()
2422
2423        ops = (
2424            torch.neg, torch.Tensor.neg, torch.Tensor.neg_,
2425            torch.negative, torch.Tensor.negative, torch.Tensor.negative_,
2426            operator.neg
2427        )
2428        for op in ops:
2429            sparse_tensor_copy = sparse_tensor.clone()
2430            self.assertEqual(expected_output, op(sparse_tensor_copy).to_dense())
2431
2432            if op in (torch.neg, torch.negative):
2433                sparse_tensor_out = torch.zeros_like(sparse_tensor)
2434                op(sparse_tensor, out=sparse_tensor_out)
2435                self.assertEqual(expected_output, sparse_tensor_out.to_dense())
2436
2437    @coalescedonoff
2438    @dtypes(torch.double, torch.cdouble)
2439    def test_neg_negative(self, device, dtype, coalesced):
2440
2441        if coalesced:
2442            input_coalesced = torch.sparse_coo_tensor(
2443                indices=torch.tensor([[0, 1, 2]]),
2444                values=torch.tensor([3.0, -4.0, 5.0]),
2445                size=[3, ],
2446                dtype=dtype,
2447                device=device
2448            ).coalesce()
2449            self._test_neg_negative(input_coalesced)
2450
2451            # hybrid sparse input
2452            input_coalesced = torch.sparse_coo_tensor(
2453                indices=torch.tensor([[1, 3], [2, 4]]),
2454                values=torch.tensor([[-1.0, 3.0], [-5.0, 7.0]]),
2455                size=[4, 5, 2],
2456                dtype=dtype,
2457                device=device
2458            ).coalesce()
2459            self._test_neg_negative(input_coalesced)
2460
2461        if not coalesced:
2462            # test uncoalesced input
2463            input_uncoalesced = torch.sparse_coo_tensor(
2464                indices=torch.tensor([[0], [1], [2], [0], [1], [2]]).transpose(1, 0),
2465                values=torch.tensor([2.0, -3.0, -4.0, 1.0, -1.0, 1.5]),
2466                size=[3, ],
2467                dtype=dtype,
2468                device=device
2469            )
2470            self._test_neg_negative(input_uncoalesced)
2471
2472            # test on empty sparse tensor
2473            input_uncoalesced = torch.sparse_coo_tensor(
2474                indices=torch.zeros([2, 0]),
2475                values=torch.zeros([0, 5, 5, 5, 5, 5, 5, 0]),
2476                size=[0, 0, 5, 5, 5, 5, 5, 5, 0],
2477                dtype=dtype,
2478                device=device
2479            )
2480            self._test_neg_negative(input_uncoalesced)
2481
2482    def _test_asin_arcsin(self, sparse_tensor, coalesced):
2483        def is_integral(dtype):
2484            return dtype in integral_types()
2485        is_integral_dtype = is_integral(sparse_tensor.dtype)
2486
2487        dense_tensor = sparse_tensor.to_dense()
2488        expected_output = dense_tensor.asin()
2489
2490        ops = (
2491            torch.asin, torch.Tensor.asin,
2492            torch.arcsin, torch.Tensor.arcsin,
2493        )
2494        for op in ops:
2495            self.assertEqual(expected_output, op(sparse_tensor).to_dense())
2496            if op in (torch.asin, torch.arcsin):
2497                sparse_tensor_out = torch.zeros_like(sparse_tensor)
2498                if not is_integral_dtype:
2499                    op(sparse_tensor, out=sparse_tensor_out)
2500                    self.assertEqual(expected_output, sparse_tensor_out.to_dense())
2501                else:
2502                    with self.assertRaisesRegex(RuntimeError, "result type .* can't be cast to"):
2503                        op(sparse_tensor, out=sparse_tensor_out)
2504
2505        for op in (torch.Tensor.asin_, torch.Tensor.arcsin_):
2506            if is_integral_dtype:
2507                # test coalesce on integral dtype tensor
2508                with self.assertRaisesRegex(RuntimeError, "result type .* can't be cast to"):
2509                    op(sparse_tensor.clone().coalesce()).to_dense()
2510            else:
2511                self.assertEqual(expected_output, op(sparse_tensor.clone().coalesce()).to_dense())
2512
2513            if not coalesced:
2514                # test in-place op on uncoalesced input
2515                with self.assertRaisesRegex(RuntimeError, "asin_ requires coalesced input"):
2516                    op(sparse_tensor)
2517
2518    @coalescedonoff
2519    @dtypes(*all_types())
2520    def test_asin_arcsin(self, device, dtype, coalesced):
2521        if coalesced:
2522            input_coalesced = torch.sparse_coo_tensor(
2523                indices=torch.tensor([[0, 1, 2, 3]]),
2524                values=torch.tensor([0.5, -0.5, 0.7, -0.7]),
2525                size=[4, ],
2526                dtype=dtype,
2527                device=device
2528            ).coalesce()
2529            self._test_asin_arcsin(input_coalesced, coalesced)
2530
2531            # hybrid sparse input
2532            input_coalesced = torch.sparse_coo_tensor(
2533                indices=torch.tensor([[1, 3], [2, 4]]),
2534                values=torch.tensor([[-0.1, 0.24], [-0.44, 0.1]]),
2535                size=[4, 5, 2],
2536                dtype=dtype,
2537                device=device
2538            ).coalesce()
2539            self._test_asin_arcsin(input_coalesced, coalesced)
2540
2541        if not coalesced:
2542            # test uncoalesced input
2543            input_uncoalesced = torch.sparse_coo_tensor(
2544                indices=torch.tensor([[0], [1], [2], [0], [1], [2]]).transpose(1, 0),
2545                values=torch.tensor([0.3, -0.3, -0.4, 0.3, -0.5, 0.15]),
2546                size=[3, ],
2547                dtype=dtype,
2548                device=device
2549            )
2550            self._test_asin_arcsin(input_uncoalesced, coalesced)
2551
2552            # test on empty sparse tensor
2553            input_uncoalesced = torch.sparse_coo_tensor(
2554                indices=torch.zeros([2, 0]),
2555                values=torch.zeros([0, 5, 5, 5, 5, 5, 5, 0]),
2556                size=[0, 0, 5, 5, 5, 5, 5, 5, 0],
2557                dtype=dtype,
2558                device=device
2559            )
2560            # empty tensors are coalesced at creation (nnz < 2) we must force the uncoalesced state
2561            input_uncoalesced._coalesced_(False)
2562            self._test_asin_arcsin(input_uncoalesced, coalesced)
2563
2564    @coalescedonoff
2565    @dtypes(torch.double)
2566    def test_mv(self, device, dtype, coalesced):
2567        def test_shape(di, dj, dk, nnz):
2568            x, _, _ = self._gen_sparse(2, nnz, [di, dj], dtype, device, coalesced)
2569            t = torch.randn(dk, dtype=dtype, device=device)
2570
2571            res = x.matmul(t)
2572            expected = self.safeToDense(x).matmul(t)
2573            self.assertEqual(res, expected)
2574
2575        test_shape(10, 100, 100, 20)
2576        test_shape(100, 1000, 1000, 20)
2577        test_shape(64, 10000, 10000, 20)
2578        test_shape(0, 100, 100, 0)
2579        test_shape(10, 0, 0, 0)
2580        test_shape(10, 100, 100, 0)
2581        test_shape(10, 100, 100, 20)
2582
2583        with self.assertRaisesRegex(RuntimeError, r"mv: expected self\.size\(-1\) == vec\.size\(-1\)"):
2584            test_shape(10, 100, 10, 20)
2585
2586        with self.assertRaisesRegex(RuntimeError, "mv: two tensor dim should be 2 and 1"):
2587            x, _, _ = self._gen_sparse(2, 20, [10, 100], dtype, device, coalesced)
2588            y, _, _ = self._gen_sparse(2, 20, [10, 100], dtype, device, coalesced)
2589            res = x.mv(y)
2590
2591    @dtypes(*floating_and_complex_types())
2592    def test_sparse_add_coalesce(self, device, dtype):
2593        i = self.index_tensor([[1, 2, 1]], device=device)
2594        v = torch.tensor([3, 4, 5], dtype=dtype, device=device)
2595        x = self.sparse_tensor(i, v, torch.Size([3]))
2596        y = self.sparse_tensor(i, v, torch.Size([3]))
2597        z = x + y
2598
2599        self.assertFalse(z._indices().numel() != 2 and z.is_coalesced())
2600
2601        i = self.index_tensor([[1, 2, 1]], device=device)
2602        v = torch.empty([3, 0], dtype=dtype, device=device)
2603        x = self.sparse_tensor(i, v, torch.Size([3, 0]))
2604        y = self.sparse_tensor(i, v, torch.Size([3, 0]))
2605        z = x + y
2606
2607        self.assertFalse(z._indices().numel() != 2 and z.is_coalesced())
2608
2609    @onlyCUDA
2610    def test_storage_not_null(self, device):
2611        x = torch.sparse_coo_tensor((2,), dtype=torch.float32, device=device)
2612        self.assertNotEqual(x.get_device(), -1)
2613
2614        x = torch.sparse_coo_tensor((2, 0), dtype=torch.float32, device=device)
2615        self.assertNotEqual(x.get_device(), -1)
2616
2617    @onlyCUDA
2618    @deviceCountAtLeast(2)
2619    def test_same_gpu(self, devices):
2620        def check_device(x, device_id):
2621            self.assertEqual(x.get_device(), device_id)
2622            self.assertEqual(x._values().get_device(), device_id)
2623            self.assertEqual(x._indices().get_device(), device_id)
2624
2625        dev1, dev2 = devices[0], devices[1]
2626
2627        i = self.index_tensor([[2]], device=dev2)
2628        v = torch.tensor([5], device=dev2)
2629        x = self.sparse_tensor(i, v, torch.Size([3]), device=1)
2630        check_device(x, 1)
2631
2632        i = self.index_tensor([[2]], device=dev2)
2633        v = torch.empty(1, 0, device=dev2)
2634        x = self.sparse_tensor(i, v, torch.Size([3, 0]), device=1)
2635        check_device(x, 1)
2636
2637        x = self.sparse_empty(3, device=1)
2638        check_device(x, 1)
2639
2640        x = self.sparse_empty(3, 0, device=1)
2641        check_device(x, 1)
2642
2643    def _test_new_device(self, size, device=torch.cuda):
2644        with torch.cuda.device(device):
2645            x = torch.sparse_coo_tensor(size, device='cuda', dtype=torch.float64)
2646        self.assertEqual(x.get_device(), device)
2647        x1 = x.new()
2648        x2 = x.new(2, 3)
2649        self.assertEqual(x1.get_device(), device)
2650        self.assertEqual(x2.get_device(), device)
2651
2652    @onlyCUDA
2653    def test_new_device_single_gpu(self):
2654        self._test_new_device((), 0)
2655        self._test_new_device((30, 20), 0)
2656        self._test_new_device((30, 20, 10), 0)
2657        self._test_new_device((30, 20, 10, 0), 0)
2658
2659    @onlyCUDA
2660    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
2661    def test_new_device_multi_gpu(self):
2662        self._test_new_device((), 1)
2663        self._test_new_device((30, 20), 1)
2664        self._test_new_device((30, 20, 10), 1)
2665        self._test_new_device((30, 20, 10, 0), 1)
2666
2667    @coalescedonoff
2668    @dtypes(torch.double, torch.cdouble)
2669    def test_new(self, device, dtype, coalesced):
2670        def test_shape(sparse_dims, nnz, with_size):
2671            x, indices, values = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)
2672            if not x.is_cuda:
2673                # CUDA sparse tensors currently requires the size to be
2674                # specified if nDimV > 0
2675                out = x.new(indices, values).coalesce()
2676                x_c = x.coalesce()
2677                self.assertEqual((out.indices(), out.values()), (x_c.indices(), x_c.values()))
2678            self.assertEqual(x.new(indices, values, x.size()), x)
2679
2680        test_shape(3, 10, 100)
2681        test_shape(3, 0, [100, 100, 0])
2682
2683    @onlyCPU  # not really, but we only really want to run this once
2684    @dtypes(torch.float64, torch.float32, torch.float16, torch.cfloat, torch.cdouble)
2685    def test_factory(self, device, dtype):
2686        for test_empty_tensor in [True, False]:
2687            if test_empty_tensor:
2688                default_size = torch.Size([1, 3, 0])
2689                size = torch.Size([3, 3, 0])
2690            else:
2691                default_size = torch.Size([1, 3])
2692                size = torch.Size([3, 3])
2693            for include_size in [True, False]:
2694                for use_tensor_idx in [True, False]:
2695                    for use_tensor_val in [True, False]:
2696                        for use_cuda in ([False] if not torch.cuda.is_available() else [True, False]):
2697                            # have to include size with cuda sparse tensors
2698                            include_size = include_size or use_cuda
2699                            long_dtype = torch.int64
2700                            device = torch.device('cpu') if not use_cuda else \
2701                                torch.device(torch.cuda.device_count() - 1)
2702                            indices = torch.tensor(([0], [2]), dtype=long_dtype) if use_tensor_idx else ([0], [2])
2703                            if test_empty_tensor:
2704                                values = torch.empty(1, 0).to(dtype)
2705                            else:
2706                                if use_tensor_val:
2707                                    values = torch.tensor([1.], dtype=dtype)
2708                                else:
2709                                    values = 1.
2710                            if include_size:
2711                                sparse_tensor = torch.sparse_coo_tensor(indices, values, size, dtype=dtype,
2712                                                                        device=device, requires_grad=True)
2713                            else:
2714                                sparse_tensor = torch.sparse_coo_tensor(indices, values, dtype=dtype,
2715                                                                        device=device, requires_grad=True)
2716                            self.assertEqual(indices, sparse_tensor._indices())
2717                            self.assertEqual(values, sparse_tensor._values())
2718                            self.assertEqual(size if include_size else default_size, sparse_tensor.size())
2719                            self.assertEqual(dtype, sparse_tensor.dtype)
2720                            if use_cuda:
2721                                self.assertEqual(device, sparse_tensor._values().device)
2722                            self.assertEqual(True, sparse_tensor.requires_grad)
2723
2724    @dtypes(torch.double, torch.cdouble)
2725    def test_factory_size_check(self, device, dtype):
2726        indices = self.index_tensor([[1, 2],
2727                                    [0, 2]], device=device)
2728        values = torch.tensor([.5, .5], dtype=dtype, device=device)
2729        sizes = torch.Size([2, 3])
2730        with self.assertRaisesRegex(RuntimeError, "size is inconsistent with indices"):
2731            torch.sparse_coo_tensor(indices, values, sizes, dtype=dtype, device=device)
2732
2733        indices.fill_(-1)
2734        with self.assertRaisesRegex(RuntimeError, "found negative index"):
2735            torch.sparse_coo_tensor(indices, values, sizes, dtype=dtype, device=device)
2736
2737        indices = self.index_tensor([[1, 2],
2738                                    [0, 2]], device=device)
2739        values = torch.empty([2, 1, 0], dtype=dtype, device=device)
2740        sizes = torch.Size([2, 3, 1, 0])
2741        with self.assertRaisesRegex(RuntimeError, "size is inconsistent with indices"):
2742            torch.sparse_coo_tensor(indices, values, sizes, dtype=dtype, device=device)
2743
2744        indices = self.index_tensor([[1, 2],
2745                                    [0, 2]], device=device)
2746        values = torch.empty([2, 2, 2], dtype=dtype, device=device)
2747        sizes = torch.Size([0, 0, 2, 2])
2748        with self.assertRaisesRegex(RuntimeError, "size is inconsistent with indices"):
2749            torch.sparse_coo_tensor(indices, values, sizes, dtype=dtype, device=device)
2750
2751        indices = self.index_tensor([[1, 2],
2752                                    [0, 2]], device=device)
2753        values = torch.tensor([[1, 1, 1], [1, 1, 1]], dtype=dtype, device=device)
2754        sizes = torch.Size([3, 3, 2])
2755        with self.assertRaisesRegex(RuntimeError, "values has incorrect size"):
2756            torch.sparse_coo_tensor(indices, values, sizes, dtype=dtype, device=device)
2757
2758        indices = self.index_tensor([[1, 2],
2759                                    [0, 2]], device=device)
2760        values = torch.empty([2, 1, 0], dtype=dtype, device=device)
2761        sizes = torch.Size([3, 3, 2, 0])
2762        with self.assertRaisesRegex(RuntimeError, "values has incorrect size"):
2763            torch.sparse_coo_tensor(indices, values, sizes, dtype=dtype, device=device)
2764
2765    def test_factory_empty_indices(self, device):
2766        tensor = torch.sparse_coo_tensor(torch.Size([2, 0]), device=device)
2767        expected_indices = torch.empty((2, 0), dtype=torch.long, device=device)
2768        self.assertEqual(tensor._indices(), expected_indices)
2769
2770        tensor = torch.sparse_coo_tensor(torch.Size([2, 2, 0]), device=device)
2771        expected_indices = torch.empty((3, 0), dtype=torch.long, device=device)
2772        self.assertEqual(tensor._indices(), expected_indices)
2773
2774        tensor = torch.sparse_coo_tensor(torch.Size([2, 2, 0, 0]), device=device)
2775        expected_indices = torch.empty((4, 0), dtype=torch.long, device=device)
2776        self.assertEqual(tensor._indices(), expected_indices)
2777
2778    @dtypes(torch.double, torch.cdouble)
2779    def test_factory_nnz(self, device, dtype):
2780        indices = self.index_tensor([[0]], device=device)  # (sparse_dim, nnz): (1, 1)
2781        values = torch.tensor([[1, 1], [1, 1]], dtype=dtype, device=device)  # (nnz, ...): (2, 2)
2782        sizes = torch.Size([2, 2])
2783        with self.assertRaisesRegex(RuntimeError, "indices and values must have same nnz"):
2784            torch.sparse_coo_tensor(indices, values, sizes, dtype=dtype, device=device)
2785
2786        indices = self.index_tensor([[0]], device=device)  # (sparse_dim, nnz): (1, 1)
2787        values = torch.empty([2, 0], dtype=dtype, device=device)  # (nnz, ...): (2, 0)
2788        sizes = torch.Size([2, 0])
2789        with self.assertRaisesRegex(RuntimeError, "indices and values must have same nnz"):
2790            torch.sparse_coo_tensor(indices, values, sizes, dtype=dtype, device=device)
2791
2792    @dtypes(torch.double, torch.cdouble)
2793    def test_factory_nnz_zero(self, device, dtype):
2794        def test_shape(i_shape, v_shape, size, expected_size):
2795            if size:
2796                t = torch.sparse_coo_tensor(torch.empty(i_shape), torch.empty(v_shape), torch.Size(size),
2797                                            dtype=dtype, device=device)
2798            else:
2799                t = torch.sparse_coo_tensor(torch.empty(i_shape), torch.empty(v_shape), dtype=dtype, device=device)
2800            expected_indices = torch.empty(i_shape, device=device, dtype=torch.int64)
2801            expected_values = torch.empty(v_shape, device=device, dtype=dtype)
2802            expected_size = torch.Size(expected_size)
2803            self.assertEqual(t._indices(), expected_indices)
2804            self.assertEqual(t._values(), expected_values)
2805            self.assertEqual(t.size(), expected_size)
2806
2807        test_shape([1, 0], [0, 2, 4, 0], None, [0, 2, 4, 0])
2808        test_shape([3, 0], [0, 2, 4, 0], None, [0, 0, 0, 2, 4, 0])
2809        test_shape([1, 0], [0, 2, 4, 0], [0, 2, 4, 0], [0, 2, 4, 0])
2810        test_shape([3, 0], [0, 2, 4, 0], [0, 0, 0, 2, 4, 0], [0, 0, 0, 2, 4, 0])
2811        test_shape([3, 0], [0, 2, 4, 0], [1, 2, 3, 2, 4, 0], [1, 2, 3, 2, 4, 0])
2812
2813    @dtypes(torch.double, torch.cdouble)
2814    def test_factory_dense_dim(self, device, dtype):
2815        indices = self.index_tensor([[0]], device=device)
2816        values = torch.tensor([[[1, 1, 1], [1, 1, 1]]], dtype=dtype, device=device)
2817        sizes = torch.Size([1, 3, 4])
2818        with self.assertRaisesRegex(RuntimeError, "values has incorrect size"):
2819            torch.sparse_coo_tensor(indices, values, sizes)
2820
2821        indices = self.index_tensor([[0]], device=device)
2822        values = torch.empty([1, 2, 3, 0], dtype=dtype, device=device)
2823        sizes = torch.Size([1, 3, 4, 0])
2824        with self.assertRaisesRegex(RuntimeError, "values has incorrect size"):
2825            torch.sparse_coo_tensor(indices, values, sizes)
2826
2827    @onlyCPU
2828    @dtypes(torch.float16, torch.float32, torch.float64, torch.cfloat, torch.cdouble, torch.int64)
2829    def test_factory_type_inference(self, device, dtype):
2830        t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.tensor([1.], dtype=dtype))
2831        self.assertEqual(dtype, t.dtype)
2832        t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.tensor([1]))
2833        self.assertEqual(torch.int64, t.dtype)
2834
2835        t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.HalfTensor(1, 0))
2836        self.assertEqual(torch.float16, t.dtype)
2837        t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.FloatTensor(1, 0))
2838        self.assertEqual(torch.float32, t.dtype)
2839        t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.DoubleTensor(1, 0))
2840        self.assertEqual(torch.float64, t.dtype)
2841        t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.LongTensor(1, 0))
2842        self.assertEqual(torch.int64, t.dtype)
2843
2844    @onlyCUDA
2845    def test_factory_device_type_inference(self, device):
2846        # both indices/values are CUDA
2847
2848        cpu_cuda = ('cpu', 'cuda')
2849        cpu_cuda_none = cpu_cuda + (None,)
2850        for indices_device, values_device, device in itertools.product(cpu_cuda,
2851                                                                       cpu_cuda,
2852                                                                       cpu_cuda_none):
2853            indices = torch.tensor(([0], [2]), device=indices_device)
2854            values = torch.tensor([1.], device=values_device)
2855            empty_values = torch.empty(1, 0).to(values_device)
2856            shape = (1, 3)
2857            empty_shape = (1, 3, 0)
2858            if device is None and indices_device != values_device:
2859                with self.assertRaises(RuntimeError):
2860                    torch.sparse_coo_tensor(indices, values, shape, device=device)
2861                with self.assertRaises(RuntimeError):
2862                    torch.sparse_coo_tensor(indices, empty_values, empty_shape, device=device)
2863            else:
2864                t = torch.sparse_coo_tensor(indices, values, shape, device=device)
2865                t_empty = torch.sparse_coo_tensor(indices, empty_values, empty_shape, device=device)
2866                should_be_cuda = (device == 'cuda' or (device is None and values_device == 'cuda'))
2867                self.assertEqual(should_be_cuda, t.is_cuda)
2868                self.assertEqual(t.is_cuda, t_empty.is_cuda)
2869
2870    @onlyCPU
2871    def test_factory_copy(self, device):
2872        def test_tensor(indices, values, indices_equal, values_equal):
2873            sparse_tensor = torch.sparse_coo_tensor(indices, values, dtype=torch.float64, device=device)
2874            if indices_equal:
2875                self.assertEqual(indices.data_ptr(), sparse_tensor._indices().data_ptr())
2876            else:
2877                self.assertNotEqual(indices.data_ptr(), sparse_tensor._indices().data_ptr())
2878            if values_equal:
2879                self.assertEqual(values.data_ptr(), sparse_tensor._values().data_ptr())
2880            else:
2881                self.assertNotEqual(values.data_ptr(), sparse_tensor._values().data_ptr())
2882
2883        # both correct
2884        indices = torch.tensor(([0], [2]), dtype=torch.int64)
2885        values = torch.tensor([1.], dtype=torch.float64)
2886        test_tensor(indices, values, True, True)
2887
2888        indices = torch.tensor(([0], [2]), dtype=torch.int64)
2889        values = torch.DoubleTensor(1, 0)
2890        test_tensor(indices, values, True, True)
2891
2892        # only indices correct
2893        indices = torch.tensor(([0], [2]), dtype=torch.int64)
2894        values = torch.tensor([1.], dtype=torch.float32)
2895        test_tensor(indices, values, True, False)
2896
2897        indices = torch.tensor(([0], [2]), dtype=torch.int64)
2898        values = torch.tensor([1.], dtype=torch.float16)
2899        test_tensor(indices, values, True, False)
2900
2901        indices = torch.tensor(([0], [2]), dtype=torch.int64)
2902        values = torch.FloatTensor(1, 0)
2903        test_tensor(indices, values, True, True)  # An empty tensor's data_ptr is always equal to 0
2904
2905        # only values correct
2906        indices = torch.tensor(([0], [2]), dtype=torch.int32)
2907        values = torch.tensor([1.], dtype=torch.float64)
2908        test_tensor(indices, values, False, True)
2909
2910        indices = torch.tensor(([0], [2]), dtype=torch.int32)
2911        values = torch.DoubleTensor(1, 0)
2912        test_tensor(indices, values, False, True)
2913
2914        # neither correct
2915        indices = torch.tensor(([0], [2]), dtype=torch.int32)
2916        values = torch.tensor([1.], dtype=torch.float32)
2917        test_tensor(indices, values, False, False)
2918
2919        indices = torch.tensor(([0], [2]), dtype=torch.int32)
2920        values = torch.FloatTensor(1, 0)
2921        test_tensor(indices, values, False, True)  # An empty tensor's data_ptr is always equal to 0
2922
2923        # complex support
2924        indices = torch.tensor(([0], [2]), dtype=torch.int64)
2925        values = make_tensor([1, ], dtype=torch.cdouble, device=device)
2926        test_tensor(indices, values, True, False)
2927
2928        indices = torch.tensor(([0], [2]), dtype=torch.int32)
2929        values = make_tensor([1, 1], dtype=torch.cdouble, device=device)
2930        test_tensor(indices, values, False, False)
2931
2932    @onlyCPU  # just run once, we test both cpu and cuda
2933    def test_legacy_new_device(self, device):
2934        i = torch.tensor([[0, 1, 1], [2, 0, 2]])
2935        v = torch.tensor([3., 4., 5.])
2936        size = torch.Size([2, 3])
2937
2938        x = torch.sparse_coo_tensor(i, v, size, device='cpu')
2939        self.assertRaises(RuntimeError, lambda: x.new(device='cuda'))
2940        self.assertRaises(RuntimeError, lambda: x.new(i, v, device='cuda'))
2941        self.assertRaises(RuntimeError, lambda: x.new(i, v, size, device='cuda'))
2942        self.assertRaises(RuntimeError, lambda: x.new(torch.Size([2, 3, 4]), device='cuda'))
2943
2944        if torch.cuda.is_available():
2945            x = torch.sparse_coo_tensor(i, v, size, device='cuda')
2946            self.assertRaises(RuntimeError, lambda: x.new(device='cpu'))
2947            self.assertRaises(RuntimeError, lambda: x.new(i, v, device='cpu'))
2948            self.assertRaises(RuntimeError, lambda: x.new(i, v, size, device='cpu'))
2949            self.assertRaises(RuntimeError, lambda: x.new(torch.Size([2, 3, 4]), device='cpu'))
2950
2951    def test_legacy_new(self, device):
2952        i = torch.tensor([[0, 1, 1], [2, 0, 2]])
2953        v = torch.tensor([3., 4., 5.])
2954        size = torch.Size([2, 3])
2955        s = torch.sparse_coo_tensor(i, v, size)
2956
2957        self.assertEqual(torch.sparse_coo, s.new(device='cpu').layout)
2958        self.assertRaises(TypeError, lambda: s.new(v.untyped_storage()))
2959        self.assertRaises(TypeError, lambda: s.new(v))
2960        self.assertEqual(torch.sparse_coo, s.new(torch.Size([2, 3])).layout)
2961        self.assertRaises(TypeError, lambda: s.new([6]))
2962
2963    @onlyCPU  # not really, but we only really want to run this once
2964    def test_dtypes(self, device):
2965        all_sparse_dtypes = all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)
2966        do_test_dtypes(self, all_sparse_dtypes, torch.sparse_coo, torch.device('cpu'))
2967        if torch.cuda.is_available():
2968            do_test_dtypes(self, all_sparse_dtypes, torch.sparse_coo, torch.device('cuda:0'))
2969
2970    def _test_empty_full(self, device, dtype, requires_grad):
2971        shape = (2, 3)
2972        layout = torch.sparse_coo
2973
2974        def check_value(tensor, value=None, dtype=dtype, requires_grad=requires_grad):
2975            self.assertEqual(shape, tensor.shape)
2976            self.assertIs(dtype, tensor.dtype)
2977            self.assertIs(layout, tensor.layout)
2978            self.assertEqual(tensor.requires_grad, requires_grad)
2979            if tensor.is_cuda and device is not None:
2980                self.assertEqual(device, tensor.device)
2981            if value is not None:
2982                fill = tensor.empty(shape, dtype=dtype).fill_(value)
2983                self.assertEqual(tensor, fill)
2984
2985        v = torch.sparse_coo_tensor(shape, dtype=dtype, device=device, requires_grad=requires_grad)
2986        check_value(v)
2987
2988        out = v.new()
2989        check_value(torch.zeros(shape, out=out, device=device, requires_grad=requires_grad))
2990
2991        int64_dtype = torch.int64
2992        check_value(v.new_empty(shape), requires_grad=False)
2993        check_value(v.new_empty(shape, dtype=int64_dtype, device=device, requires_grad=False),
2994                    dtype=int64_dtype, requires_grad=False)
2995        check_value(torch.empty_like(v), requires_grad=False)
2996        check_value(torch.empty_like(v, dtype=int64_dtype, layout=layout, device=device, requires_grad=False),
2997                    dtype=int64_dtype, requires_grad=False)
2998
2999    @onlyCPU  # not really, but we only really want to run this once
3000    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3001    @parametrize('requires_grad', (True, False))
3002    def test_empty_full(self, device, dtype, requires_grad):
3003        if requires_grad and not (dtype.is_floating_point or dtype.is_complex):
3004            self.skipTest(f'requires_grad==True requires float or complex dtype, got {dtype}')
3005
3006        self._test_empty_full(device, dtype, requires_grad)
3007        if torch.cuda.is_available():
3008            self._test_empty_full(None, dtype, requires_grad)
3009            self._test_empty_full(torch.device('cuda:0'), dtype, requires_grad)
3010
3011    def test_is_sparse(self, device):
3012        x = torch.randn(3, 3)
3013        self.assertFalse(x.is_sparse)
3014
3015        x = torch.randn(3, 3, 0)
3016        self.assertFalse(x.is_sparse)
3017
3018        x = self.sparse_empty(1, 0, device=device)
3019        self.assertTrue(x.is_sparse)
3020
3021    def test_resize_as(self, device):
3022        def do_test(t):
3023            y = t.new().resize_as_(t).zero_()
3024            self.assertEqual(y.shape, t.shape)
3025            # Check that y can be added to t. Currently, this requires that
3026            # sparse_dim and dense_dim match.
3027            self.assertEqual(t, t + y)
3028
3029        do_test(self.sparse_empty([3, 0], device=device))
3030        do_test(self.sparse_empty([3, 3], device=device))
3031
3032    def _test_resize_shape(self, x_i, x_v, x_size, y_i, y_v, y_size, dtype, device):
3033        x_v_numel = torch.zeros(x_v).numel()
3034        x = torch.sparse_coo_tensor(torch.zeros(x_i),
3035                                    torch.arange(x_v_numel).resize_(x_v).to(torch.float),
3036                                    torch.Size(x_size), dtype=dtype, device=device)
3037        x_dense = x.to_dense()
3038        y = torch.sparse_coo_tensor(torch.zeros(y_i),
3039                                    torch.ones(y_v).to(torch.float),
3040                                    torch.Size(y_size), dtype=dtype, device=device)
3041        y_dense = y.to_dense()
3042        x.resize_as_(y)
3043        x_dense.resize_as_(y_dense)
3044        self.assertEqual(x.shape, y.shape)
3045        self.assertEqual(x.sparse_dim(), y.sparse_dim())
3046        self.assertEqual(x.dense_dim(), y.dense_dim())
3047        self.assertEqual(x.shape, x_dense.shape)
3048        self.assertEqual(y.shape, y_dense.shape)
3049        # Here we make sure that the original data are preserved after resizing
3050        self.assertEqual(x.to_dense().view(-1)[0:x_v_numel].view(x_v),
3051                         x_dense.view(-1)[0:x_v_numel].view(x_v))
3052
3053    @dtypes(torch.double, torch.cdouble)
3054    def test_resize(self, device, dtype):
3055        # 1. Expand the size of some dense dimensions [Supported]
3056        self._test_resize_shape([1, 1], [1, 2, 3], [2, 2, 3],
3057                                [1, 1], [1, 2, 4], [2, 2, 4],
3058                                dtype=dtype, device=device)
3059
3060        self._test_resize_shape([1, 1], [1, 2, 0], [2, 2, 0],
3061                                [1, 1], [1, 2, 4], [2, 2, 4],
3062                                dtype=dtype, device=device)
3063
3064        # 2. Expand the size of some sparse dimensions [Supported]
3065        self._test_resize_shape([1, 1], [1, 2, 3], [2, 2, 3],
3066                                [1, 1], [1, 2, 3], [4, 2, 3],
3067                                dtype=dtype, device=device)
3068
3069        # 3. Change the shapes of both sparse and dense dimensions when nnz is zero [Supported]
3070        self._test_resize_shape([1, 0], [0, 2, 3], [2, 2, 3],
3071                                [2, 0], [0, 2, 4, 5], [1, 1, 2, 4, 5],
3072                                dtype=dtype, device=device)
3073
3074        self._test_resize_shape([1, 0], [0, 2, 3], [2, 2, 3],
3075                                [2, 0], [0, 2, 4, 0], [1, 1, 2, 4, 0],
3076                                dtype=dtype, device=device)
3077
3078        # 4. Add dims to dense dimensions [Not Supported]
3079        with self.assertRaisesRegex(RuntimeError, "changing the number of dense dimensions"):
3080            self._test_resize_shape([1, 1], [1, 2, 3], [2, 2, 3],
3081                                    [1, 1], [1, 2, 3, 4], [2, 2, 3, 4],
3082                                    dtype=dtype, device=device)
3083
3084        with self.assertRaisesRegex(RuntimeError, "changing the number of dense dimensions"):
3085            self._test_resize_shape([1, 1], [1, 2, 3], [2, 2, 3],
3086                                    [1, 1], [1, 2, 3, 0], [2, 2, 3, 0],
3087                                    dtype=dtype, device=device)
3088
3089        # 5. Remove dims from dense dimensions [Not Supported]
3090        with self.assertRaisesRegex(RuntimeError, "changing the number of dense dimensions"):
3091            self._test_resize_shape([1, 1], [1, 2, 3], [2, 2, 3],
3092                                    [1, 1], [1, 2], [2, 2],
3093                                    dtype=dtype, device=device)
3094
3095        # 6. Change the number of sparse dimensions on a non-empty sparse tensor [Not Supported]
3096        with self.assertRaisesRegex(RuntimeError, "changing the number of sparse dimensions"):
3097            self._test_resize_shape([1, 1], [1, 2, 3], [2, 2, 3],
3098                                    [2, 1], [1, 2, 3], [1, 2, 2, 3],
3099                                    dtype=dtype, device=device)
3100
3101        # 7. Shrink the size of some sparse dimensions on a non-empty sparse tensor [Not Supported]
3102        with self.assertRaisesRegex(RuntimeError, "shrinking the size of sparse dimensions"):
3103            self._test_resize_shape([1, 1], [1, 2, 3], [2, 2, 3],
3104                                    [1, 1], [1, 2, 3], [1, 2, 3],
3105                                    dtype=dtype, device=device)
3106
3107        # 8. Shrink the size of some dense dimensions on a non-empty sparse tensor [Not Supported]
3108        with self.assertRaisesRegex(RuntimeError, "shrinking the size of dense dimensions"):
3109            self._test_resize_shape([1, 1], [1, 2, 3], [2, 2, 3],
3110                                    [1, 1], [1, 2, 2], [2, 2, 2],
3111                                    dtype=dtype, device=device)
3112
3113        with self.assertRaisesRegex(RuntimeError, "shrinking the size of dense dimensions"):
3114            self._test_resize_shape([1, 1], [1, 2, 3], [2, 2, 3],
3115                                    [1, 1], [1, 2, 0], [2, 2, 0],
3116                                    dtype=dtype, device=device)
3117
3118    def test_is_nonzero(self, device):
3119        self.assertTrue(torch.sparse_coo_tensor(([0],), 1., (1,), device=device).is_nonzero())
3120        self.assertFalse(torch.sparse_coo_tensor(([0],), 0., (1,), device=device).is_nonzero())
3121        self.assertFalse(torch.sparse_coo_tensor(([0], [0]), 0., (1, 1), device=device).is_nonzero())
3122        self.assertFalse(torch.sparse_coo_tensor(([0, 0],), (0., 0.), (1,), device=device).is_nonzero())
3123        self.assertFalse(torch.sparse_coo_tensor(([0, 0],), (-1., 1.), (1,), device=device).is_nonzero())
3124
3125        # scalar sparse tensor
3126        self.assertTrue(torch.sparse_coo_tensor(torch.zeros(0, 1), 12.3, [], device=device).is_nonzero())
3127        with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with no values is ambiguous"):
3128            torch.sparse_coo_tensor(([0, 1],), torch.empty(2, 0), (4, 0), device=device).is_nonzero()
3129        self.assertTrue(torch.sparse_coo_tensor(([0],), 2.3 - 4.5j, (1,), dtype=torch.cfloat, device=device)
3130                        .is_nonzero())
3131        self.assertTrue(torch.sparse_coo_tensor(([0],), 2.3 - 4.5j, (1,), dtype=torch.cdouble, device=device)
3132                        .is_nonzero())
3133        self.assertFalse(torch.sparse_coo_tensor(([0],), 0. + 0j, (1,), dtype=torch.cfloat, device=device)
3134                         .is_nonzero())
3135        self.assertFalse(torch.sparse_coo_tensor(([0],), 0. + 0j, (1,), dtype=torch.cdouble, device=device)
3136                         .is_nonzero())
3137
3138    @dtypes(torch.double, torch.cdouble)
3139    def test_change_tensor_metadata(self, device, dtype):
3140        i = self.index_tensor([[0], [1]], device=device)
3141        v = torch.tensor([[3, 4, 5]], dtype=dtype, device=device)
3142        t = torch.sparse_coo_tensor(i, v, torch.Size([1, 2, 3]), dtype=dtype, device=device)
3143        i.resize_(2, 3)
3144        v.resize_(4, 5)
3145        self.assertEqual(list(t.coalesce().indices().size()), [2, 1])
3146        self.assertEqual(list(t.coalesce().values().size()), [1, 3])
3147
3148        i = self.index_tensor([[0], [1]], device=device)
3149        v = torch.tensor([[3, 4, 5]], dtype=dtype, device=device)
3150        t = torch.sparse_coo_tensor(i, v, torch.Size([1, 2, 3]))
3151        i.resize_as_(self.index_tensor([0, 1], device=device))
3152        v.resize_as_(torch.tensor([3, 4, 5], dtype=dtype, device=device))
3153        self.assertEqual(list(t.coalesce().indices().size()), [2, 1])
3154        self.assertEqual(list(t.coalesce().values().size()), [1, 3])
3155
3156        i = self.index_tensor([[0], [1]], device=device)
3157        v = torch.tensor([[3, 4, 5]], dtype=dtype, device=device)
3158        t = torch.sparse_coo_tensor(i, v, torch.Size([1, 2, 3]))
3159        i.as_strided_((2, 1), (1, 1))
3160        v.as_strided_((1, 3), (1, 1))
3161        self.assertEqual(list(t.coalesce().indices().size()), [2, 1])
3162        self.assertEqual(list(t.coalesce().values().size()), [1, 3])
3163
3164        i = self.index_tensor([[0], [1]], device=device)
3165        v = torch.tensor([[3, 4, 5]], dtype=dtype, device=device)
3166        t = torch.sparse_coo_tensor(i, v, torch.Size([1, 2, 3]))
3167        i.set_(self.index_tensor([0, 1], device=device))
3168        v.set_(torch.tensor([3, 4, 5], dtype=dtype, device=device))
3169        self.assertEqual(list(t.coalesce().indices().size()), [2, 1])
3170        self.assertEqual(list(t.coalesce().values().size()), [1, 3])
3171
3172        i = self.index_tensor([[0], [1]], device=device)
3173        v = torch.tensor([[3, 4, 5]], dtype=dtype, device=device)
3174        t = torch.sparse_coo_tensor(i, v, torch.Size([1, 2, 3]))
3175        i.transpose_(0, 1)
3176        v.transpose_(0, 1)
3177        self.assertEqual(list(t.coalesce().indices().size()), [2, 1])
3178        self.assertEqual(list(t.coalesce().values().size()), [1, 3])
3179
3180    @coalescedonoff
3181    @dtypes(torch.double)
3182    def test_pickle(self, device, dtype, coalesced):
3183        import pickle
3184
3185        shape_sparse_dim_nnz = [
3186            ((), 0, 2),
3187            ((0,), 0, 10),
3188            ((2,), 0, 3),
3189            ((100, 3), 1, 3),
3190            ((100, 20, 3), 2, 0),
3191            ((10, 0, 3), 0, 3),
3192            ((10, 0, 3), 0, 0),
3193        ]
3194
3195        for shape, sparse_dim, nnz in shape_sparse_dim_nnz:
3196            indices_shape = torch.Size((sparse_dim, nnz))
3197            values_shape = torch.Size((nnz,) + shape[sparse_dim:])
3198            indices = torch.arange(indices_shape.numel(), dtype=self.index_tensor(0).dtype,
3199                                   device=device).view(indices_shape)
3200            for d in range(sparse_dim):
3201                indices[d].clamp_(max=(shape[d] - 1))  # make it valid index
3202            if not coalesced and indices.numel() > 0:
3203                indices[:, -1] = indices[:, 0]  # make it uncoalesced
3204            values_numel = values_shape.numel()
3205            values = torch.arange(values_numel, dtype=dtype,
3206                                  device=device).view(values_shape).div_(values_numel / 2.)
3207            sp_tensor = self.sparse_tensor(indices, values, shape)
3208            serialized = pickle.dumps(sp_tensor)
3209            sp_tensor_loaded = pickle.loads(serialized)
3210            self.assertEqual(sp_tensor, sp_tensor_loaded)
3211
3212    def test_any(self, device):
3213        t = torch.sparse_coo_tensor(torch.tensor(([0, 0], [2, 0])), torch.tensor([False, False]), device=device)
3214        t_any = torch.tensor(False)
3215        self.assertEqual(torch.any(t), t_any)
3216        t = torch.sparse_coo_tensor(torch.tensor(([0, 0], [2, 0])), torch.tensor([True, False]), device=device)
3217        t_any = torch.tensor(True)
3218        self.assertEqual(torch.any(t), t_any)
3219
3220    def test_isnan(self, device):
3221        t = torch.sparse_coo_tensor(torch.tensor(([0, 0], [0, 2])), torch.tensor([1, 4]), device=device)
3222        t_nan = torch.sparse_coo_tensor(torch.tensor(([0, 0], [0, 2])), torch.tensor([False, False]), device=device)
3223        self.assertEqual(torch.isnan(t).int(), t_nan.int())
3224        t = torch.sparse_coo_tensor(torch.tensor(([0, 0], [0, 2])), torch.tensor([1, float("nan")]), device=device)
3225        t_nan = torch.sparse_coo_tensor(torch.tensor(([0, 0], [0, 2])), torch.tensor([False, True]), device=device)
3226        self.assertEqual(torch.isnan(t).int(), t_nan.int())
3227
3228    @coalescedonoff
3229    @dtypes(torch.float32, torch.float64)
3230    def test_div_rounding_mode(self, device, dtype, coalesced):
3231        sparse, _, _ = self._gen_sparse(2, 10, (10, 10), dtype,
3232                                        device, coalesced)
3233        dense = self.safeToDense(sparse)
3234
3235        for mode in (None, 'floor', 'trunc'):
3236            actual = sparse.div(-2, rounding_mode=mode)
3237            expect = dense.div(-2, rounding_mode=mode)
3238            self.assertEqual(self.safeToDense(actual), expect)
3239
3240            # Test inplace
3241            actual = sparse.clone().div_(-2, rounding_mode=mode)
3242            self.assertEqual(self.safeToDense(actual), expect)
3243
3244            # Test out argument
3245            actual.zero_()
3246            torch.div(sparse, -2, rounding_mode=mode, out=actual)
3247            self.assertEqual(self.safeToDense(actual), expect)
3248
3249    def test_div_by_sparse_error(self, device):
3250        self.assertRaisesRegex(RuntimeError, 'Sparse division requires',
3251                               lambda: torch.tensor(1., device=device).to_sparse()
3252                               / torch.tensor(1., device=device).to_sparse())
3253
3254    def test_floor_divide_by_sparse_error(self, device):
3255        self.assertRaisesRegex(RuntimeError, 'Sparse floor division requires',
3256                               lambda: torch.tensor(1., device=device).to_sparse()
3257                               // torch.tensor(1., device=device).to_sparse())
3258
3259    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
3260    @onlyCPU
3261    def test_sparse_to_numpy(self, device):
3262        t = torch.sparse_coo_tensor(torch.tensor(([0, 0], [2, 0])), torch.tensor([1, 4]))
3263        self.assertRaises(TypeError, lambda: t.numpy())
3264
3265    @coalescedonoff
3266    @dtypes(torch.double)
3267    def test_softmax(self, device, dtype, coalesced):
3268        import torch.nn.functional as F
3269
3270        def to_dense(sparse, fill_value=None):
3271            """
3272            Return dense tensor from a sparse tensor using given fill value.
3273            """
3274            if fill_value is None or fill_value == 0:
3275                return sparse.to_dense()
3276            sparse = sparse.coalesce()
3277            dense = torch.full(sparse.shape, fill_value, dtype=sparse.dtype, device=sparse.device)
3278            for idx, value in zip(sparse._indices().t(), sparse._values()):
3279                dense[tuple(idx)] = value
3280            return dense
3281
3282        def softmax_to_dense(sparse, dim):
3283            """Dense softmax of a sparse tensor. Useful only for testing softmax
3284            correctness.
3285
3286            When computing softmax of a sparse tensor, the value of
3287            unspecified items is negative infinity rather than zero so
3288            that
3289
3290              softmax(sparse.to_dense(fill_value=-inf), dim) == softmax(sparse, dim).to_dense()
3291
3292            holds for non-empty lines. One empty lines, the softmax
3293            values are defined as 0 in order to preserve the sparsity
3294            of result.
3295
3296            Note that in PyTorch, ``to_dense`` method does not
3297            implement the ``fill_value`` keyword argument.
3298            """
3299            dtype = sparse.dtype
3300            device = sparse.device
3301            dense = to_dense(sparse, fill_value=-float('inf'))
3302            r = F.softmax(dense, dim)
3303            # softmax on empty lines results nan, replace with zeros to match the definition
3304            r[r != r] = 0
3305            return r
3306
3307        def sparse_softmax(sparse, dim):
3308            """Pure Python softmax of a sparse tensor. Assuming -inf for
3309            unspecified sparse tensor data. This is a prototype of
3310            sparse softmax algorithm in Python.
3311            """
3312            dtype = sparse.dtype
3313            device = sparse.device
3314
3315            # softmax is non-linear operation, so sparse tensors must
3316            # be coalesced.
3317            sparse = sparse.coalesce()
3318            inf = float('inf')
3319            indices = sparse._indices()
3320            values = sparse._values()
3321
3322            if dim < sparse.sparse_dim():
3323                nnz = sparse._nnz()
3324
3325                # compute pool indices
3326                size = sparse.size()
3327                strides = torch.ones((sparse.sparse_dim(), 1), dtype=indices.dtype, device=indices.device)
3328                for i in reversed(range(sparse.sparse_dim() - 1)):
3329                    strides[i, 0] = strides[i + 1, 0] * size[i + 1]
3330                strides[dim, 0] = 0
3331
3332                pool = (indices * strides).sum(dim=0)
3333                i2p = {}
3334                for i in range(nnz):
3335                    c = int(pool[i])
3336                    if c not in i2p:
3337                        i2p[c] = len(i2p)
3338                    pool[i] = i2p[c]
3339
3340                # compute max
3341                dense_size = tuple(size[sparse.sparse_dim():])
3342                mx = torch.empty((pool.max() + 1,) + dense_size, dtype=dtype, device=device)
3343                mx[:] = -inf
3344                for n in range(nnz):
3345                    p = pool[n]
3346                    mx[p] = torch.max(mx[p], values[n])
3347
3348                # apply exp to (v - mx) and sum the results
3349                exp_values = torch.empty_like(values)
3350                exp_sums = torch.zeros_like(mx)
3351                for n in range(nnz):
3352                    p = pool[n]
3353                    v = exp_values[n] = (values[n] - mx[p]).exp()
3354                    exp_sums[p] = exp_sums[p] + v
3355
3356                # normalize with the sum of exponents
3357                for n in range(nnz):
3358                    p = pool[n]
3359                    exp_values[n] = exp_values[n] / exp_sums[p]
3360
3361                return torch.sparse_coo_tensor(indices,
3362                                               exp_values,
3363                                               sparse.size(),
3364                                               dtype=dtype, device=device)
3365
3366            elif dim < sparse.sparse_dim() + sparse.dense_dim():
3367                return torch.sparse_coo_tensor(indices,
3368                                               F.softmax(values, dim - sparse.sparse_dim() + 1),
3369                                               sparse.size(),
3370                                               dtype=dtype, device=device)
3371            else:
3372                raise ValueError(
3373                    f'`dim(={dim})` must be smaller than `sparse_dim(={sparse.sparse_dim()}) + dense_dim(={sparse.dense_dim()})`')
3374
3375        def softmax_jacobian_analytic(x, dim):
3376            """Return Jacobian of softmax using analytic formula
3377
3378               D_jS_i = S_i * (1[i==j] - S_j).
3379
3380            where S = softmax(x, dim), x is dense tensor, i,j in
3381            range(x.shape[dim]).
3382            """
3383            y = F.softmax(x, dim)
3384            y[y != y] = 0  # replace nan-s with zeros
3385            J = torch.zeros((x.shape[dim],) + tuple(x.shape), dtype=x.dtype, device=x.device)
3386            si = [slice(None)] * len(y.shape)
3387            sj = [slice(None)] * len(y.shape)
3388            s = [slice(None)] * len(J.shape)
3389            for i in range(y.shape[dim]):
3390                si[dim] = i
3391                s[dim + 1] = i
3392                yi = y[tuple(si)]
3393                for j in range(y.shape[dim]):
3394                    sj[dim] = j
3395                    s[0] = j
3396                    if i == j:
3397                        J[tuple(s)] = yi * (1 - yi)
3398                    else:
3399                        yj = y[tuple(sj)]
3400                        J[tuple(s)] = - yi * yj
3401                    sj[dim] = slice(None)
3402                si[dim] = slice(None)
3403                s[dim + 1] = slice(None)
3404            return J
3405
3406        def softmax_jacobian_autograd(x, dim, log=False):
3407            """Return Jacobian of softmax using PyTorch autograd feature.
3408
3409            x can be dense or sparse tensor.
3410            """
3411            import itertools
3412
3413            if x.is_sparse:
3414                x = x.coalesce()
3415
3416            dtype = x.dtype
3417            device = x.device
3418            shape = tuple(x.shape)
3419            J = torch.zeros((shape[dim],) + shape, dtype=dtype, device=device)
3420            for i in range(shape[dim]):
3421                if x.is_sparse:
3422                    sparse_dim = x.sparse_dim()
3423                    dense_dim = x.dense_dim()
3424                    if dim < sparse_dim:
3425                        ranges = []
3426                        for j, sz in enumerate(shape[:sparse_dim]):
3427                            if dim == j:
3428                                ranges.append([i])
3429                            else:
3430                                ranges.append(list(range(sz)))
3431                        indices = torch.tensor(list(itertools.product(*ranges)), dtype=torch.long, device=device).t()
3432                        values = torch.ones((indices.shape[1],) + shape[sparse_dim:], dtype=dtype, device=device)
3433                    else:
3434                        ranges = []
3435                        for j, sz in enumerate(shape[:sparse_dim]):
3436                            ranges.append(list(range(sz)))
3437                        indices = torch.tensor(list(itertools.product(*ranges)), dtype=torch.long, device=device).t()
3438                        values = torch.zeros((indices.shape[1],) + shape[sparse_dim:], dtype=dtype, device=device)
3439                        sv = [slice(None)] * (dense_dim + 1)
3440                        sv[dim - sparse_dim + 1] = i
3441                        values[tuple(sv)] = 1
3442                    v = torch.sparse_coo_tensor(indices, values, shape, dtype=dtype, device=device)
3443                else:
3444                    v = torch.zeros_like(x)
3445                    sv = [slice(None)] * len(v.shape)
3446                    sv[dim] = i
3447                    v[tuple(sv)] = 1
3448                x_ = x.clone()
3449                x_.requires_grad_(True)
3450
3451                if log:
3452                    if x_.is_sparse:
3453                        y = torch.sparse.log_softmax(x_, dim)
3454                    else:
3455                        y = F.log_softmax(x_, dim)
3456                else:
3457                    if x_.is_sparse:
3458                        y = torch.sparse.softmax(x_, dim)
3459                    else:
3460                        y = F.softmax(x_, dim)
3461                        # replace nan-s with zeros
3462                        y.data[y != y] = 0
3463                y.backward(v)
3464                g = x_.grad
3465                if not g.is_sparse:
3466                    # replace nan-s with zeros
3467                    g.data[g != g] = 0
3468                J[i] = g.to_dense() if g.is_sparse else g
3469            return J
3470
3471        @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1166")
3472        def test_op(sparse_dims, nnz, with_size, coalesced):
3473            if isinstance(with_size, Number):
3474                with_size = [with_size] * sparse_dims
3475
3476            x, i, v = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)
3477
3478            def sparse_log(x):
3479                return torch.sparse_coo_tensor(x._indices(), x._values().log(),
3480                                               x.size(), dtype=x.dtype, device=x.device)
3481
3482            # Check dim out of bounds
3483            with self.assertRaisesRegex(IndexError, r"Dimension out of range"):
3484                torch.sparse.softmax(x, x.dim())
3485            with self.assertRaisesRegex(IndexError, r"Dimension out of range"):
3486                torch.sparse.softmax(x, -x.dim() - 1)
3487
3488            for dim in range(x.dim()):
3489                # Check sparse softmax definition
3490
3491                # check Python sparse softmax
3492                y = sparse_softmax(x, dim)
3493                r1 = softmax_to_dense(x, dim)
3494                r2 = y.to_dense()
3495                self.assertEqual(r1, r2)
3496
3497                # check C++ sparse softmax
3498                for d in (dim, dim - x.dim()):
3499                    y1 = torch.sparse.softmax(x, d)
3500                    self.assertEqual(y, y1)
3501
3502                    # check C++ sparse log_softmax
3503                    ly1 = torch.sparse.log_softmax(x, d)
3504                    self.assertEqual(ly1, sparse_log(y1))
3505
3506                # Check autograd support on sparse softmax
3507
3508                # check softmax Jacobian definition for dense input
3509                x1 = to_dense(x, fill_value=float('-inf'))
3510                J = softmax_jacobian_analytic(x1, dim)
3511                assert J.shape[0] == x.shape[dim]
3512                assert J.shape[dim + 1] == x.shape[dim]
3513
3514                # check softmax Jacobian from autograd, dense input
3515                J2 = softmax_jacobian_autograd(x1, dim)
3516                self.assertEqual(J, J2)
3517
3518                # check softmax Jacobian from autograd, sparse input
3519                J3 = softmax_jacobian_autograd(x, dim)
3520                self.assertEqual(J, J3)
3521
3522                '''
3523                y = softmax(x, dim)
3524                z = log(y) = log_softmax(x, dim)
3525                Dy/Dx = J
3526                Dz/Dx = Dz/Dy Dy/Dx = 1/y * J
3527                => J = J_log * y
3528                '''
3529                # log_softmax Jacobian from autograd, dense input
3530                J2_log = softmax_jacobian_autograd(x1, dim, log=True)
3531
3532                # log_softmax Jacobian from autograd, sparse input
3533                J3_log = softmax_jacobian_autograd(x, dim, log=True)
3534
3535                J = J.transpose(0, dim + 1)
3536                J2_log = J2_log.transpose(0, dim + 1)
3537                J3_log = J3_log.transpose(0, dim + 1)
3538                self.assertEqual(J, J2_log * r1)
3539                self.assertEqual(J, J3_log * r1)
3540
3541                if dim == 0:
3542                    # check dtype argument
3543                    other_dtype = torch.float32
3544                    y2 = torch.sparse.softmax(x, dim, dtype=other_dtype)
3545                    self.assertEqual(y2.dtype, other_dtype)
3546                    self.assertEqual(y2, y1.type(other_dtype))
3547
3548                    ly2 = torch.sparse.log_softmax(x, dim, dtype=other_dtype)
3549                    self.assertEqual(ly2.dtype, other_dtype)
3550                    self.assertEqual(ly2, ly1.type(other_dtype))
3551
3552        test_op(1, 10, [3], coalesced)
3553        test_op(1, 10, [2, 3], coalesced)
3554        test_op(1, 10, [3, 2], coalesced)
3555        test_op(2, 10, [2, 3, 4], coalesced)
3556        test_op(2, 10, [3, 4], coalesced)
3557        test_op(2, 5, [5, 4], coalesced)
3558        test_op(2, 10, [3, 4, 2], coalesced)
3559        test_op(3, 10, [3, 4, 2], coalesced)
3560        test_op(3, 100, [3, 4, 2], coalesced)
3561        test_op(3, 100, [3, 4, 2, 3], coalesced)
3562        test_op(3, 100, [3, 4, 2, 3, 5, 2], coalesced)
3563        test_op(4, 100, [3, 4, 2, 3, 5, 2], coalesced)
3564
3565
3566    def _check_zero_nnz_softmax_op(self, func, ndim, device, dtype):
3567        # create a sparse tensor with shape (0,..., 3) it has no materialize values
3568        t = torch.sparse_coo_tensor([[] for _ in range(ndim)], [], (0,) * (ndim - 1) + (3,), device=device, dtype=dtype)
3569        out = func(t, 0)
3570        self.assertEqual(out, torch.zeros_like(t))
3571
3572        # gradient
3573        t = t.requires_grad_()
3574        gradcheck(lambda x: func(x, 0).to_dense(), (t,), masked=True)
3575
3576
3577    @dtypes(torch.double, torch.float)
3578    @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
3579    def test_softmax_zero_nnz(self, device, dtype):
3580        self._check_zero_nnz_softmax_op(torch.sparse.softmax, 1, device, dtype)
3581        self._check_zero_nnz_softmax_op(torch.sparse.softmax, 10, device, dtype)
3582
3583    @dtypes(torch.double, torch.float)
3584    @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
3585    def test_log_softmax_zero_nnz(self, device, dtype):
3586        self._check_zero_nnz_softmax_op(torch.sparse.log_softmax, 1, device, dtype)
3587        self._check_zero_nnz_softmax_op(torch.sparse.log_softmax, 10, device, dtype)
3588
3589    # TODO: Check after why ROCm's cusparseXcsrgemm2Nnz function doesn't return the same nnz value as CUDA
3590    @skipIfRocm
3591    @coalescedonoff
3592    @dtypes(*floating_and_complex_types())
3593    @dtypesIfCUDA(*floating_types_and(*[torch.half] if SM53OrLater else [],
3594                                      *[torch.bfloat16] if SM80OrLater else [],
3595                                      torch.complex64,
3596                                      *[torch.complex128] if CUSPARSE_SPMM_COMPLEX128_SUPPORTED else []))
3597    @unittest.skipIf(TEST_WITH_CROSSREF, "not working with fake tensor")
3598    @precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2, torch.complex64: 1e-2, torch.float32: 1e-2})
3599    def test_sparse_matmul(self, device, dtype, coalesced):
3600        """
3601        This function test `torch.sparse.mm` when both the mat1 and mat2 are sparse tensors.
3602        """
3603
3604        def ref_sparse_mm(a, b):
3605            return a.to_dense() @ b.to_dense()
3606
3607        def grad_with_custom_sparsity_pattern_test_helper(sparse_dims, nnz, shape_a, shape_b):
3608            def test_grad_dense(a_s, b_s, g_s):
3609                a = a_s.to_dense().detach()
3610                b = b_s.to_dense().detach()
3611                g = g_s.to_dense().detach()
3612
3613                a.requires_grad_(True)
3614                b.requires_grad_(True)
3615                c = a @ b
3616                c.backward(g)
3617                return a.grad.sparse_mask(a_s.coalesce()), b.grad.sparse_mask(b_s.coalesce())
3618
3619            a, _, _ = self._gen_sparse(sparse_dims, nnz, shape_a, dtype, device, coalesced)
3620            b, _, _ = self._gen_sparse(sparse_dims, nnz, shape_b, dtype, device, coalesced)
3621            a.requires_grad_(True)
3622            b.requires_grad_(True)
3623
3624            c = torch.sparse.mm(a, b)
3625            c2 = c.to_dense().detach()
3626            c2 = torch.rand_like(c2)
3627            g = c2.sparse_mask(c.coalesce())
3628
3629            c.backward(g)
3630
3631            a_grad, b_grad = test_grad_dense(a, b, g)
3632
3633            # We convert grad to dense since dense and sparse mm
3634            # implementations handle materialized zeroes differently.
3635            self.assertEqual(a.grad.to_dense(), a_grad.to_dense())
3636            self.assertEqual(b.grad.to_dense(), b_grad.to_dense())
3637
3638        def test_sparse_matmul(sparse_dims, nnz, shape_a, shape_b):
3639            a, i_a, v_a = self._gen_sparse(sparse_dims, nnz, shape_a, dtype, device, coalesced)
3640            b, i_b, v_b = self._gen_sparse(sparse_dims, nnz, shape_b, dtype, device, coalesced)
3641
3642            # dense implementation
3643            r1 = ref_sparse_mm(a, b)
3644
3645            # cpp implementation
3646            r2 = torch.sparse.mm(a, b)
3647            self.assertEqual(r1, r2.to_dense())
3648
3649            # Check result is truly coalesced
3650            self.assertTrue(r2.is_coalesced() and is_coalesced_indices(r2))
3651
3652            if dtype in [torch.double, torch.cdouble]:
3653                a.requires_grad_(True)
3654                b.requires_grad_(True)
3655
3656                # check autograd support on sparse matmul
3657                def fn(D1, D2):
3658                    return torch.sparse.mm(D1, D2).to_dense()
3659
3660                if a.is_cuda:
3661                    # For cuda, `nondet_tol` is set with `1e-5`
3662                    # This is because cuSparse sometimes returns approximate zero values like `~e-323`
3663                    # TODO: Check this cuSparse issue.
3664                    # This happens when you do chain multiplication `torch.sparse.mm` operations
3665                    gradcheck(fn, (a, b), nondet_tol=1e-5, masked=True)
3666                else:
3667                    gradcheck(fn, (a, b), masked=True)
3668                grad_with_custom_sparsity_pattern_test_helper(sparse_dims, nnz, shape_a, shape_b)
3669
3670        def test_error_cases():
3671            def fn(sparse_dims, nnz, shape_a, shape_b):
3672                a, i_a, v_a = self._gen_sparse(sparse_dims, nnz, shape_a, dtype, device, coalesced)
3673                b, i_b, v_b = self._gen_sparse(sparse_dims, nnz, shape_b, dtype, device, coalesced)
3674                r2 = torch.sparse.mm(a, b)
3675
3676            # This is not a matrix
3677            self.assertRaises(RuntimeError, lambda: fn(3, 4, [2, 2, 2], [2, 2, 2]))
3678
3679            # Shapes does not
3680            self.assertRaisesRegex(RuntimeError,
3681                                   r"mat1 and mat2 shapes cannot be multiplied \(2x3 and 4x2\)",
3682                                   lambda: fn(2, 10, [2, 3], [4, 2]))
3683
3684            def different_dtypes():
3685                a, i_a, v_a = self._gen_sparse(2, 10, [2, 2], dtype, device, coalesced)
3686                b, i_b, v_b = self._gen_sparse(2, 10, [2, 2], dtype, device, coalesced)
3687                r2 = torch.sparse.mm(a.to(torch.float64), a.to(torch.float32))
3688
3689            self.assertRaisesRegex(RuntimeError, 'mat1 dtype Double does not match mat2 dtype Float', different_dtypes)
3690
3691        def test_backward_noncontiguous():
3692            # Sparse.mm backward used to wrong with non-contiguous grads,
3693            # see https://github.com/pytorch/pytorch/issues/102493.
3694            n_reps = 7
3695            for _ in range(n_reps):
3696                A = torch.eye(5).to_sparse().requires_grad_(True)
3697                B = torch.eye(5).to_sparse()
3698                out = torch.sparse.mm(A, B)
3699                out.coalesce().values().sum().backward()
3700                self.assertEqual(A.grad, A)
3701
3702        for n in range(2, 5):
3703            for m in range(2, 8):
3704                for p in range(2, 8):
3705                    test_sparse_matmul(2, 10, [n, m], [m, p])
3706
3707        test_sparse_matmul(2, 0, [0, 0], [0, 0])
3708        test_sparse_matmul(2, 0, [0, 10], [10, 0])
3709        test_error_cases()
3710        test_backward_noncontiguous()
3711
3712    @coalescedonoff
3713    @dtypes(torch.double)
3714    def test_assign(self, device, dtype, coalesced):
3715        def assign_to():
3716            a, i_a, v_a = self._gen_sparse(2, 5, [2, 3], dtype, device, coalesced)
3717            a[0] = 100
3718
3719        self.assertRaises(TypeError, assign_to)
3720
3721    @dtypes(torch.double, torch.cdouble)
3722    def test_full_broadcast_to(self, device, dtype):
3723        def can_broadcast(s0, s1):
3724            s0 = tuple(reversed(s0))
3725            s1 = tuple(reversed(s1))
3726            for i in range(len(s0)):
3727                if s0[i] != 1 and s0[i] != s1[i]:
3728                    return False
3729            return True
3730        sizes = (
3731            (), (1,), (2,), (1, 1), (3, 1), (3, 2), (4, 1, 1), (4, 3, 2)
3732        )
3733        for s0, s1 in itertools.combinations(sizes, r=2):
3734            t = make_tensor(s0, dtype=dtype, device=device, low=-9, high=9)
3735            for sparse_dims in range(1, len(s0) + 1):
3736                s = t.to_sparse(sparse_dims)
3737                if can_broadcast(s0, s1):
3738                    t_res = torch.broadcast_to(t, s1)
3739                    s_res = torch._sparse_broadcast_to(s, s1)
3740                    torch._validate_sparse_coo_tensor_args(s_res._indices(), s_res._values(), s_res.shape)
3741                    if s_res.is_coalesced():
3742                        # ensure that is_coalesced is estimated correctly
3743                        self.assertEqual(s_res, torch.sparse_coo_tensor(s_res._indices(), s_res._values(), s_res.shape).coalesce())
3744                    self.assertEqual(s_res.to_dense(), t_res)
3745                else:
3746                    with self.assertRaisesRegex(RuntimeError,
3747                                                r"The expanded size of the tensor \(\d\) "
3748                                                r"must match the existing size \(\d\)"):
3749                        torch._sparse_broadcast_to(s, s1)
3750
3751    @coalescedonoff
3752    @dtypes(torch.double, torch.cdouble)
3753    def test_sparse_broadcast_to(self, device, dtype, coalesced):
3754        def test(sparse_dims, nnz, with_size, new_size):
3755            x = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0]
3756            y = self.safeToDense(x)
3757            x1 = torch._sparse_broadcast_to(x, new_size)
3758            y1 = y.broadcast_to(new_size)
3759            self.assertEqual(self.safeToDense(x1), y1)
3760
3761        test(4, 6, [7, 3, 1, 3, 0], [7, 3, 4, 3, 0])
3762        test(4, 6, [7, 3, 1, 3, 0], [2, 7, 3, 1, 3, 0])
3763        test(4, 6, [7, 3, 1, 3, 1, 3], [7, 3, 1, 3, 2, 3])
3764        test(4, 6, [7, 3, 1, 3, 2, 1], [7, 3, 1, 3, 2, 3])
3765
3766    def _test_mul_skips(self, device, dtype, coalesced):
3767        skipTestIfUncoalesced = False
3768        # This case always coalesce inputs and that could lead to loss of precision,
3769        # hence it is inhibited for float16/bfloat16 by providing already coalesced tensors.
3770        if not coalesced and dtype in {torch.float16, torch.bfloat16}:
3771            skipTestIfUncoalesced = True
3772        # to_dense is problematic for boolean non-coalesced CUDA tensors
3773        # see https://github.com/pytorch/pytorch/issues/81648
3774        if not coalesced and dtype == torch.bool and torch.device(device).type == "cuda":
3775            skipTestIfUncoalesced = True
3776
3777        if skipTestIfUncoalesced:
3778            self.skipTest(f"Test with dtype={dtype}, device={device} runs only with coalesced inputs")
3779
3780    @coalescedonoff
3781    # NOTE: addcmul_out is not implemented for bool.
3782    @dtypes(*all_types_and_complex_and(torch.bfloat16, torch.float16))
3783    @precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2})
3784    def test_sparse_sparse_mul(self, device, dtype, coalesced):
3785        self._test_mul_skips(device, dtype, coalesced)
3786
3787        shape = (2, 3, 4, 10)
3788        nnz = 10
3789
3790        def check(self, x, y):
3791            res_sparse = x * y
3792            res_dense = x.to_dense() * y.to_dense()
3793            self.assertEqual(res_sparse.to_dense(), res_dense)
3794
3795        def check_empty(sparse_shape, nnz, dense_shape, coalesce):
3796            from itertools import product
3797            for nnz_val, shape_suffix in product((nnz, 0), ((), (0,))):
3798                empty_sparse_shape = sparse_shape + shape_suffix
3799                empty_dense_shape = dense_shape + shape_suffix
3800                x = self._gen_sparse(sparse_dim, nnz_val, empty_sparse_shape, dtype, device, coalesce)[0]
3801                check(self, x, x)
3802
3803        # TODO: uncomment once backward is implemented for sparse tensors that broadcast in dense dims.
3804        # def check_autograd(x, y):
3805        #     if dtype in {torch.double, torch.cdouble}:
3806        #         xa = x.detach().clone().requires_grad_(True)
3807        #         ya = y.detach().clone().requires_grad_(True)
3808        #         gradcheck(lambda a, b: (a * b).to_dense(), (xa, ya), masked=True)
3809        #         gradcheck(lambda a, b: (a * b).to_dense(), (ya, xa), masked=True)
3810
3811        for dim in range(len(shape) + 1):
3812            sub_shape = shape[dim:]
3813            sparse_dim = len(sub_shape) // 2
3814
3815            check_empty(sub_shape, nnz, shape, coalesced)
3816
3817            x = self._gen_sparse(sparse_dim, nnz, sub_shape, dtype, device, coalesced)[0]
3818            y = self._gen_sparse(sparse_dim, nnz, sub_shape, dtype, device, coalesced)[0]
3819            check(self, x, y)
3820            # TODO: uncomment once supported
3821            # check_autograd(x, y)
3822
3823            # check broadcasting in dense dims
3824            for d in range(sparse_dim, len(sub_shape)):
3825                new_shape = sub_shape[:d] + (1,) + sub_shape[d + 1:]
3826                y = self._gen_sparse(sparse_dim, nnz, new_shape, dtype, device, coalesced)[0]
3827                check(self, x, y)
3828                # TODO: uncomment once supported
3829                # check_autograd(x, y)
3830
3831    @coalescedonoff
3832    @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))
3833    @precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2})
3834    def test_sparse_dense_mul(self, device, dtype, coalesced):
3835        self._test_mul_skips(device, dtype, coalesced)
3836
3837        shape = (2, 3, 4, 10)
3838        nnz = 10
3839
3840        def check(self, s, d):
3841            res = d * s
3842
3843            # check commutativity
3844            self.assertEqual(res, s * d)
3845
3846            # check correctness
3847            self.assertEqual(res.to_dense(), s.to_dense() * d)
3848
3849            # check in-placeness for dense
3850            if d.dim() >= s.dim():
3851                dc = d.clone()
3852                self.assertEqual(d.mul_(s), dc.mul_(s.to_dense()))
3853
3854            # check in-placeness for sparse
3855            if s.dim() >= d.dim():
3856                # for sparse
3857                sc = s.clone()
3858                self.assertEqual(s.mul_(d).to_dense(), sc.to_dense().mul_(d))
3859
3860        for dim in range(len(shape) + 1):
3861            sub_shape = shape[dim:]
3862            sparse_dim = len(sub_shape) // 2
3863
3864            def check_empty(sparse_shape, nnz, dense_shape, coalesce):
3865                from itertools import product
3866                for nnz_val, shape_suffix in product((nnz, 0), ((), (0,))):
3867                    empty_sparse_shape = sparse_shape + shape_suffix
3868                    empty_dense_shape = dense_shape + shape_suffix
3869                    s = self._gen_sparse(sparse_dim, nnz_val, empty_sparse_shape, dtype, device, coalesce)[0]
3870                    d = make_tensor(empty_dense_shape, dtype=dtype, device=device)
3871                    check(self, s, d)
3872
3873            # check scalar multiplication
3874            s = self._gen_sparse(sparse_dim, nnz, sub_shape, dtype, device, coalesced)[0]
3875            for scalar in (True, 1, 1.0):
3876                res_sparse_right = s * scalar
3877                res_sparse_left = scalar * s
3878                res_dense = s.to_dense() * scalar
3879                # check correctness and dtype
3880                self.assertEqual(s.to(res_sparse_right.dtype), res_sparse_right)
3881                self.assertEqual(res_sparse_right, res_sparse_left)
3882                self.assertEqual(res_sparse_right.dtype, res_dense.dtype)
3883                self.assertEqual(res_sparse_left.dtype, res_dense.dtype)
3884                # check scalar as 0-dim sparse tensor
3885                tscalar = torch.tensor(scalar, device=device)
3886                sscalar = tscalar.to_sparse()
3887                res_sparse_right = s * sscalar
3888                res_sparse_left = sscalar * s
3889                self.assertEqual(res_sparse_right, res_sparse_left)
3890                self.assertEqual(s.to(res_sparse_right.dtype), res_sparse_right)
3891
3892            # check non-coalesced 0-dim scalar
3893            # we skip torch.bool because for such tensors
3894            # coalesce.to_dense != to_dense
3895            if dtype == torch.bool:
3896                return
3897
3898            for scalar_dtype in (int, float):
3899                scalar = scalar_dtype(1)
3900                idx = torch.tensor([], device=device).reshape(0, 2)
3901                val = torch.tensor([scalar, scalar], device=device)
3902                sscalar = torch.sparse_coo_tensor(idx, val, ())
3903                res_dense = s.to_dense() * sscalar.to_dense()
3904                self.assertEqual((s * sscalar).to_dense(), res_dense)
3905                self.assertEqual((sscalar * s).to_dense(), res_dense)
3906
3907            # Case 1: sparse broadcasts over dense
3908            s = self._gen_sparse(sparse_dim, nnz, sub_shape, dtype, device, coalesced)[0]
3909            d = make_tensor(shape, dtype=dtype, device=device)
3910            check(self, s, d)
3911            check_empty(sub_shape, nnz, shape, coalesced)
3912
3913            # Case 2: dense broadcasts over sparse
3914            s = self._gen_sparse(3, nnz, shape, dtype, device, coalesced)[0]
3915            d = make_tensor(sub_shape, dtype=dtype, device=device)
3916            check(self, s, d)
3917            check_empty(shape, nnz, sub_shape, coalesced)
3918
3919    @unittest.skipIf(not TEST_NUMPY, "NumPy is not available")
3920    @onlyCPU
3921    @dtypes(*all_types_and_complex_and(torch.bool))
3922    def test_sparse_spdiags(self, device, dtype):
3923
3924        make_diags = functools.partial(make_tensor, dtype=dtype, device=device)
3925        make_offsets = functools.partial(torch.tensor, dtype=torch.long, device=device)
3926
3927        if TEST_SCIPY:
3928            def reference(diags, offsets, shape):
3929                return scipy.sparse.spdiags(diags, offsets, *shape).toarray()
3930
3931        else:
3932            def reference(diags, offsets, shape):
3933                result = torch.zeros(shape, dtype=dtype, device=device)
3934                for i, off in enumerate(offsets):
3935                    res_view = result.diagonal(off)
3936                    data = diags[i]
3937                    if off > 0:
3938                        data = data[off:]
3939
3940                    m = min(res_view.shape[0], data.shape[0])
3941                    res_view[:m] = data[:m]
3942                return result
3943
3944        def check_valid(diags, offsets, shape, layout=None):
3945            ref_out = reference(diags, offsets, shape)
3946            out = torch.sparse.spdiags(diags, offsets, shape, layout=layout)
3947            if layout is None:
3948                ex_layout = torch.sparse_coo
3949            else:
3950                ex_layout = layout
3951            out_dense = out.to_dense()
3952            self.assertTrue(out.layout == ex_layout, f"Output layout {out.layout} expected {ex_layout}")
3953            self.assertEqual(out_dense, ref_out, f"Result:\n{out_dense} does not match reference:\n{ref_out}")
3954
3955        def check_invalid(args, error):
3956            with self.assertRaisesRegex(RuntimeError, error):
3957                torch.sparse.spdiags(*args)
3958
3959        def valid_cases():
3960            # some normal cases
3961            yield (make_diags((1, 5)), make_offsets([0]), (5, 5))
3962            yield (make_diags((3, 3)), make_offsets([-1, 0, 1]), (4, 4))
3963            # noncontigous diags
3964            yield (make_diags((5, 4), noncontiguous=True), make_offsets([-1, 1, 0, 2, -2]), (5, 5))
3965            # noncontigous offsets
3966            yield (make_diags((3, 4)), make_offsets([1, -1, 0, -2, 2])[::2], (5, 5))
3967            # noncontigous diags + offsets
3968            yield (make_diags((3, 4), noncontiguous=True), make_offsets([1, -1, 0, -2, 2])[::2], (5, 5))
3969            # correct dimensionality, 2d, 2d , and shapes match, but the number of diagonals is zero
3970            yield (make_diags((0, 3)), make_offsets([]), (3, 3))
3971            # forward rotation of upper diagonals
3972            yield (make_diags((3, 8)), make_offsets([1, 2, 3]), (4, 4))
3973            # rotation exausts input space to read from
3974            yield (make_diags((2, 3)), make_offsets([2, 1]), (3, 3))
3975            # Simple cases repeated with special output format
3976            yield (make_diags((1, 5)), make_offsets([0]), (5, 5), torch.sparse_csc)
3977            yield (make_diags((3, 3)), make_offsets([-1, 0, 1]), (4, 4), torch.sparse_csr)
3978            # vector diags
3979            yield (make_diags((3, )), make_offsets([1]), (4, 4))
3980            # Scalar offset
3981            yield (make_diags((1, 3)), make_offsets(2), (4, 4))
3982            # offsets out of range
3983            yield (make_diags((1, 3)), make_offsets([3]), (3, 3))
3984            yield (make_diags((1, 3)), make_offsets([-3]), (3, 3))
3985
3986        for case in valid_cases():
3987            check_valid(*case)
3988
3989        def invalid_cases():
3990            yield (make_diags((1, 3)), make_offsets([0]), (3, 2, 3)), "Output shape must be 2d"
3991            yield (make_diags((2, 3)), make_offsets([[1, 2], [0, 3]]), (3, 3)), "Offsets must be scalar or vector"
3992            yield (make_diags((3, 2, 3)), make_offsets([0, 1, 2]), (4, 4)), "Diagonals must be vector or matrix"
3993            yield (make_diags((3, 3)), make_offsets([-1, 0]), (3, 3)), \
3994                r"Number of diagonals \(\d\) does not match the number of offsets \(\d\)"
3995            yield (make_diags((5,)), make_offsets([0, 1, 2, 3, 4]), (3, 3)), \
3996                r"Number of diagonals \(\d\) does not match the number of offsets \(\d\)"
3997            yield (make_diags((2, 2)), make_offsets([-1, 0]), (2, 3), torch.strided), \
3998                r"Only output layouts \(\w+, \w+, \w+\) are supported, got \w+"
3999            yield (make_diags((2, 5)), make_offsets([0, 0]), (5, 5)), "Offset tensor contains duplicate values"
4000            yield (make_diags((1, 5)), make_offsets([0]).to(torch.int32), (5, 5)), r"Offset Tensor must have dtype Long but got \w+"
4001
4002
4003        for case, error_regex in invalid_cases():
4004            check_invalid(case, error_regex)
4005
4006    def test_small_nnz_coalesced(self):
4007        # creating a coo tensor with nnz == 0 is always coalesced
4008        self.assertTrue(torch.sparse_coo_tensor([[], []], [], (2, 2)).is_coalesced())
4009        # same for a coo tensor with only 1 nnz
4010        self.assertTrue(torch.sparse_coo_tensor([[0], [0]], [1], (2, 2)).is_coalesced())
4011        # two or more nnz coalesced is false as it can't be verified without an expensive check
4012        self.assertFalse(torch.sparse_coo_tensor([[0, 0], [0, 0]], [1, 2], (2, 2)).is_coalesced())
4013        # even if there are no duplicates
4014        self.assertFalse(torch.sparse_coo_tensor([[0, 1], [0, 1]], [1, 2], (2, 2)).is_coalesced())
4015
4016    @coalescedonoff
4017    @dtypes(*all_types_and_complex_and(torch.bool))
4018    def test_sum(self, device, dtype, coalesced):
4019        def run_test(shape, nnz):
4020            a = self._gen_sparse(2, nnz, shape, dtype, device, coalesced)[0]
4021            self.assertEqual(a.sum(), a._values().sum())
4022            if dtype.is_floating_point or dtype.is_complex:
4023                a.requires_grad_(True)
4024                a_inter = a.sum()
4025                a_inter.abs().backward()
4026                with torch.no_grad():
4027                    self.assertEqual(a.grad, torch.ones(shape, dtype=dtype, device=device) * torch.sgn(a_inter))
4028        for shape in [(10, 5), (10, 10)]:
4029            run_test(shape, 0)
4030            run_test(shape, max(shape))
4031            run_test(shape, shape[0] * shape[1])
4032
4033
4034class TestSparseOneOff(TestCase):
4035    @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
4036    def test_cuda_from_cpu(self):
4037        with self.assertRaisesRegex(
4038                RuntimeError,
4039                "Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!"):
4040            torch.sparse_coo_tensor(torch.zeros(1, 4).long().cuda(),
4041                                    torch.randn(4, 4, 4),
4042                                    [3, 4, 4])
4043
4044        with self.assertRaisesRegex(
4045                RuntimeError,
4046                "Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!"):
4047            torch.sparse_coo_tensor(torch.zeros(1, 4).long().cuda(),
4048                                    torch.randn(4, 4, 4, 0),
4049                                    [3, 4, 4, 0])
4050
4051        with self.assertRaisesRegex(
4052                RuntimeError,
4053                "Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!"):
4054            torch.sparse_coo_tensor(torch.empty(1, 0).long().cuda(),
4055                                    torch.randn(0, 4, 4, 0),
4056                                    [0, 4, 4, 0])
4057
4058    @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
4059    def test_cuda_sparse_cpu_dense_add(self):
4060        x = torch.zeros(3, 4, 4)
4061        sparse_y = torch.sparse_coo_tensor(torch.zeros(1, 4).long().cuda(),
4062                                           torch.randn(4, 4, 4).cuda(),
4063                                           [3, 4, 4])
4064        with self.assertRaisesRegex(RuntimeError, "add: expected 'self' to be a CUDA tensor, but got a CPU tensor"):
4065            x + sparse_y
4066
4067        x = torch.zeros(3, 4, 4, 0)
4068        sparse_y = torch.sparse_coo_tensor(torch.zeros(1, 4).long().cuda(),
4069                                           torch.randn(4, 4, 4, 0).cuda(),
4070                                           [3, 4, 4, 0])
4071        with self.assertRaisesRegex(RuntimeError, "add: expected 'self' to be a CUDA tensor, but got a CPU tensor"):
4072            x + sparse_y
4073
4074        x = torch.zeros(0, 4, 4, 0)
4075        sparse_y = torch.sparse_coo_tensor(torch.empty(1, 0).long().cuda(),
4076                                           torch.randn(0, 4, 4, 0).cuda(),
4077                                           [0, 4, 4, 0])
4078        with self.assertRaisesRegex(RuntimeError, "add: expected 'self' to be a CUDA tensor, but got a CPU tensor"):
4079            x + sparse_y
4080
4081
4082def _sparse_to_dense(tensor):
4083    if tensor.dtype != torch.bool:
4084        return tensor.to_dense(masked_grad=True)
4085
4086    # to_dense uses coalesce which isn't implemented for bool
4087    return tensor.to(torch.int8).to_dense().to(torch.bool)
4088
4089
4090_sparse_unary_ops = ops(sparse_unary_ufuncs, dtypes=OpDTypes.supported,
4091                        allowed_dtypes=all_types_and_complex())
4092class TestSparseUnaryUfuncs(TestCase):
4093    exact_dtype = True
4094
4095
4096    @_sparse_unary_ops
4097    def test_sparse_consistency(self, device, dtype, op):
4098        sample = first_sample(self, op.sample_inputs(device, dtype))
4099        assert isinstance(sample.input, torch.Tensor)
4100
4101        expected = op(sample.input, *sample.args, **sample.kwargs)
4102        assert torch.is_tensor(expected)
4103        output = op(sample.input.to_sparse(), *sample.args, **sample.kwargs)
4104        assert torch.is_tensor(output)
4105        self.assertEqual(_sparse_to_dense(output), expected)
4106
4107    @_sparse_unary_ops
4108    def test_out(self, device, dtype, op):
4109        if not op.supports_out:
4110            self.skipTest("Skipped! Out not supported")
4111
4112        sample = first_sample(self, op.sample_inputs(device, dtype))
4113        sample.input = sample.input.to_sparse()
4114        expect = op(sample.input, *sample.args, **sample.kwargs)
4115
4116        out = torch.sparse_coo_tensor(sample.input.shape, device=device,
4117                                      dtype=expect.dtype)
4118        op(sample.input, *sample.args, **sample.kwargs, out=out)
4119        self.assertEqual(out, expect)
4120
4121    @_sparse_unary_ops
4122    def test_inplace(self, device, dtype, op):
4123        if op.inplace_variant is None:
4124            self.skipTest("Skipped! Out not supported")
4125
4126        sample = first_sample(self, op.sample_inputs(device, dtype))
4127        sample.input = sample.input.to_sparse().coalesce()
4128        expect = op(sample.input, *sample.args, **sample.kwargs)
4129
4130        if not torch.can_cast(expect.dtype, dtype):
4131            with self.assertRaisesRegex(RuntimeError, "result type .* can't be cast to"):
4132                op.inplace_variant(sample.input, *sample.args, **sample.kwargs)
4133            return
4134
4135        actual = op.inplace_variant(sample.input, *sample.args, **sample.kwargs)
4136        self.assertIs(actual, sample.input)
4137        self.assertEqual(actual, expect)
4138
4139    @_sparse_unary_ops
4140    def test_sparse_zero_dims(self, device, dtype, op):
4141        # test 0x0 sparse_coo_tensor
4142        indices = torch.empty(2, 0, dtype=torch.int64)
4143        values = torch.empty(0, dtype=dtype)
4144        sparse_0x0 = torch.sparse_coo_tensor(indices, values, (0, 0))
4145        expected = torch.sparse_coo_tensor(indices, op(values), (0, 0))
4146        actual = op(sparse_0x0)
4147        self.assertEqual(expected, actual)
4148
4149    @_sparse_unary_ops
4150    def test_sparse_zeros(self, device, dtype, op):
4151        samples = op.sample_inputs(device, dtype)
4152
4153        zero_input = torch.zeros((), device=device, dtype=dtype)
4154        sparse_input = torch.sparse_coo_tensor((), dtype=dtype, device=device)
4155
4156        expect = op(zero_input)
4157        actual = op(sparse_input)
4158        self.assertEqual(expect, _sparse_to_dense(actual))
4159
4160    @ops(sparse_unary_ufuncs, dtypes=OpDTypes.supported,
4161         allowed_dtypes=[torch.double, torch.cdouble])
4162    def test_sparse_fn_grad(self, device, dtype, op):
4163        if not op.supports_autograd:
4164            self.skipTest("Skipped! Op doesn't support autograd")
4165
4166        for sample in op.sample_inputs(device, dtype):
4167            sparse_input = sample.input.to_sparse().detach().requires_grad_(True)
4168
4169            def fn(x):
4170                return _sparse_to_dense(
4171                    op(x, *sample.args, **sample.kwargs))
4172
4173            self.assertTrue(gradcheck(
4174                fn,
4175                (sparse_input,),
4176                check_batched_grad=False,
4177                check_grad_dtypes=True,
4178                nondet_tol=op.gradcheck_nondet_tol,
4179                fast_mode=op.gradcheck_fast_mode,
4180                masked=True))
4181
4182
4183class TestSparseMaskedReductions(TestCase):
4184    exact_dtype = True
4185
4186    fp16_low_precision_list = {
4187        'masked.prod',
4188    }
4189
4190    @ops(sparse_masked_reduction_ops)
4191    def test_future_empty_dim(self, device, dtype, op):
4192        """Currently, `dim=()` in reductions operations means "reduce over
4193        all dimensions" while in future, it will read "no reduce". See
4194        https://github.com/pytorch/pytorch/issues/29137
4195
4196        For sparse masked reductions, we'll implement the current behavior.
4197
4198        For testing, we'll use samples with `dim=0` and map it to
4199        `dim=()` until
4200        torch.testing._internal.common_methods_invocations._generate_reduction_kwargs
4201        is made to generate samples with `dim=()` for non-scalar
4202        inputs. With this and after gh-29137 is resolved, this test
4203        can be deleted. See also `torch.masked._canonical_dim`
4204        implementation about changing the `dim=()` behavior.
4205        """
4206
4207        samples = op.sample_inputs_func(op, device, dtype, requires_grad=False)
4208        op_name = op.name.replace('masked.', '')
4209        for sample_input in samples:
4210            if sample_input.kwargs.get('dim') != 0:
4211                continue
4212            sample_input_kwargs = dict(sample_input.kwargs)
4213            sample_input_kwargs['dim'] = ()    # reduce over all dimensions
4214
4215            t = sample_input.input
4216            mask = sample_input_kwargs.get('mask')
4217            if mask is None and op_name in {'prod', 'amax', 'amin'}:
4218                # FIXME: for now reductions with non-zero reduction identity and
4219                # unspecified mask are not supported for sparse COO
4220                # tensors, see torch.masked.prod implementation
4221                # for details.
4222                continue
4223            sparse_op_kwargs = dict(sample_input_kwargs)
4224            actual = op(t.to_sparse(), *sample_input.args, **sample_input_kwargs)
4225            self.assertEqual(actual.layout, torch.sparse_coo)
4226
4227            expected = op(t, *sample_input.args, **sample_input_kwargs).to_sparse()
4228            atol = None
4229            rtol = None
4230            if op.name in self.fp16_low_precision_list and dtype == torch.half:
4231                atol = 1e-5
4232                rtol = 2e-3
4233            self.assertEqual(actual, expected, atol=atol, rtol=rtol)
4234
4235
4236class TestSparseMeta(TestCase):
4237    exact_dtype = True
4238
4239    def _test_meta_sparse_coo(self, dtype):
4240        r = torch.empty(4, 4, layout=torch.sparse_coo, device='meta', dtype=dtype)
4241        self.assertTrue(r.is_meta)
4242        self.assertEqual(r.device.type, "meta")
4243        r2 = torch.empty_like(r)
4244        self.assertTrue(r2.is_meta)
4245        self.assertEqual(r, r2)
4246        r3 = torch.sparse_coo_tensor(size=(4, 4), device='meta', dtype=dtype)
4247        self.assertTrue(r3.is_meta)
4248        self.assertEqual(r, r3)
4249        r.sparse_resize_((4, 4), 1, 1)
4250        r.sparse_resize_and_clear_((4, 4, 4), 2, 1)
4251        self.assertEqual(r.sparse_dim(), 2)
4252        self.assertEqual(r.dense_dim(), 1)
4253        self.assertEqual(r._dimV(), 1)
4254        self.assertEqual(r._nnz(), 0)
4255        # nnz zero sparse tensors should always be coalesced at creation
4256        self.assertEqual(r.is_coalesced(), True)
4257        # but we can force them into the uncoalesed state
4258        r._coalesced_(False)
4259        self.assertEqual(r.is_coalesced(), False)
4260        # return the coalesced state for indices/values access
4261        r._coalesced_(True)
4262        # TODO: this sort of aliasing will need to be handled by
4263        # functionalization
4264        self.assertEqual(r._indices(), torch.empty(2, 0, device='meta', dtype=torch.int64))
4265        self.assertEqual(r._values(), torch.empty(0, 4, device='meta', dtype=dtype))
4266        self.assertEqual(r.indices(), torch.empty(2, 0, device='meta', dtype=torch.int64))
4267        self.assertEqual(r.values(), torch.empty(0, 4, device='meta', dtype=dtype))
4268
4269    def _test_meta_sparse_compressed(self, dtype, layout, batchsize, densesize):
4270        index_dtype = torch.int64
4271        blocksize = (2, 3) if layout in {torch.sparse_bsr, torch.sparse_bsc} else ()
4272        sparsesize = (4, 6)
4273        nnz = 0
4274
4275        shape = (*batchsize, *sparsesize, *densesize)
4276        compressed_dim = 0 if layout in {torch.sparse_csr, torch.sparse_bsr} else 1
4277        nof_compressed_indices = (sparsesize[compressed_dim] // blocksize[compressed_dim] + 1 if blocksize
4278                                  else sparsesize[compressed_dim] + 1)
4279        compressed_indices = torch.empty((*batchsize, nof_compressed_indices), device='meta', dtype=index_dtype)
4280        plain_indices = torch.empty((*batchsize, nnz), device='meta', dtype=index_dtype)
4281
4282        values = torch.empty((*batchsize, nnz, *blocksize, *densesize), device='meta', dtype=dtype)
4283        r = torch.sparse_compressed_tensor(
4284            compressed_indices,
4285            plain_indices,
4286            values,
4287            shape,
4288            layout=layout
4289        )
4290        self.assertTrue(r.is_meta)
4291        self.assertEqual(r.device.type, "meta")
4292
4293        self.assertEqual(r.sparse_dim(), 2)
4294        self.assertEqual(r.dense_dim(), len(densesize))
4295        self.assertEqual(r._nnz(), nnz)
4296        batch_dims = r.ndim - r.sparse_dim() - r.dense_dim()
4297        r_blocksize = r.values().shape[batch_dims + 1: batch_dims + 1 + len(blocksize)]
4298        self.assertEqual(r_blocksize, blocksize)
4299
4300        r_compressed_indices = r.crow_indices() if layout in {torch.sparse_csr, torch.sparse_bsr} else r.ccol_indices()
4301        r_plain_indices = r.col_indices() if layout in {torch.sparse_csr, torch.sparse_bsr} else r.row_indices()
4302
4303        self.assertEqual(r_compressed_indices,
4304                         torch.empty((*batchsize, nof_compressed_indices), device='meta', dtype=index_dtype))
4305        self.assertEqual(r_plain_indices, torch.empty((*batchsize, nnz), device='meta', dtype=index_dtype))
4306        self.assertEqual(r.values(), torch.empty((*batchsize, nnz, *blocksize, *densesize), device='meta', dtype=dtype))
4307
4308        r2 = torch.empty_like(r)
4309        self.assertTrue(r2.is_meta)
4310        self.assertEqual(r2, r)
4311
4312        if layout in {torch.sparse_csr, torch.sparse_csc}:
4313            r3 = torch.empty((*batchsize, *sparsesize), dtype=dtype, layout=layout, device="meta")
4314            self.assertTrue(r3.is_meta)
4315            if not densesize:
4316                # dense dimensions cannot be specified for torch.empty
4317                self.assertEqual(r3, r)
4318
4319    @all_sparse_layouts('layout', include_strided=False)
4320    @parametrize("dtype", [torch.float64])
4321    def test_meta(self, dtype, layout):
4322        if layout is torch.sparse_coo:
4323            self._test_meta_sparse_coo(dtype)
4324        else:
4325            for batchsize, densesize in itertools.product([(), (2,)], [(), (3,)]):
4326                self._test_meta_sparse_compressed(dtype, layout, batchsize, densesize)
4327
4328    def _test_print_meta_data(self, dtype, layout, batchsize, sparsesize, densesize):
4329        index_dtype = torch.int64
4330        nnz = 0
4331        blocksize = (2, 3) if layout in {torch.sparse_bsr, torch.sparse_bsc} else ()
4332        shape = (*batchsize, *sparsesize, *densesize)
4333        values = torch.empty((*batchsize, nnz, *blocksize, *densesize), device='meta', dtype=dtype)
4334        if layout is torch.sparse_coo:
4335            indices = torch.empty((len(sparsesize), nnz), device='meta', dtype=index_dtype)
4336            x = torch.sparse_coo_tensor(indices, values, shape)
4337        else:
4338            compressed_dim = 0 if layout in {torch.sparse_csr, torch.sparse_bsr} else 1
4339            nof_compressed_indices = (sparsesize[compressed_dim] // blocksize[compressed_dim] + 1 if blocksize
4340                                      else sparsesize[compressed_dim] + 1)
4341            compressed_indices = torch.empty((*batchsize, nof_compressed_indices), device='meta', dtype=index_dtype)
4342            plain_indices = torch.empty((*batchsize, nnz), device='meta', dtype=index_dtype)
4343            x = torch.sparse_compressed_tensor(
4344                compressed_indices,
4345                plain_indices,
4346                values,
4347                shape,
4348                layout=layout
4349            )
4350
4351        printed = []
4352        printed.append(f"########## {dtype}/{index_dtype}/size={batchsize}+{sparsesize}+{blocksize}+{densesize} ##########")
4353        printed.append("# sparse meta tensor")
4354        printed.append(str(x))
4355
4356        return printed
4357
4358    @all_sparse_layouts('layout', include_strided=False)
4359    @parametrize("dtype", [torch.float64])
4360    def test_print_meta(self, dtype, layout):
4361        printed = []
4362        for batchsize, sparsesize, densesize in itertools.product(
4363                [(), (2,)], [(4, 6), (3, 5, 7)], [(), (3,)]
4364        ):
4365            if layout is torch.sparse_coo and batchsize:
4366                # COO tensors don't have batch dimensions
4367                continue
4368            if layout is not torch.sparse_coo and len(sparsesize) != 2:
4369                # CSR/CSC/BSR/BSC tensors must have 2 sparse dimensions
4370                continue
4371            printed += self._test_print_meta_data(dtype, layout, batchsize, sparsesize, densesize)
4372
4373        orig_maxDiff = self.maxDiff
4374        self.maxDiff = None
4375        try:
4376            self.assertExpected('\n'.join(printed))
4377            self.maxDiff = orig_maxDiff
4378        except Exception:
4379            self.maxDiff = orig_maxDiff
4380            raise
4381
4382    def assertEqualMeta(self, x, y, expected_nnz):
4383        self.assertEqual(x.layout, y.layout)
4384        self.assertEqual(x.shape, y.shape)
4385        self.assertEqual(x.dtype, y.dtype)
4386        self.assertEqual(x.sparse_dim(), y.sparse_dim())
4387        self.assertEqual(x.dense_dim(), y.dense_dim())
4388
4389        def assertEqualAttrs(x, y, expected_shape):
4390            self.assertEqual(x.shape, expected_shape)
4391            self.assertEqual(x.dtype, y.dtype)
4392            self.assertEqual(x.layout, y.layout)
4393            if not x.is_meta:
4394                self.assertEqual(x.device, y.device)
4395
4396        if x.layout is torch.sparse_coo:
4397            assertEqualAttrs(x._indices(), y._indices(), (*y._indices().shape[:-1], expected_nnz))
4398            assertEqualAttrs(x._values(), y._values(), (expected_nnz, *y._values().shape[1:]))
4399        elif x.layout in {torch.sparse_csr, torch.sparse_bsr}:
4400            assertEqualAttrs(x.crow_indices(), y.crow_indices(), y.crow_indices().shape)
4401            assertEqualAttrs(x.col_indices(), y.col_indices(), (*y.col_indices().shape[:-1], expected_nnz))
4402            batch_dim = x.col_indices().ndim - 1
4403            values_shape = (*y.values().shape[:batch_dim], expected_nnz, *y.values().shape[batch_dim + 1:])
4404            self.assertEqual(x.values().layout, y.values().layout)
4405            self.assertEqual(x.values().dtype, y.values().dtype)
4406            self.assertEqual(x.values().shape, values_shape)
4407        elif x.layout in {torch.sparse_csc, torch.sparse_bsc}:
4408            assertEqualAttrs(x.ccol_indices(), y.ccol_indices(), y.ccol_indices().shape)
4409            assertEqualAttrs(x.row_indices(), y.row_indices(), (*y.row_indices().shape[:-1], expected_nnz))
4410            batch_dim = x.row_indices().ndim - 1
4411            values_shape = (*y.values().shape[:batch_dim], expected_nnz, *y.values().shape[batch_dim + 1:])
4412            self.assertEqual(x.values().layout, y.values().layout)
4413            self.assertEqual(x.values().dtype, y.values().dtype)
4414            self.assertEqual(x.values().shape, values_shape)
4415
4416    @all_sparse_layouts('layout', include_strided=False)
4417    @parametrize("dtype", [torch.float64])
4418    def test_to_meta(self, dtype, layout):
4419        index_dtype = torch.int64
4420        device = 'cpu'
4421        for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype):
4422            m = t.to(device="meta")
4423            self.assertEqual(m.device.type, "meta")
4424            self.assertEqualMeta(m, t, 0)
4425
4426    @all_sparse_layouts('layout', include_strided=False)
4427    @parametrize("dtype", [torch.float64])
4428    def test_zeros_like_meta(self, dtype, layout):
4429        index_dtype = torch.int64
4430        device = 'cpu'
4431        for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype):
4432            m = torch.zeros_like(t, device="meta")
4433            self.assertEqual(m.device.type, "meta")
4434            self.assertEqualMeta(m, t, 0)
4435
4436    @all_sparse_layouts('layout', include_strided=False)
4437    @parametrize("dtype", [torch.float64])
4438    def test_fake(self, dtype, layout):
4439        from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
4440        fake_mode = FakeTensorMode()
4441        index_dtype = torch.int64
4442        device = 'cpu'
4443        for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype):
4444            f = FakeTensor.from_tensor(t, fake_mode)
4445            self.assertIsInstance(f, FakeTensor)
4446            self.assertEqualMeta(f, t, 0)
4447
4448            d = f.detach()
4449            self.assertIsInstance(d, FakeTensor)
4450            self.assertEqualMeta(d, t, 0)
4451
4452    @all_sparse_layouts('layout', include_strided=False)
4453    @parametrize("dtype", [torch.float64])
4454    def test_zeros_like_fake(self, dtype, layout):
4455        from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
4456        from torch.utils._mode_utils import no_dispatch
4457        fake_mode = FakeTensorMode()
4458        index_dtype = torch.int64
4459        device = 'cpu'
4460        for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype):
4461            f = FakeTensor.from_tensor(t, fake_mode)
4462            expected = torch.zeros_like(t)
4463            with no_dispatch():
4464                result = torch.zeros_like(f, device=f.fake_device)
4465            self.assertEqual(result, expected)
4466            self.assertEqualMeta(result, expected, 0)
4467
4468    @all_sparse_layouts('layout', include_strided=False)
4469    @parametrize("dtype", [torch.float64])
4470    def test_sum_meta(self, dtype, layout):
4471        device = 'cpu'
4472        index_dtype = torch.int64
4473        for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype):
4474            m = t.to(device='meta')
4475            r = torch.sum(m)
4476            expected = torch.sum(t).to(device="meta")
4477            self.assertTrue(r.is_meta)
4478            self.assertEqualMeta(r, expected, 0)
4479
4480    @all_sparse_layouts('layout', include_strided=False)
4481    @parametrize("dtype", [torch.float64])
4482    def test_add_meta(self, dtype, layout):
4483        device = 'cpu'
4484        index_dtype = torch.int64
4485        for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype):
4486            expected = torch.add(t, t).to(device='meta')
4487            m = t.to(device='meta')
4488            r = torch.add(m, m)
4489            self.assertEqualMeta(r, expected, 0)
4490
4491
4492class _SparseDataset(torch.utils.data.Dataset):
4493    # An utility class used in TestSparseAny.test_dataloader method.
4494
4495    def __init__(self, sparse_tensors):
4496        self.sparse_tensors = sparse_tensors
4497
4498    def __len__(self):
4499        return len(self.sparse_tensors)
4500
4501    def __getitem__(self, index):
4502        return self.sparse_tensors[index]
4503
4504
4505class TestSparseAny(TestCase):
4506
4507    @onlyCPU
4508    @all_sparse_layouts('layout', include_strided=False)
4509    @torch.sparse.check_sparse_tensor_invariants(enable=False)
4510    def test_check_sparse_tensor_invariants(self, layout):
4511
4512        if layout is torch.sparse_coo:
4513
4514            def create_invalid_tensor(check_invariants=None):
4515                shape = (2, 2)
4516                invalid_indices = torch.tensor([[0], [3]])  # column index is out of range
4517                values = torch.tensor([1])
4518                if check_invariants is None:
4519                    return torch.sparse_coo_tensor(invalid_indices, values, shape)
4520                else:
4521                    return torch.sparse_coo_tensor(invalid_indices, values, shape, check_invariants=check_invariants)
4522
4523            expected_exception_message = 'size is inconsistent with indices: for dim 1, size is 2 but found index 3'
4524
4525        elif layout in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}:
4526
4527            def create_invalid_tensor(check_invariants=None):
4528                shape = (2, 2)
4529                compressed_indices = torch.tensor([0, 0, 1])
4530                invalid_plain_indices = torch.tensor([3])  # index is out of range
4531                if layout in {torch.sparse_bsr, torch.sparse_bsc}:
4532                    values = torch.tensor([[[1]]])
4533                else:
4534                    values = torch.tensor([1])
4535                if check_invariants is None:
4536                    return torch.sparse_compressed_tensor(compressed_indices, invalid_plain_indices, values, shape, layout=layout)
4537                else:
4538                    return torch.sparse_compressed_tensor(compressed_indices, invalid_plain_indices, values, shape, layout=layout,
4539                                                          check_invariants=check_invariants)
4540
4541            if layout in {torch.sparse_csr, torch.sparse_bsr}:
4542                expected_exception_message = r'`0 <= col_indices < ncols` is not satisfied.'
4543            else:
4544                expected_exception_message = r'`0 <= row_indices < nrows` is not satisfied.'
4545
4546        else:
4547            raise NotImplementedError(layout)
4548
4549        # First, consider the case where invariant checks are disabled
4550        # "globally" (read: within the context of this test method
4551        # caller) as defined by check_sparse_tensor_invariants(False)
4552        # decorator:
4553        self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4554
4555        # Enable the invariant checks in a local context:
4556        with torch.sparse.check_sparse_tensor_invariants():
4557            self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4558
4559        # Leaving the local context must restore the "global" state of
4560        # the invariant check feature:
4561        self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4562
4563        # Since invariant checks are disabled by default, we can
4564        # create an invalid sparse tensor without raising an
4565        # exception:
4566        r = create_invalid_tensor()
4567        self.assertEqual(r.layout, layout)
4568
4569        # Or, when disabling the invariants check explicitly:
4570        r = create_invalid_tensor(check_invariants=False)
4571        self.assertEqual(r.layout, layout)
4572
4573        # Enabling invariant check via constructor's optional argument
4574        # will raise an exception when sparse tensor invariants are
4575        # violated:
4576        with self.assertRaisesRegex(RuntimeError, expected_exception_message):
4577            create_invalid_tensor(check_invariants=True)
4578
4579        # Check that the global invariant check flag has been restored
4580        # after raising the exception above:
4581        self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4582
4583        # Next, consider the case where invariant checks are enabled
4584        # within a local context:
4585        with torch.sparse.check_sparse_tensor_invariants():
4586            self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4587
4588            # Since invariant checks are now enabled by default, an
4589            # attempt to create an invalid sparse tensor will lead to
4590            # an exception:
4591            with self.assertRaisesRegex(RuntimeError, expected_exception_message):
4592                create_invalid_tensor()
4593
4594            # Similarly, when enabling the invariant checks
4595            # explicitly, invalid sparse tensor construction will lead
4596            # to an exception:
4597            with self.assertRaisesRegex(RuntimeError, expected_exception_message):
4598                create_invalid_tensor(check_invariants=True)
4599
4600            # However, invariants check can be disabled via
4601            # constructor's optional argument so that the invalid
4602            # tensor is succesfully constructed:
4603            r = create_invalid_tensor(check_invariants=False)
4604            self.assertEqual(r.layout, layout)
4605
4606            # Check that the invariant check flag has been restored
4607            # when leaving the constructor:
4608            self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4609
4610        # Double-check restoring the global state when leaving the
4611        # local context:
4612        self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4613
4614        # Test nesting of pre-defined context managers
4615        check_ctx = torch.sparse.check_sparse_tensor_invariants(True)
4616        no_check_ctx = torch.sparse.check_sparse_tensor_invariants(False)
4617        with check_ctx:
4618            self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4619            with no_check_ctx:
4620                self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4621            self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4622        self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4623
4624        # Test an attempt to re-use an activate context manager instance
4625        check_ctx2 = torch.sparse.check_sparse_tensor_invariants(True)
4626        with check_ctx:
4627            self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4628            with no_check_ctx:
4629                self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4630                with self.assertRaisesRegex(RuntimeError, "This context manager instance is already activated."
4631                                            " Use a different context manager instance for context nesting"):
4632                    with check_ctx:
4633                        self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4634                self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4635                with check_ctx2:
4636                    self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4637                self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4638            self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4639        self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4640
4641    def test_generate_simple_inputs(self):
4642        layouts = [torch.strided, torch.sparse_coo, torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc]
4643
4644        tested_combinations = set()
4645        for tensors in zip(*map(self.generate_simple_inputs, layouts)):
4646            for i, t in enumerate(tensors):
4647                self.assertEqual(t.layout, layouts[i])
4648
4649                # all layouts must produce semantically the same tensors
4650                self.assertEqual(t, tensors[0])
4651
4652                if t.layout is torch.strided:
4653                    is_hybrid = None
4654                else:
4655                    is_hybrid = t.dense_dim() > 0
4656                if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
4657                    is_batch = t.crow_indices().ndim > 1
4658                elif t.layout in {torch.sparse_csc, torch.sparse_bsc}:
4659                    is_batch = t.ccol_indices().ndim > 1
4660                else:
4661                    is_batch = None
4662                if t.layout in {torch.sparse_bsr, torch.sparse_bsc}:
4663                    blocksize = t.values().shape[1:3]
4664                    nontrivial_blocksize = 1 not in blocksize
4665                else:
4666                    nontrivial_blocksize = None
4667                if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
4668                    contiguous_indices = t.crow_indices().is_contiguous() and t.col_indices().is_contiguous()
4669                    contiguous_values = t.values().is_contiguous()
4670                elif t.layout in {torch.sparse_csc, torch.sparse_bsc}:
4671                    contiguous_indices = t.ccol_indices().is_contiguous() and t.row_indices().is_contiguous()
4672                    contiguous_values = t.values().is_contiguous()
4673                elif t.layout is torch.sparse_coo:
4674                    contiguous_indices = t._indices().is_contiguous()
4675                    contiguous_values = t._values().is_contiguous()
4676                else:
4677                    contiguous_indices = None
4678                    contiguous_values = t.is_contiguous()
4679
4680                tested_combinations.add((t.layout, is_hybrid, is_batch, nontrivial_blocksize,
4681                                         contiguous_indices, contiguous_values))
4682
4683        # Ensure that the inputs generation covers all layout,
4684        # non-hybrid/hybrid, non-batch/batch, and contiguity
4685        # combinations:
4686        untested_combinations = set()
4687        for layout in layouts:
4688            for is_hybrid in [False, True]:
4689                if layout is torch.strided:
4690                    is_hybrid = None
4691                for is_batch in [False, True]:
4692                    if layout in {torch.sparse_coo, torch.strided}:
4693                        is_batch = None
4694                    for nontrivial_blocksize in [False, True]:
4695                        if layout not in {torch.sparse_bsr, torch.sparse_bsc}:
4696                            nontrivial_blocksize = None
4697                        for contiguous_indices in [False, True]:
4698                            if layout is torch.strided:
4699                                contiguous_indices = None
4700                            elif not is_batch:
4701                                # indices are contiguous per-patch
4702                                contiguous_indices = True
4703                            for contiguous_values in [False, True]:
4704                                key = (layout, is_hybrid, is_batch, nontrivial_blocksize,
4705                                       contiguous_indices, contiguous_values)
4706                                if key not in tested_combinations:
4707                                    untested_combinations.add(
4708                                        f'layout={layout}, is_hybrid={is_hybrid}, is_batch={is_batch},'
4709                                        f' nontrivial_blocksize={nontrivial_blocksize},'
4710                                        f' contiguous_indices{contiguous_indices}, contiguous_values={contiguous_values}')
4711        assert not untested_combinations, untested_combinations
4712
4713    @all_sparse_layouts('layout', include_strided=False)
4714    def test_constructor_autograd(self, device, layout):
4715
4716        def specific_constructor(*args, **kwargs):
4717            if layout is torch.sparse_csr:
4718                return torch.sparse_csr_tensor(*args, **kwargs)
4719            elif layout is torch.sparse_csc:
4720                return torch.sparse_csc_tensor(*args, **kwargs)
4721            elif layout is torch.sparse_bsc:
4722                return torch.sparse_bsc_tensor(*args, **kwargs)
4723            elif layout is torch.sparse_bsr:
4724                return torch.sparse_bsr_tensor(*args, **kwargs)
4725            elif layout is torch.sparse_coo:
4726                return torch.sparse_coo_tensor(*args, **kwargs)
4727            else:
4728                raise NotImplementedError(layout)
4729
4730        def generic_constructor(*args, **kwargs):
4731            if layout in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}:
4732                kwargs.update(layout=layout)
4733                return torch.sparse_compressed_tensor(*args, **kwargs)
4734            elif layout is torch.sparse_coo:
4735                return torch.sparse_coo_tensor(*args, **kwargs)
4736            else:
4737                raise NotImplementedError(layout)
4738
4739        if layout is torch.sparse_coo:
4740            constructors = (specific_constructor,)
4741        else:
4742            constructors = (specific_constructor, generic_constructor)
4743
4744        for args, kwargs in self.generate_simple_inputs(
4745                layout, device=device, dtype=torch.float64,
4746                enable_batch=False,  # TODO: remove after gh-104868 is resolved
4747                output_tensor=False):
4748            values_offset = 1 if layout is torch.sparse_coo else 2
4749
4750            for cnstr in constructors:
4751                for requires_grad in (False, True):
4752                    values = args[values_offset].detach().requires_grad_(requires_grad)
4753                    args = (*args[:values_offset], values, *args[values_offset + 1:])
4754                    kwargs_ = dict(kwargs)
4755                    args_ = args + (kwargs_.pop('size'),)
4756
4757                    sparse = cnstr(*args, **kwargs)
4758
4759                    self.assertEqual(sparse.requires_grad, requires_grad)
4760
4761                    if requires_grad:
4762                        for masked in (False, True):
4763                            if layout is torch.sparse_coo:
4764                                torch.autograd.gradcheck(
4765                                    lambda i, v: cnstr(i, v, **kwargs).to_dense(masked_grad=masked),
4766                                    args, masked=masked)
4767                                torch.autograd.gradcheck(
4768                                    lambda i, v, sz: cnstr(i, v, sz, **kwargs_).to_dense(masked_grad=masked),
4769                                    args_, masked=masked)
4770                            else:
4771                                if layout in {torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc} and 0:
4772                                    # TODO: remove this if-block after gh-107370 is resolved
4773                                    continue
4774                                torch.autograd.gradcheck(
4775                                    lambda ci, pi, v: cnstr(ci, pi, v, **kwargs).to_dense(masked_grad=masked),
4776                                    args, masked=masked)
4777                                torch.autograd.gradcheck(
4778                                    lambda ci, pi, v, sz: cnstr(ci, pi, v, sz, **kwargs_).to_dense(masked_grad=masked),
4779                                    args_, masked=masked)
4780
4781    @all_sparse_layouts('from_layout', include_strided=False)
4782    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
4783    @parametrize("index_dtype", [torch.int32, torch.int64])
4784    def test_to_dense(self, from_layout, device, dtype, index_dtype):
4785        """
4786        This test tests conversion from any layout to strided layout.
4787        """
4788        for t in self.generate_simple_inputs(
4789                from_layout, device=device, dtype=dtype, index_dtype=index_dtype):
4790            r = t.to_dense()
4791            self.assertEqual(r.layout, torch.strided)
4792            self.assertEqual(r, t)
4793
4794    @all_sparse_layouts('from_layout', include_strided=False)
4795    @dtypes(torch.float64, torch.complex128)
4796    @parametrize("index_dtype", [torch.int64])
4797    @gradcheck_semantics()
4798    def test_gradcheck_to_dense(self, from_layout, device, dtype, index_dtype, gradcheck):
4799        for t in self.generate_simple_inputs(
4800                from_layout, device=device, dtype=dtype, index_dtype=index_dtype):
4801            batch_dim = t.dim() - t.dense_dim() - t.sparse_dim()
4802            if batch_dim > 0:
4803                # TODO: implement batch support in _convert_indices_from_csr_to_coo
4804                continue
4805            t = t.clone().detach().requires_grad_(True)
4806            r = gradcheck(lambda x: torch.Tensor.to_dense(x, masked_grad=gradcheck.masked), t)
4807            self.assertTrue(r)
4808
4809    @all_sparse_layouts('from_layout', include_strided=True)
4810    @all_sparse_layouts('to_layout', include_strided=False)
4811    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
4812    @parametrize("index_dtype", [torch.int32, torch.int64])
4813    def test_to_sparse(self, from_layout, to_layout, device, dtype, index_dtype):
4814        """
4815        This test tests conversion from any layout to any sparse layout.
4816        """
4817        for t in self.generate_simple_inputs(
4818                from_layout, device=device, dtype=dtype, index_dtype=index_dtype,
4819                enable_hybrid=(
4820                    # TODO: to support conversion strided->hybrid
4821                    # CSR/CSC/BSR/BSC, to_sparse() requires extra keyword
4822                    # argument, either nof_batch_dims or
4823                    # nof_dense_dims
4824                    not (from_layout is torch.strided and to_layout in
4825                         {torch.sparse_bsr, torch.sparse_bsc, torch.sparse_csr, torch.sparse_csc}))):
4826
4827            if to_layout in {torch.sparse_bsr, torch.sparse_bsc}:
4828                if from_layout == torch.sparse_bsr:
4829                    batch_ndim = t.crow_indices().dim() - 1
4830                    blocksize = t.values().shape[batch_ndim + 1:batch_ndim + 3]
4831                elif from_layout == torch.sparse_bsc:
4832                    batch_ndim = t.ccol_indices().dim() - 1
4833                    blocksize = t.values().shape[batch_ndim + 1:batch_ndim + 3]
4834                else:
4835                    blocksize = (1, 1)
4836            else:
4837                blocksize = None
4838
4839            if from_layout is torch.strided:
4840                is_batch = None
4841                is_hybrid = None
4842            else:
4843                is_batch = t.dim() > (t.sparse_dim() + t.dense_dim())
4844                is_hybrid = t.dense_dim() > 0
4845
4846            def explicit_to_sparse(x):
4847                # Used to check that the explicit conversion methods
4848                # are consistent with the `to_sparse(*, layout,
4849                # blocksize)` method.
4850                if to_layout is torch.sparse_coo:
4851                    return x.to_sparse_coo()
4852                elif to_layout is torch.sparse_csr:
4853                    return x.to_sparse_csr()
4854                elif to_layout is torch.sparse_csc:
4855                    return x.to_sparse_csc()
4856                elif to_layout is torch.sparse_bsr:
4857                    return x.to_sparse_bsr(blocksize)
4858                elif to_layout is torch.sparse_bsc:
4859                    return x.to_sparse_bsc(blocksize)
4860                else:
4861                    assert 0  # unreachable
4862
4863            # TODO: The following exception cases all correspond to
4864            # not implemented conversions
4865            if from_layout in {
4866                    torch.sparse_csr, torch.sparse_csc} and to_layout in {torch.sparse_bsr, torch.sparse_bsc} and is_batch:
4867                with self.assertRaisesRegex(
4868                        RuntimeError,
4869                        r"conversion from Sparse(Csr|Csc) to Sparse(Bsr|Bsc) for batched inputs is not supported"):
4870                    t.to_sparse(layout=to_layout, blocksize=blocksize)
4871                with self.assertRaisesRegex(
4872                        RuntimeError,
4873                        r"conversion from Sparse(Csr|Csc) to Sparse(Bsr|Bsc) for batched inputs is not supported"):
4874                    explicit_to_sparse(t)
4875                continue
4876            elif from_layout is torch.sparse_coo and to_layout in {
4877                    torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc} and t.sparse_dim() != 2:
4878                with self.assertRaisesRegex(
4879                        RuntimeError,
4880                        r"conversion from Sparse to .* for input tensors with sparse_dim\(\)!=2 is not supported"):
4881                    t.to_sparse(layout=to_layout, blocksize=blocksize)
4882                with self.assertRaisesRegex(
4883                        RuntimeError,
4884                        r"conversion from Sparse to .* for input tensors with sparse_dim\(\)!=2 is not supported"):
4885                    explicit_to_sparse(t)
4886                continue
4887            elif (from_layout, to_layout) in {(torch.sparse_bsc, torch.sparse_csr), (torch.sparse_bsc, torch.sparse_csc),
4888                                              (torch.sparse_bsr, torch.sparse_csr), (torch.sparse_bsr, torch.sparse_csc)}:
4889                with self.assertRaisesRegex(
4890                        RuntimeError,
4891                        r"sparse_compressed_to_sparse_(csr|csc|bsr|bsc): expected\s*(Sparse(Csc|Csr)[,]|)\s*Sparse(Csr|Bsr)"
4892                        " or Sparse(Csc|Bsc) layout but got Sparse(Csr|Csc|Bsr|Bsc)"):
4893                    t.to_sparse(layout=to_layout, blocksize=blocksize)
4894                with self.assertRaisesRegex(
4895                        RuntimeError,
4896                        r"sparse_compressed_to_sparse_(csr|csc|bsr|bsc): expected\s*(Sparse(Csc|Csr)[,]|)\s*Sparse(Csr|Bsr)"
4897                        " or Sparse(Csc|Bsc) layout but got Sparse(Csr|Csc|Bsr|Bsc)"):
4898                    explicit_to_sparse(t)
4899                self.skipTest('NOT IMPL')
4900            else:
4901                r = t.to_sparse(layout=to_layout, blocksize=blocksize)
4902
4903                self.assertEqual(r.layout, to_layout)
4904
4905                # to_sparse method uses unsafe construction of sparse
4906                # tensors. Here we explicitly validate the results to
4907                # make sure that the sparse tensors are consistent
4908                # with the corresponding sparse tensor invariants.
4909                if r.layout in {torch.sparse_csr, torch.sparse_bsr, torch.sparse_csc, torch.sparse_bsc}:
4910                    if r.layout in {torch.sparse_csr, torch.sparse_bsr}:
4911                        compressed_indices, plain_indices = r.crow_indices(), r.col_indices()
4912                    else:
4913                        compressed_indices, plain_indices = r.ccol_indices(), r.row_indices()
4914                    torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, r.values(),
4915                                                                  r.shape, r.layout)
4916                    if from_layout in {torch.strided, torch.sparse_coo}:
4917                        self.assertEqual(compressed_indices.dtype, torch.int64)
4918                        self.assertEqual(plain_indices.dtype, torch.int64)
4919                    else:
4920                        self.assertEqual(compressed_indices.dtype, index_dtype)
4921                        self.assertEqual(plain_indices.dtype, index_dtype)
4922                    self.assertEqual(r.values().dtype, dtype)
4923                elif r.layout is torch.sparse_coo:
4924                    if t.layout is torch.sparse_coo:
4925                        self.assertEqual(t.is_coalesced(), r.is_coalesced())
4926
4927                    # Check r is truly coalesced when r.is_coalesced == True
4928                    if r.is_coalesced():
4929                        self.assertTrue(is_coalesced_indices(r))
4930
4931                    torch._validate_sparse_coo_tensor_args(r._indices(), r._values(), r.shape)
4932                    self.assertEqual(r._indices().dtype, torch.int64)
4933                    self.assertEqual(r._values().dtype, dtype)
4934                else:
4935                    assert 0  # unreachable
4936
4937                # Finally, we'll test tensor equality:
4938                self.assertEqual(r, t)
4939
4940                # Also, check consistency with explicit conversion methods:
4941                r2 = explicit_to_sparse(t)
4942                self.assertEqual(r2, r)
4943
4944                # Check inverse conversion from sparse compressed block tensors
4945                if from_layout == torch.sparse_bsr:
4946                    batch_ndim = t.crow_indices().dim() - 1
4947                    from_blocksize = t.values().shape[batch_ndim + 1:batch_ndim + 3]
4948                elif from_layout == torch.sparse_bsc:
4949                    batch_ndim = t.ccol_indices().dim() - 1
4950                    from_blocksize = t.values().shape[batch_ndim + 1:batch_ndim + 3]
4951                else:
4952                    continue
4953                if r.ndim != 2:
4954                    continue
4955
4956                t2 = r.to_sparse(layout=from_layout, blocksize=from_blocksize)
4957                self.assertEqual(t2, t)
4958
4959        # extra tests
4960        if (from_layout, to_layout) == (torch.sparse_csr, torch.sparse_bsr):
4961            # See gh-90910
4962            t = torch.tensor([[0, 0, 1, 0], [0, 1, 0, 0]], dtype=dtype, device=device).to_sparse_csr()
4963            r = t.to_sparse_bsr((2, 2))
4964            torch._validate_sparse_compressed_tensor_args(r.crow_indices(), r.col_indices(), r.values(), r.shape, r.layout)
4965            self.assertEqual(r, t)
4966
4967        if (from_layout, to_layout) in {(torch.sparse_csr, torch.sparse_csc),
4968                                        (torch.sparse_csc, torch.sparse_csr)}:
4969            # See gh-91007
4970            compressed_indices = torch.tensor([0, 4, 8, 8, 12, 16, 20], dtype=index_dtype, device=device)
4971            plain_indices = torch.tensor([0, 1, 2, 3] * 5, dtype=index_dtype, device=device)
4972            t = torch.sparse_compressed_tensor(compressed_indices, plain_indices, range(20),
4973                                               dtype=dtype, device=device, layout=from_layout)
4974            r = t.to_sparse(layout=to_layout)
4975            if r.layout in {torch.sparse_csr, torch.sparse_bsr}:
4976                compressed_indices, plain_indices = r.crow_indices(), r.col_indices()
4977            else:
4978                compressed_indices, plain_indices = r.ccol_indices(), r.row_indices()
4979            torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, r.values(), r.shape, r.layout)
4980            self.assertEqual(r, t)
4981
4982    @onlyNativeDeviceTypes
4983    @suppress_warnings
4984    @ops(reduction_ops_with_sparse_support)
4985    @precisionOverride({torch.bfloat16: 5e-4, torch.float16: 5e-3})
4986    @all_sparse_layouts('layout', include_strided=False)
4987    def test_reductions(self, layout, device, dtype, op):
4988        count = 0
4989        for sample in op.sample_inputs_sparse(layout, device, dtype):
4990            count += 1
4991
4992            t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs
4993            result = op.op(t_inp, *t_args, **t_kwargs)
4994
4995            #  Checking invariant rop(inp, ...).to_dense() == rop(inp.to_dense(), ...)
4996            dense = op.op(t_inp.to_dense(), *t_args, **t_kwargs)
4997            self.assertEqual(result, dense)
4998
4999        if count == 0:
5000            # we count samples to avoid false-positive test reports
5001            self.skipTest('no sample inputs')
5002
5003    @onlyNativeDeviceTypes
5004    @suppress_warnings
5005    @ops(reduction_ops_with_sparse_support, allowed_dtypes=(torch.float32, torch.float64, torch.complex64, torch.complex128))
5006    @all_sparse_layouts('layout', include_strided=False)
5007    def test_reductions_backward(self, layout, device, dtype, op):
5008        count = 0
5009        for sample in op.sample_inputs_sparse(layout, device, dtype, requires_grad=True):
5010            t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs
5011            r = op.op(t_inp, *t_args, **t_kwargs)
5012            if r.numel() != 0:
5013                r = r.sum()
5014
5015            if op.name == 'sum':
5016                count += 1
5017                r.abs().backward()
5018                self.assertEqual(t_inp.grad, torch.ones(t_inp.shape, dtype=dtype, device=device) * torch.sgn(r))
5019            else:
5020                self.skipTest('NOT IMPL')
5021
5022        if count == 0:
5023            # we count samples to avoid false-positive test reports
5024            self.skipTest('no sample inputs')
5025
5026    @onlyNativeDeviceTypes
5027    @suppress_warnings
5028    @parametrize("mth", [subtest(mth, name=mth.__name__)
5029                         for mth in [torch.Tensor.is_coalesced,
5030                                     torch.Tensor.coalesce,
5031                                     torch.Tensor.indices,
5032                                     torch.Tensor.values,
5033                                     torch.Tensor.crow_indices,
5034                                     torch.Tensor.col_indices,
5035                                     torch.Tensor.ccol_indices,
5036                                     torch.Tensor.row_indices,
5037                                     ]])
5038    @all_sparse_layouts('layout', include_strided=True)
5039    def test_unsupported_backend_error_message(self, mth, layout, device):
5040        inp = torch.tensor([[1, 2], [3, 4]], device=device).to_sparse(
5041            layout=layout,
5042            blocksize=(1, 1) if layout in {torch.sparse_bsr, torch.sparse_bsc} else None)
5043        assert inp.layout is layout
5044
5045        expected_behaviour = dict(
5046            # <mth name> = (<supported layouts>, <exception message on other layouts>)
5047            is_coalesced=({torch.sparse_coo},
5048                          "is_coalesced expected sparse coordinate tensor layout but got (Sparse(Csr|Csc|Bsr|Bsc)|Strided)"),
5049            coalesce=({torch.sparse_coo},
5050                      "coalesce expected sparse coordinate tensor layout but got (Sparse(Csr|Csc|Bsr|Bsc)|Strided)"),
5051            indices=({torch.sparse_coo},
5052                     "indices expected sparse coordinate tensor layout but got (Sparse(Csr|Csc|Bsr|Bsc)|Strided)"),
5053            values=({torch.sparse_coo, torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc},
5054                    "values expected sparse tensor layout but got Strided"),
5055            crow_indices=({torch.sparse_csr, torch.sparse_bsr},
5056                          "crow_indices expected sparse row compressed tensor layout but got (Sparse(Csc|Bsc|)|Strided)"),
5057            col_indices=({torch.sparse_csr, torch.sparse_bsr},
5058                         "col_indices expected sparse row compressed tensor layout but got (Sparse(Csc|Bsc|)|Strided)"),
5059            ccol_indices=({torch.sparse_csc, torch.sparse_bsc},
5060                          "ccol_indices expected sparse column compressed tensor layout but got (Sparse(Csr|Bsr|)|Strided)"),
5061            row_indices=({torch.sparse_csc, torch.sparse_bsc},
5062                         "row_indices expected sparse column compressed tensor layout but got (Sparse(Csr|Bsr|)|Strided)"),
5063        )[mth.__name__]
5064
5065        if layout in expected_behaviour[0]:
5066            mth(inp)
5067        else:
5068            with self.assertRaisesRegex(RuntimeError, expected_behaviour[1]):
5069                mth(inp)
5070
5071    @onlyNativeDeviceTypes
5072    @all_sparse_layouts('layout', include_strided=not True)
5073    @dtypes(torch.float64, torch.cdouble)
5074    @parametrize("masked", [subtest(False, name='sparse'), subtest(True, name='masked')])
5075    @parametrize("fast_mode", [subtest(False, name='slow'), subtest(True, name='fast')])
5076    def test_gradcheck_mm(self, layout, dtype, device, masked, fast_mode):
5077        # This function does not check the following cases:
5078        # - batch or hybrid tensors because addmm does not support
5079        #   such inputs yet
5080        # - check_forward_ad=True because of the lack of sparse tensor
5081        #   support in aten::view_as_real, torch._VF._make_dual, etc.
5082
5083        ref_x = torch.tensor([[1, 2, 0, 0],
5084                              [0, 6, 0, 0],
5085                              [0, 0, 0, 0],
5086                              [13, 14, 0, 15]], dtype=dtype, device=device)
5087        ref_y = torch.tensor([[11, 12, 13, 14],
5088                              [21, 22, 23, 24],
5089                              [31, 32, 33, 34],
5090                              [41, 42, 43, 44]],
5091                             dtype=dtype, device=device)
5092
5093        mm = torch.sparse.mm if masked else torch.mm
5094
5095        blocksize = (2, 2) if layout in {torch.sparse_bsr, torch.sparse_bsc} else None
5096        x = ref_x.to_sparse(layout=layout, blocksize=blocksize).requires_grad_(True)
5097        y = ref_y.requires_grad_(True)
5098
5099        if layout is torch.sparse_bsr and not masked or layout is torch.sparse_bsc:
5100            with self.assertRaisesRegex(
5101                    RuntimeError,
5102                    r"addmm: computation on (CPU|CUDA) is not implemented for Strided \+ Sparse(Bsr|Bsc) @ Strided"):
5103                torch.autograd.gradcheck(mm, (x, y), fast_mode=fast_mode, masked=masked)
5104            self.skipTest('NOT IMPL')
5105        elif layout in {torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc} and masked:
5106            with self.assertRaisesRegex(
5107                    RuntimeError,
5108                    r"(sparse_addmm_sparse_backward: unsupported combination of layouts,"
5109                    r" grad: Strided, mat1: Sparse(Csc|Bsr|Bsc), mat2: Strided"
5110                    r"|addmm: computation on (CPU|CUDA) is not implemented for "
5111                    r"Strided \+ Sparse(Csc|Bsr|Bsc) @ Strided without MKL)"):
5112                torch.autograd.gradcheck(mm, (x, y), fast_mode=fast_mode, masked=masked)
5113            self.skipTest('NOT IMPL')
5114        else:
5115            torch.autograd.gradcheck(mm, (x, y), fast_mode=fast_mode, masked=masked)
5116
5117    @onlyNativeDeviceTypes
5118    @suppress_warnings
5119    @ops(binary_ufuncs_with_sparse_support)
5120    @all_sparse_layouts('layout', include_strided=False)
5121    def test_binary_operation(self, layout, device, dtype, op):
5122        if not op.supports_sparse_layout(layout):
5123            self.skipTest(f'{layout} is not supported in `{op.name}` OpInfo definition. Skipping!')
5124
5125        for sample in op.sample_inputs_sparse(layout, device, dtype):
5126            if validate_sample_input_sparse(op, sample, check_validate=False) is not sample:
5127                # that is, the validation returns the sparse sample
5128                # wrapped within ErrorInput instance
5129                continue
5130            t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs
5131            batch_dim = t_inp.dim() - t_inp.dense_dim() - t_inp.sparse_dim()
5132            result = op.op(t_inp, *t_args, **t_kwargs)
5133
5134            # Check rop(inp, ...).shape == inp.shape
5135            self.assertEqual(result.shape, t_inp.shape)
5136
5137            # Check rop(inp, ...).sparse_dim() == inp.sparse_dim()
5138            self.assertEqual(result.sparse_dim(), t_inp.sparse_dim())
5139
5140            # Check rop(inp, ...).dense_dim() == inp.dense_dim()
5141            self.assertEqual(result.dense_dim(), t_inp.dense_dim())
5142
5143            # Check invariant rop(inp, ...).to_dense() == rop(inp.to_dense(), ...)
5144            try:
5145                dense = op.op(t_inp.to_dense(), *(t_args[0].to_dense(), *t_args[1:]), **t_kwargs)
5146            except Exception as msg:
5147                # this is strided op issue, so skipping the sample silently here
5148                if "\"cpublas_axpy_impl\" not implemented for 'ComplexHalf'" in str(msg):
5149                    continue
5150                raise
5151            self.assertEqual(result, dense)
5152
5153    @onlyCPU
5154    @all_sparse_layouts('layout', include_strided=True)
5155    @dtypes(torch.double)
5156    def test_to_sparse_identity(self, device, layout, dtype):
5157        for dense_dim in range(4):
5158            x_dense = torch.eye(dense_dim, dtype=dtype, device=device)
5159            for sparse_dim_in in range(1, dense_dim):
5160                x_sparse = x_dense.to_sparse(sparse_dim_in)
5161                for sparse_dim_out in range(0, dense_dim):
5162                    if sparse_dim_out == sparse_dim_in:
5163                        self.assertTrue(x_sparse.to_sparse(sparse_dim_out).sparse_dim() == sparse_dim_out)
5164                    else:
5165                        with self.assertRaisesRegex(
5166                                RuntimeError,
5167                                r"to_sparse: conversion from Sparse to Sparse with sparse_dim argument !=self.sparse_dim\(\)"
5168                                " is not supported"):
5169                            x_sparse.to_sparse(sparse_dim_out)
5170
5171
5172    @onlyNativeDeviceTypes
5173    @suppress_warnings
5174    @ops(like_fns_with_sparse_support)
5175    @all_sparse_layouts('layout', include_strided=False)
5176    def test_like_fns(self, layout, device, dtype, op):
5177
5178        for sample in op.sample_inputs_sparse(layout, device, dtype):
5179            t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs
5180            batch_dim = t_inp.dim() - t_inp.dense_dim() - t_inp.sparse_dim()
5181            if t_inp.layout in {torch.sparse_bsr, torch.sparse_bsc}:
5182                expected_blocksize = t_inp.values().shape[batch_dim + 1:batch_dim + 3]
5183            else:
5184                expected_blocksize = None
5185            expected_dtype = t_kwargs.get('dtype', dtype)
5186            expected_device = torch.device(t_kwargs.get('device', device))
5187            expected_layout = t_kwargs.get('layout', layout)
5188
5189            result = op.op(t_inp, *t_args, **t_kwargs)
5190
5191            self.assertEqual(result.dtype, expected_dtype)
5192            self.assertEqual(result.device.type, expected_device.type)
5193            self.assertEqual(result.layout, expected_layout)
5194
5195            if result.layout in {torch.sparse_bsr, torch.sparse_bsc}:
5196                result_batch_dim = result.dim() - result.dense_dim() - result.sparse_dim()
5197                blocksize = result.values().shape[result_batch_dim + 1:result_batch_dim + 3]
5198                self.assertEqual(blocksize, expected_blocksize)
5199
5200            # Check op(inp).shape == inp.shape
5201            self.assertEqual(result.shape, t_inp.shape)
5202
5203            if expected_layout is torch.strided:
5204                self.assertEqual(result.sparse_dim(), 0)
5205                # Check op(inp, layout=torch.strided).dense_dim() == inp.dim()
5206                self.assertEqual(result.dense_dim(), t_inp.dim())
5207            elif expected_layout is torch.sparse_coo:
5208                # Check op(inp, layout=torch.sparse_coo).sparse_dim() == batch_dim + inp.sparse_dim()
5209                self.assertEqual(result.sparse_dim(), batch_dim + t_inp.sparse_dim())
5210                # Check op(inp, layout=torch.sparse_coo).dense_dim() == inp.dense_dim()
5211                self.assertEqual(result.dense_dim(), t_inp.dense_dim())
5212
5213                torch._validate_sparse_coo_tensor_args(result._indices(), result._values(), result.shape)
5214            else:
5215                # Check op(inp).sparse_dim() == inp.sparse_dim()
5216                self.assertEqual(result.sparse_dim(), t_inp.sparse_dim())
5217                # Check op(inp).dense_dim() == inp.dense_dim()
5218                self.assertEqual(result.dense_dim(), t_inp.dense_dim())
5219
5220                if result.layout in {torch.sparse_csr, torch.sparse_bsr}:
5221                    compressed_indices, plain_indices = result.crow_indices(), result.col_indices()
5222                else:
5223                    compressed_indices, plain_indices = result.ccol_indices(), result.row_indices()
5224
5225                torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, result.values(),
5226                                                              result.shape, result.layout)
5227
5228    @all_sparse_layouts('mask_layout', include_strided=False)
5229    @onlyNativeDeviceTypes
5230    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
5231    def test_sparse_mask(self, mask_layout, device, dtype):
5232        input_layout = torch.strided
5233        mask_dtype = torch.bool
5234        for mask in self.generate_simple_inputs(mask_layout, dtype=mask_dtype, device=device,
5235                                                enable_hybrid=False, enable_batch=False):
5236
5237            x = make_tensor(mask.shape, dtype=dtype, device=device).to_sparse(layout=input_layout)
5238
5239            result = x.sparse_mask(mask)
5240
5241            # Check invariant `x.sparse_mask(mask).<indices> == mask.<indices>`
5242            if mask_layout is torch.sparse_coo:
5243                self.assertEqual(result._indices(), mask._indices())
5244                ones = torch.sparse_coo_tensor(mask._indices(),
5245                                               torch.ones_like(mask._values(), dtype=x.dtype),
5246                                               mask.shape,
5247                                               is_coalesced=mask.is_coalesced())
5248            elif mask_layout in {torch.sparse_csr, torch.sparse_bsr}:
5249                self.assertEqual(result.crow_indices(), mask.crow_indices())
5250                self.assertEqual(result.col_indices(), mask.col_indices())
5251                ones = torch.sparse_compressed_tensor(mask.crow_indices(), mask.col_indices(),
5252                                                      torch.ones_like(mask.values(), dtype=x.dtype),
5253                                                      mask.shape, layout=mask.layout)
5254            else:
5255                self.assertEqual(result.ccol_indices(), mask.ccol_indices())
5256                self.assertEqual(result.row_indices(), mask.row_indices())
5257                ones = torch.sparse_compressed_tensor(mask.ccol_indices(), mask.row_indices(),
5258                                                      torch.ones_like(mask.values(), dtype=x.dtype),
5259                                                      mask.shape, layout=mask.layout)
5260
5261            # Check invariant:
5262            #  x.sparse_mask(mask).to_dense() == x.mul(sparse_xyz_tensor(<mask indices>,
5263            #                                          ones_like(<mask values>)).to_dense())
5264            expected = x.mul(ones.to_dense())
5265
5266            self.assertEqual(result.to_dense(), expected)
5267
5268            # Check invariant `mask.to_dense().sparse_mask(mask) == mask`
5269            result = mask.to_dense().sparse_mask(mask)
5270            self.assertEqual(result, mask)
5271
5272    @all_sparse_layouts('layout', include_strided=False)
5273    @parametrize("masked", [subtest(False, name='nonmasked'), subtest(True, name='masked')])
5274    @parametrize("fast_mode", [subtest(False, name='slow'), subtest(True, name='fast')])
5275    def test_as_sparse_gradcheck(self, layout, device, masked, fast_mode):
5276        gradcheck = torch.sparse.as_sparse_gradcheck(torch.autograd.gradcheck)
5277        sparse_compressed_layouts = {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}
5278
5279        def identity(x):
5280            return x
5281
5282        for func in (torch.Tensor.to_dense,
5283                     torch.Tensor.sum,
5284                     identity,
5285                     torch.Tensor.to_sparse,
5286                     torch.Tensor.values,
5287                     ):
5288            for x in self.generate_simple_inputs(
5289                    layout,
5290                    device=device,
5291                    dtype=torch.float64,
5292                    # TODO: fix gh-104868  to enable batched samples:
5293                    enable_batch=layout not in sparse_compressed_layouts,
5294                    enable_hybrid=not (
5295                        layout in sparse_compressed_layouts and (
5296                            # FIXME: RuntimeError: sparse_mask(): the
5297                            # number of sparse dimensions in `self`
5298                            # should match that of the `mask`. Got
5299                            # `self.sparse_dim() == 3` !=
5300                            # `mask.sparse_dim() == 2
5301                            func.__name__ == 'sum'
5302                            # FIXME: RuntimeError: expected
5303                            # col_indices to be a contiguous tensor
5304                            # per batch
5305                            or func.__name__ == 'to_sparse'
5306                        ))):
5307                if layout is torch.sparse_coo and func.__name__ == 'values':
5308                    x = x.coalesce()
5309
5310                gradcheck(func, x.requires_grad_(True), masked=masked, fast_mode=fast_mode)
5311
5312    @onlyCPU
5313    @all_sparse_layouts('layout', include_strided=False)
5314    @dtypes(torch.double)
5315    def test_dataloader(self, device, layout, dtype):
5316
5317        data = list(self.generate_simple_inputs(layout, device=device, dtype=dtype))
5318
5319        dataset = _SparseDataset(data)
5320        loader = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=2)
5321
5322        loaded_data = list(loader)
5323        self.assertEqual(data, loaded_data)
5324
5325    @onlyCPU
5326    def test_invalid_blocksize(self):
5327        # Blocksize should be a tuple/list/torch.Size containing two values
5328        with self.assertRaisesRegex(RuntimeError, ".*blocksize.*, but got 1"):
5329            torch.randn(1).to_sparse(blocksize=(1,))
5330        with self.assertRaisesRegex(RuntimeError, ".*blocksize.*, but got 1"):
5331            torch.randn(1).to_sparse(blocksize=[1])
5332        with self.assertRaisesRegex(RuntimeError, ".*blocksize.*, but got 1"):
5333            torch.randn(1).to_sparse(blocksize=torch.Size((1,)))
5334        with self.assertRaisesRegex(RuntimeError, ".*blocksize.*, but got 3"):
5335            torch.randn(1).to_sparse(blocksize=(1, 1, 1))
5336        with self.assertRaisesRegex(RuntimeError, ".*blocksize.*, but got 3"):
5337            torch.randn(1).to_sparse(blocksize=[1, 1, 1])
5338        with self.assertRaisesRegex(RuntimeError, ".*blocksize.*, but got 3"):
5339            torch.randn(1).to_sparse(blocksize=torch.Size((1, 1, 1)))
5340
5341    @unittest.skipIf(not torch.cuda.is_available(), 'requires cuda')
5342    @onlyCPU
5343    @all_sparse_layouts('layout', include_strided=True)
5344    def test_constructor_pin_memory(self, device, layout):
5345        """Tests sparse_xyz_tensor(indices, values, pin_memory=True)
5346        """
5347        self.assertEqual(device, "cpu")
5348        for t in self.generate_simple_inputs(
5349                layout, device=device, dtype=torch.float64,
5350                enable_zero_sized=False,  # pinning zero-sized tensors is a no-op
5351                pin_memory=True,
5352                enable_batch=False,  # TODO: remove after gh-104868 is resolved
5353        ):
5354            if layout is torch.sparse_coo:
5355                self.assertTrue(t._indices().is_pinned())
5356                self.assertTrue(t._values().is_pinned())
5357            elif layout in {torch.sparse_csr, torch.sparse_bsr}:
5358                self.assertTrue(t.crow_indices().is_pinned())
5359                self.assertTrue(t.col_indices().is_pinned())
5360                self.assertTrue(t.values().is_pinned())
5361            elif layout in {torch.sparse_csc, torch.sparse_bsc}:
5362                self.assertTrue(t.ccol_indices().is_pinned())
5363                self.assertTrue(t.row_indices().is_pinned())
5364                self.assertTrue(t.values().is_pinned())
5365            elif layout is torch.strided:
5366                pass
5367            else:
5368                assert 0  # unreachable
5369            self.assertTrue(t.is_pinned())
5370
5371    @unittest.skipIf(not torch.cuda.is_available(), 'requires cuda')
5372    @onlyCPU
5373    @all_sparse_layouts('layout', include_strided=True)
5374    def test_method_pin_memory(self, device, layout):
5375        """Tests sparse_xyz_tensor(indices, values, pin_memory=False).pin_memory()
5376        """
5377
5378        for t_ in self.generate_simple_inputs(
5379                layout, device=device, dtype=torch.float64,
5380                enable_zero_sized=False,  # pinning zero-sized tensors is a no-op
5381                pin_memory=False,         # no pinning
5382                enable_batch=False,  # TODO: remove after gh-104868 is resolved
5383        ):
5384            t = t_.pin_memory()
5385            self.assertTrue(t.is_pinned())
5386
5387            # registering a non-pinned tensor with CUDA memory is a
5388            # clone operation
5389            self.assertFalse(t_.is_pinned())
5390
5391            # registering already pinned tensor with CUDA memory is an
5392            # identity operation:
5393            t2 = t.pin_memory()
5394            self.assertTrue(t2 is t)
5395
5396            if layout is torch.sparse_coo:
5397                self.assertTrue(t._indices().is_pinned())
5398                self.assertTrue(t._values().is_pinned())
5399                self.assertFalse(t_._indices().is_pinned())
5400                self.assertFalse(t_._values().is_pinned())
5401            elif layout in {torch.sparse_csr, torch.sparse_bsr}:
5402                self.assertTrue(t.crow_indices().is_pinned())
5403                self.assertTrue(t.col_indices().is_pinned())
5404                self.assertTrue(t.values().is_pinned())
5405                self.assertFalse(t_.crow_indices().is_pinned())
5406                self.assertFalse(t_.col_indices().is_pinned())
5407                self.assertFalse(t_.values().is_pinned())
5408            elif layout in {torch.sparse_csc, torch.sparse_bsc}:
5409                self.assertTrue(t.ccol_indices().is_pinned())
5410                self.assertTrue(t.row_indices().is_pinned())
5411                self.assertTrue(t.values().is_pinned())
5412                self.assertFalse(t_.ccol_indices().is_pinned())
5413                self.assertFalse(t_.row_indices().is_pinned())
5414                self.assertFalse(t_.values().is_pinned())
5415            elif layout is torch.strided:
5416                pass
5417            else:
5418                assert 0  # unreachable
5419
5420
5421    @unittest.skipIf(not torch.cuda.is_available(), 'requires cuda')
5422    @onlyCPU
5423    @all_sparse_layouts('layout', include_strided=True)
5424    def test_constructor_pinned_memory(self, device, layout):
5425        """Tests sparse_xyz_tensor(indices.pin_memory(device), values.pin_memory(device))
5426        """
5427        pin_memory_device = "cuda"
5428        for t in self.generate_simple_inputs(
5429                layout, device=device, dtype=torch.float64,
5430                enable_zero_sized=False,     # pinning zero-sized tensors is a no-op
5431                pin_memory=None,             # constructor does not specify pin_memory=...
5432                members_pin_memory=True,     # indices and values are pinned
5433                enable_batch=False,          # TODO: remove after gh-104868 is resolved
5434        ):
5435            if layout is torch.sparse_coo:
5436                self.assertTrue(t._indices().is_pinned())
5437                self.assertTrue(t._values().is_pinned())
5438            elif layout in {torch.sparse_csr, torch.sparse_bsr}:
5439                self.assertTrue(t.crow_indices().is_pinned())
5440                self.assertTrue(t.col_indices().is_pinned())
5441                self.assertTrue(t.values().is_pinned())
5442            elif layout in {torch.sparse_csc, torch.sparse_bsc}:
5443                self.assertTrue(t.ccol_indices().is_pinned())
5444                self.assertTrue(t.row_indices().is_pinned())
5445                self.assertTrue(t.values().is_pinned())
5446            elif layout is torch.strided:
5447                pass
5448            else:
5449                assert 0  # unreachable
5450            self.assertTrue(t.is_pinned())
5451
5452    @unittest.skipIf(not torch.cuda.is_available(), 'requires cuda')
5453    @onlyCPU
5454    @all_sparse_layouts('layout', include_strided=False)
5455    def test_constructor_mismatched_pinned_memory(self, device, layout):
5456        """Test the failure to construct sparse tensor from indices and values
5457        that have different pinning states.
5458        """
5459        def generic_constructor(*args, **kwargs):
5460            if layout in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}:
5461                kwargs.update(layout=layout)
5462                return torch.sparse_compressed_tensor(*args, **kwargs)
5463            elif layout is torch.sparse_coo:
5464                return torch.sparse_coo_tensor(*args, **kwargs)
5465            else:
5466                raise NotImplementedError(layout)
5467
5468        for args, kwargs in self.generate_simple_inputs(
5469                layout, device=device, dtype=torch.float64,
5470                enable_zero_sized=False,     # pinning zero-sized tensors is a no-op
5471                enable_batch=False,  # TODO: remove after gh-104868 is resolved
5472                output_tensor=False):
5473
5474            # indices are pinned, values is a non-pinned tensor
5475            args1 = (args[0].pin_memory(), *args[1:])
5476
5477            # indices are non-pinned, values is a pinned tensor
5478            args2 = (*args[:-1], args[-1].pin_memory())
5479
5480            with self.assertRaisesRegex(
5481                    RuntimeError, r"memory pinning of \w*indices \(=1\) must match memory pinning of values \(=0\)"):
5482                generic_constructor(*args1, **kwargs)
5483
5484            with self.assertRaisesRegex(
5485                    RuntimeError, r"memory pinning of \w*indices \(=0\) must match memory pinning of values \(=1\)"):
5486                generic_constructor(*args2, **kwargs)
5487
5488
5489# e.g., TestSparseUnaryUfuncsCPU and TestSparseUnaryUfuncsCUDA
5490instantiate_device_type_tests(TestSparseUnaryUfuncs, globals(), except_for='meta')
5491
5492instantiate_device_type_tests(TestSparseMaskedReductions, globals(), except_for='meta')
5493
5494# e.g., TestSparseCPU and TestSparseCUDA
5495instantiate_device_type_tests(TestSparse, globals(), except_for='meta')
5496
5497instantiate_device_type_tests(TestSparseAny, globals(), except_for='meta')
5498
5499instantiate_parametrized_tests(TestSparseMeta)
5500
5501instantiate_parametrized_tests(TestSparseLegacyAndDeprecation)
5502
5503if __name__ == '__main__':
5504    run_tests()
5505