xref: /aosp_15_r20/external/pytorch/test/test_sparse.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: sparse"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport torch
4*da0073e9SAndroid Build Coastguard Workerimport itertools
5*da0073e9SAndroid Build Coastguard Workerimport functools
6*da0073e9SAndroid Build Coastguard Workerimport operator
7*da0073e9SAndroid Build Coastguard Workerimport random
8*da0073e9SAndroid Build Coastguard Workerimport unittest
9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor
10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \
11*da0073e9SAndroid Build Coastguard Worker    load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, \
12*da0073e9SAndroid Build Coastguard Worker    DeterministicGuard, first_sample, TEST_WITH_CROSSREF, TEST_WITH_ROCM, skipIfTorchDynamo, \
13*da0073e9SAndroid Build Coastguard Worker    parametrize, subtest, is_coalesced_indices, suppress_warnings, instantiate_parametrized_tests, \
14*da0073e9SAndroid Build Coastguard Worker    skipIfCrossRef
15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA
16*da0073e9SAndroid Build Coastguard Workerfrom numbers import Number
17*da0073e9SAndroid Build Coastguard Workerfrom typing import Dict, Any
18*da0073e9SAndroid Build Coastguard Workerfrom packaging import version
19*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import \
20*da0073e9SAndroid Build Coastguard Worker    (SM53OrLater, SM80OrLater, TEST_MULTIGPU)
21*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import \
22*da0073e9SAndroid Build Coastguard Worker    (instantiate_device_type_tests, ops, dtypes, dtypesIfCUDA, onlyCPU, onlyCUDA, precisionOverride,
23*da0073e9SAndroid Build Coastguard Worker     deviceCountAtLeast, OpDTypes, onlyNativeDeviceTypes)
24*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import \
25*da0073e9SAndroid Build Coastguard Worker    (op_db, reduction_ops, sparse_unary_ufuncs, sparse_masked_reduction_ops, binary_ufuncs)
26*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import (
27*da0073e9SAndroid Build Coastguard Worker    all_types, all_types_and_complex, all_types_and_complex_and, floating_and_complex_types,
28*da0073e9SAndroid Build Coastguard Worker    floating_and_complex_types_and, integral_types, floating_types_and,
29*da0073e9SAndroid Build Coastguard Worker)
30*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.opinfo.definitions.sparse import validate_sample_input_sparse
31*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.opinfo.refs import (
32*da0073e9SAndroid Build Coastguard Worker    ElementwiseBinaryPythonRefInfo,
33*da0073e9SAndroid Build Coastguard Worker    ReductionPythonRefInfo
34*da0073e9SAndroid Build Coastguard Worker)
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Workerdef _op_supports_any_sparse(op):
37*da0073e9SAndroid Build Coastguard Worker    return (op.supports_sparse
38*da0073e9SAndroid Build Coastguard Worker            or op.supports_sparse_csr
39*da0073e9SAndroid Build Coastguard Worker            or op.supports_sparse_csc
40*da0073e9SAndroid Build Coastguard Worker            or op.supports_sparse_bsr
41*da0073e9SAndroid Build Coastguard Worker            or op.supports_sparse_bsc)
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Workerreduction_ops_with_sparse_support = [
46*da0073e9SAndroid Build Coastguard Worker    op for op in reduction_ops if 'masked.' not in op.name and
47*da0073e9SAndroid Build Coastguard Worker    _op_supports_any_sparse(op) and not isinstance(op, ReductionPythonRefInfo)]
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Workerbinary_ufuncs_with_sparse_support = [
50*da0073e9SAndroid Build Coastguard Worker    op for op in binary_ufuncs if _op_supports_any_sparse(op) and
51*da0073e9SAndroid Build Coastguard Worker    not isinstance(op, ElementwiseBinaryPythonRefInfo)]
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Workerlike_fns_with_sparse_support = [op for op in op_db if _op_supports_any_sparse(op) and '_like' in op.name]
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Workerif TEST_SCIPY:
56*da0073e9SAndroid Build Coastguard Worker    import scipy.sparse
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
59*da0073e9SAndroid Build Coastguard Worker# sharding on sandcastle. This line silences flake warnings
60*da0073e9SAndroid Build Coastguard Workerload_tests = load_tests
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker# batched grad doesn't support sparse
63*da0073e9SAndroid Build Coastguard Workergradcheck = functools.partial(gradcheck, check_batched_grad=False)
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard WorkerCUSPARSE_SPMM_COMPLEX128_SUPPORTED = (
66*da0073e9SAndroid Build Coastguard Worker    IS_WINDOWS and torch.version.cuda and version.parse(torch.version.cuda) > version.parse("11.2")
67*da0073e9SAndroid Build Coastguard Worker) or (not IS_WINDOWS and not TEST_WITH_ROCM)
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard WorkerHIPSPARSE_SPMM_COMPLEX128_SUPPORTED = torch.version.hip and version.parse(torch.version.hip.split("-")[0]) >= version.parse("6.0")
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Workerdef all_sparse_layouts(test_name='layout', include_strided=False):
72*da0073e9SAndroid Build Coastguard Worker    return parametrize(test_name, [
73*da0073e9SAndroid Build Coastguard Worker        subtest(torch.strided, name='Strided'),
74*da0073e9SAndroid Build Coastguard Worker        subtest(torch.sparse_coo, name='SparseCOO'),
75*da0073e9SAndroid Build Coastguard Worker        subtest(torch.sparse_csr, name='SparseCSR'),
76*da0073e9SAndroid Build Coastguard Worker        subtest(torch.sparse_csc, name='SparseCSC'),
77*da0073e9SAndroid Build Coastguard Worker        subtest(torch.sparse_bsr, name='SparseBSR'),
78*da0073e9SAndroid Build Coastguard Worker        subtest(torch.sparse_bsc, name='SparseBSC'),
79*da0073e9SAndroid Build Coastguard Worker    ][(0 if include_strided else 1):])
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Workerdef gradcheck_semantics(test_name='gradcheck'):
82*da0073e9SAndroid Build Coastguard Worker    gradcheck_sparse = functools.partial(gradcheck, masked=False)
83*da0073e9SAndroid Build Coastguard Worker    gradcheck_masked = functools.partial(gradcheck, masked=True)
84*da0073e9SAndroid Build Coastguard Worker    gradcheck_sparse.masked = False
85*da0073e9SAndroid Build Coastguard Worker    gradcheck_masked.masked = True
86*da0073e9SAndroid Build Coastguard Worker    return parametrize(test_name, [
87*da0073e9SAndroid Build Coastguard Worker        subtest(gradcheck_sparse, name='sparse'),
88*da0073e9SAndroid Build Coastguard Worker        subtest(gradcheck_masked, name='masked')])
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Workerclass CrossRefSparseFakeMode(torch._subclasses.CrossRefFakeMode):
92*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
93*da0073e9SAndroid Build Coastguard Worker        super().__init__(
94*da0073e9SAndroid Build Coastguard Worker            self.ignore_op, check_strides=False,
95*da0073e9SAndroid Build Coastguard Worker            check_aliasing=False,
96*da0073e9SAndroid Build Coastguard Worker        )  # TODO: enable stride/alias checking
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker    # empty_like excluded for now due to sparse complex
99*da0073e9SAndroid Build Coastguard Worker    # aten._to_dense.default this one is getting called with csc
100*da0073e9SAndroid Build Coastguard Worker    @staticmethod
101*da0073e9SAndroid Build Coastguard Worker    def ignore_op(func):
102*da0073e9SAndroid Build Coastguard Worker        return func in (
103*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten.empty_like.default,
104*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten.set_.source_Storage_storage_offset,
105*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten.sspaddmm.out,
106*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten._spdiags.default,
107*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten._to_dense.default,
108*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten.indices.default,
109*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten._indices.default,
110*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten.values.default,
111*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten._values.default,
112*da0073e9SAndroid Build Coastguard Worker        )
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Workerclass TestSparseLegacyAndDeprecation(TestCase):
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
117*da0073e9SAndroid Build Coastguard Worker    def test_legacy_warnings(self):
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker        def f1():
120*da0073e9SAndroid Build Coastguard Worker            "torch.sparse.SparseTensor() is deprecated."\
121*da0073e9SAndroid Build Coastguard Worker                "  Please use torch.sparse_coo_tensor((0,), dtype=)"
122*da0073e9SAndroid Build Coastguard Worker            x_ref = torch.sparse_coo_tensor((0,), dtype=torch.float64)
123*da0073e9SAndroid Build Coastguard Worker            x = torch.sparse.DoubleTensor()
124*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, x_ref)
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker        def f2():
127*da0073e9SAndroid Build Coastguard Worker            "torch.sparse.SparseTensor(cdata=x._cdata) is deprecated."\
128*da0073e9SAndroid Build Coastguard Worker                "  Please use torch.sparse_coo_tensor(x._indices(), x._values(), x.shape)"
129*da0073e9SAndroid Build Coastguard Worker            x_ref = torch.tensor([[1, 2], [3, 4]], dtype=torch.float64).to_sparse()
130*da0073e9SAndroid Build Coastguard Worker            x = torch.sparse.DoubleTensor(cdata=x_ref._cdata)
131*da0073e9SAndroid Build Coastguard Worker            y = torch.sparse_coo_tensor(x._indices(), x._values(), x.shape)
132*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, x_ref)
133*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, x_ref)
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker        def f3():
136*da0073e9SAndroid Build Coastguard Worker            "torch.sparse.SparseTensor(indices, values, *, device=) is deprecated."\
137*da0073e9SAndroid Build Coastguard Worker                "  Please use torch.sparse_coo_tensor(indices, values, dtype=, device=)"
138*da0073e9SAndroid Build Coastguard Worker            x_ref = torch.sparse_coo_tensor([[0, 0, 1, 1], [0, 1, 0, 1]], [1, 2, 3, 4], dtype=torch.float64)
139*da0073e9SAndroid Build Coastguard Worker            x = torch.sparse.DoubleTensor(torch.tensor([[0, 0, 1, 1], [0, 1, 0, 1]]),
140*da0073e9SAndroid Build Coastguard Worker                                          torch.tensor([1, 2, 3, 4], dtype=torch.float64))
141*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, x_ref)
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker        def f4():
144*da0073e9SAndroid Build Coastguard Worker            "torch.sparse.SparseTensor(indices, values, shape, *, device=) is deprecated."\
145*da0073e9SAndroid Build Coastguard Worker                "  Please use torch.sparse_coo_tensor(indices, values, shape, dtype=, device=)"
146*da0073e9SAndroid Build Coastguard Worker            x_ref = torch.sparse_coo_tensor([[0, 0, 1, 1], [0, 1, 0, 1]], [1, 2, 3, 4], (2, 3), dtype=torch.float64)
147*da0073e9SAndroid Build Coastguard Worker            x = torch.sparse.DoubleTensor(torch.tensor([[0, 0, 1, 1], [0, 1, 0, 1]]),
148*da0073e9SAndroid Build Coastguard Worker                                          torch.tensor([1, 2, 3, 4], dtype=torch.float64), (2, 3))
149*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, x_ref)
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker        def f5():
152*da0073e9SAndroid Build Coastguard Worker            "torch.sparse.SparseTensor(shape, *, device=) is deprecated."\
153*da0073e9SAndroid Build Coastguard Worker                "  Please use torch.sparse_coo_tensor(shape, dtype=, device=)"
154*da0073e9SAndroid Build Coastguard Worker            x_ref = torch.sparse_coo_tensor((2, 3), dtype=torch.float64)
155*da0073e9SAndroid Build Coastguard Worker            x = torch.sparse.DoubleTensor(2, 3)
156*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, x_ref)
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker        for test_f in [f1, f2, f3, f4, f5]:
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker            with self.assertWarns(UserWarning, msg=test_f.__doc__) as cm:
161*da0073e9SAndroid Build Coastguard Worker                test_f()
162*da0073e9SAndroid Build Coastguard Worker                test_f()
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Worker            # Check warn-once:
165*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(cm.warnings), 1)
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Workerclass TestSparseBase(TestCase):
169*da0073e9SAndroid Build Coastguard Worker    def run(self, result=None):
170*da0073e9SAndroid Build Coastguard Worker        if TEST_WITH_CROSSREF:
171*da0073e9SAndroid Build Coastguard Worker            with CrossRefSparseFakeMode():
172*da0073e9SAndroid Build Coastguard Worker                return super().run(result)
173*da0073e9SAndroid Build Coastguard Worker        else:
174*da0073e9SAndroid Build Coastguard Worker            return super().run(result)
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Workerclass TestSparse(TestSparseBase):
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
179*da0073e9SAndroid Build Coastguard Worker        TestCase.setUp(self)
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard Worker        self.index_tensor = lambda *args, **kwargs: torch.tensor(*args, **kwargs, dtype=torch.int64)
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker        def sparse_empty_factory(*args, **kwargs):
184*da0073e9SAndroid Build Coastguard Worker            kwargs['layout'] = kwargs.get('layout', torch.sparse_coo)
185*da0073e9SAndroid Build Coastguard Worker            return torch.empty(*args, **kwargs)
186*da0073e9SAndroid Build Coastguard Worker        self.sparse_empty = sparse_empty_factory
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker        def sparse_tensor_factory(*args, **kwargs):
189*da0073e9SAndroid Build Coastguard Worker            return torch.sparse_coo_tensor(*args, **kwargs)
190*da0073e9SAndroid Build Coastguard Worker        self.sparse_tensor = sparse_tensor_factory
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Worker    def _gen_sparse(self, sparse_dim, nnz, with_size, dtype, device, coalesced):
193*da0073e9SAndroid Build Coastguard Worker        if isinstance(with_size, Number):
194*da0073e9SAndroid Build Coastguard Worker            with_size = [with_size] * sparse_dim
195*da0073e9SAndroid Build Coastguard Worker
196*da0073e9SAndroid Build Coastguard Worker        x, i, v = self.genSparseTensor(with_size, sparse_dim, nnz, not coalesced, dtype=dtype, device=device)
197*da0073e9SAndroid Build Coastguard Worker
198*da0073e9SAndroid Build Coastguard Worker        if not coalesced:
199*da0073e9SAndroid Build Coastguard Worker            self.assert_uncoalesced(x)
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker        return x, i, v
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker    def assert_uncoalesced(self, x):
204*da0073e9SAndroid Build Coastguard Worker        """
205*da0073e9SAndroid Build Coastguard Worker        Test if a CPU tensor is uncoalesced.  This is used to ensure
206*da0073e9SAndroid Build Coastguard Worker        correctness of the uncoalesced tensor generation algorithm.
207*da0073e9SAndroid Build Coastguard Worker        """
208*da0073e9SAndroid Build Coastguard Worker        assert not x.is_coalesced()
209*da0073e9SAndroid Build Coastguard Worker        existing_indices = set()
210*da0073e9SAndroid Build Coastguard Worker        indices = x._indices()
211*da0073e9SAndroid Build Coastguard Worker        for i in range(x._nnz()):
212*da0073e9SAndroid Build Coastguard Worker            index = str(indices[:, i])
213*da0073e9SAndroid Build Coastguard Worker            if index in existing_indices:
214*da0073e9SAndroid Build Coastguard Worker                return True
215*da0073e9SAndroid Build Coastguard Worker            else:
216*da0073e9SAndroid Build Coastguard Worker                existing_indices.add(index)
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker    def randn(self, *args, **kwargs):
219*da0073e9SAndroid Build Coastguard Worker        """
220*da0073e9SAndroid Build Coastguard Worker        Variant of torch.randn that also works in the TEST_CUDA case.
221*da0073e9SAndroid Build Coastguard Worker        """
222*da0073e9SAndroid Build Coastguard Worker        # TODO: Put this in torch.cuda.randn
223*da0073e9SAndroid Build Coastguard Worker        return torch.empty(*args, **kwargs).normal_()
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
226*da0073e9SAndroid Build Coastguard Worker    def test_print_coalesced(self, device, dtype):
227*da0073e9SAndroid Build Coastguard Worker        self._test_print(device, dtype, True)
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
230*da0073e9SAndroid Build Coastguard Worker    def test_print_uncoalesced(self, device, dtype):
231*da0073e9SAndroid Build Coastguard Worker        self._test_print(device, dtype, False)
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Worker    def _test_print(self, device, dtype, coalesced):
234*da0073e9SAndroid Build Coastguard Worker        shape_sparse_dim_nnz = [
235*da0073e9SAndroid Build Coastguard Worker            ((), 0, 2),
236*da0073e9SAndroid Build Coastguard Worker            ((0,), 0, 10),
237*da0073e9SAndroid Build Coastguard Worker            ((2,), 0, 3),
238*da0073e9SAndroid Build Coastguard Worker            ((100, 3), 1, 3),
239*da0073e9SAndroid Build Coastguard Worker            ((100, 20, 3), 2, 0),
240*da0073e9SAndroid Build Coastguard Worker            ((10, 0, 3), 0, 3),
241*da0073e9SAndroid Build Coastguard Worker            ((10, 0, 3), 0, 0),
242*da0073e9SAndroid Build Coastguard Worker        ]
243*da0073e9SAndroid Build Coastguard Worker        printed = []
244*da0073e9SAndroid Build Coastguard Worker        for shape, sparse_dim, nnz in shape_sparse_dim_nnz:
245*da0073e9SAndroid Build Coastguard Worker            indices_shape = torch.Size((sparse_dim, nnz))
246*da0073e9SAndroid Build Coastguard Worker            values_shape = torch.Size((nnz,) + shape[sparse_dim:])
247*da0073e9SAndroid Build Coastguard Worker            printed.append(f"# shape: {torch.Size(shape)}")
248*da0073e9SAndroid Build Coastguard Worker            printed.append(f"# nnz: {nnz}")
249*da0073e9SAndroid Build Coastguard Worker            printed.append(f"# sparse_dim: {sparse_dim}")
250*da0073e9SAndroid Build Coastguard Worker            printed.append(f"# indices shape: {indices_shape}")
251*da0073e9SAndroid Build Coastguard Worker            printed.append(f"# values shape: {values_shape}")
252*da0073e9SAndroid Build Coastguard Worker
253*da0073e9SAndroid Build Coastguard Worker            indices = torch.arange(indices_shape.numel(), dtype=self.index_tensor(0).dtype,
254*da0073e9SAndroid Build Coastguard Worker                                   device=device).view(indices_shape)
255*da0073e9SAndroid Build Coastguard Worker            for d in range(sparse_dim):
256*da0073e9SAndroid Build Coastguard Worker                indices[d].clamp_(max=(shape[d] - 1))  # make it valid index
257*da0073e9SAndroid Build Coastguard Worker            if not coalesced and indices.numel() > 0:
258*da0073e9SAndroid Build Coastguard Worker                indices[:, -1] = indices[:, 0]  # make it uncoalesced
259*da0073e9SAndroid Build Coastguard Worker            values_numel = values_shape.numel()
260*da0073e9SAndroid Build Coastguard Worker            values = torch.arange(values_numel, dtype=dtype,
261*da0073e9SAndroid Build Coastguard Worker                                  device=device).view(values_shape).div_(values_numel / 2.)
262*da0073e9SAndroid Build Coastguard Worker            sp_tensor = self.sparse_tensor(indices, values, shape, dtype=dtype, device=device)
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Worker            dtypes = [torch.int32]
265*da0073e9SAndroid Build Coastguard Worker            if values.dtype == torch.double:
266*da0073e9SAndroid Build Coastguard Worker                dtypes.append(torch.float)
267*da0073e9SAndroid Build Coastguard Worker            else:
268*da0073e9SAndroid Build Coastguard Worker                dtypes.append(torch.double)
269*da0073e9SAndroid Build Coastguard Worker            for dtype in dtypes:
270*da0073e9SAndroid Build Coastguard Worker                printed.append(f"########## {dtype} ##########")
271*da0073e9SAndroid Build Coastguard Worker                x = sp_tensor.detach().to(dtype)
272*da0073e9SAndroid Build Coastguard Worker                printed.append("# sparse tensor")
273*da0073e9SAndroid Build Coastguard Worker                printed.append(str(x))
274*da0073e9SAndroid Build Coastguard Worker                if x.dtype.is_floating_point:
275*da0073e9SAndroid Build Coastguard Worker                    printed.append("# after requires_grad_")
276*da0073e9SAndroid Build Coastguard Worker                    printed.append(str(x.requires_grad_()))
277*da0073e9SAndroid Build Coastguard Worker                    printed.append("# after addition")
278*da0073e9SAndroid Build Coastguard Worker                    printed.append(str(x + x))
279*da0073e9SAndroid Build Coastguard Worker                printed.append("# _indices")
280*da0073e9SAndroid Build Coastguard Worker                printed.append(str(x._indices()))
281*da0073e9SAndroid Build Coastguard Worker                printed.append("# _values")
282*da0073e9SAndroid Build Coastguard Worker                printed.append(str(x._values()))
283*da0073e9SAndroid Build Coastguard Worker            printed.append('')
284*da0073e9SAndroid Build Coastguard Worker        self.assertExpected('\n'.join(printed))
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
287*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
288*da0073e9SAndroid Build Coastguard Worker    def test_basic(self, device, dtype, coalesced):
289*da0073e9SAndroid Build Coastguard Worker        def test_shape(sparse_dims, nnz, with_size):
290*da0073e9SAndroid Build Coastguard Worker            if isinstance(with_size, Number):
291*da0073e9SAndroid Build Coastguard Worker                with_size = [with_size] * sparse_dims
292*da0073e9SAndroid Build Coastguard Worker            x, i, v = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)
293*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(i, x._indices())
294*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(v, x._values())
295*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.ndimension(), len(with_size))
296*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.coalesce()._nnz(), nnz if x.is_coalesced() else nnz // 2)
297*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(list(x.size()), with_size)
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker            # Test .indices() and .values()
300*da0073e9SAndroid Build Coastguard Worker            if not coalesced:
301*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, "Cannot get indices on an uncoalesced tensor"):
302*da0073e9SAndroid Build Coastguard Worker                    x.indices()
303*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, "Cannot get values on an uncoalesced tensor"):
304*da0073e9SAndroid Build Coastguard Worker                    x.values()
305*da0073e9SAndroid Build Coastguard Worker            else:
306*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x.indices(), x._indices())
307*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x.values(), x._values())
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker        test_shape(3, 10, 100)
310*da0073e9SAndroid Build Coastguard Worker        test_shape(3, 10, [100, 100, 100])
311*da0073e9SAndroid Build Coastguard Worker        test_shape(3, 10, [100, 100, 100, 5, 5, 5, 0])
312*da0073e9SAndroid Build Coastguard Worker        test_shape(3, 0, [0, 0, 100, 5, 5, 5, 0])
313*da0073e9SAndroid Build Coastguard Worker
314*da0073e9SAndroid Build Coastguard Worker        # Make sure that coalesce handles duplicate indices correctly
315*da0073e9SAndroid Build Coastguard Worker        i = self.index_tensor([[9, 0, 0, 0, 8, 1, 1, 1, 2, 7, 2, 2, 3, 4, 6, 9]], device=device)
316*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([[idx**2, idx] for idx in range(i.size(1))], dtype=dtype, device=device)
317*da0073e9SAndroid Build Coastguard Worker        x = self.sparse_tensor(i, v, torch.Size([10, 2]), dtype=dtype, device=device)
318*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.coalesce()._nnz(), 9)
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
321*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble, torch.bfloat16)
322*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.bfloat16: 1e-2})
323*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
324*da0073e9SAndroid Build Coastguard Worker    def test_coalesce(self, device, dtype, coalesced):
325*da0073e9SAndroid Build Coastguard Worker
326*da0073e9SAndroid Build Coastguard Worker        def _test_coalesce(t):
327*da0073e9SAndroid Build Coastguard Worker            tc = t.coalesce()
328*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tc.to_dense(), t.to_dense())
329*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(tc.is_coalesced())
330*da0073e9SAndroid Build Coastguard Worker            # Our code below doesn't work when nnz is 0, because
331*da0073e9SAndroid Build Coastguard Worker            # then it's a 0D tensor, not a 2D tensor.
332*da0073e9SAndroid Build Coastguard Worker            if t._nnz() == 0:
333*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t._indices(), tc._indices())
334*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t._values(), tc._values())
335*da0073e9SAndroid Build Coastguard Worker                return tc
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard Worker            value_map: Dict[Any, Any] = {}
338*da0073e9SAndroid Build Coastguard Worker            for idx, val in zip(t._indices().t(), t._values()):
339*da0073e9SAndroid Build Coastguard Worker                idx_tup = tuple(idx.tolist())
340*da0073e9SAndroid Build Coastguard Worker                if idx_tup in value_map:
341*da0073e9SAndroid Build Coastguard Worker                    value_map[idx_tup] += val
342*da0073e9SAndroid Build Coastguard Worker                else:
343*da0073e9SAndroid Build Coastguard Worker                    value_map[idx_tup] = val.clone() if isinstance(val, torch.Tensor) else val
344*da0073e9SAndroid Build Coastguard Worker
345*da0073e9SAndroid Build Coastguard Worker            new_indices = sorted(value_map.keys())
346*da0073e9SAndroid Build Coastguard Worker            _new_values = [value_map[idx] for idx in new_indices]
347*da0073e9SAndroid Build Coastguard Worker            if t._values().ndimension() < 2:
348*da0073e9SAndroid Build Coastguard Worker                new_values = t._values().new(_new_values)
349*da0073e9SAndroid Build Coastguard Worker            else:
350*da0073e9SAndroid Build Coastguard Worker                new_values = torch.stack(_new_values)
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker            new_indices = t._indices().new(new_indices).t()
353*da0073e9SAndroid Build Coastguard Worker            tg = t.new(new_indices, new_values, t.size())
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tc._indices(), tg._indices())
356*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tc._values(), tg._values())
357*da0073e9SAndroid Build Coastguard Worker
358*da0073e9SAndroid Build Coastguard Worker            if t.is_coalesced():
359*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(tc._indices(), t._indices())
360*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(tc._values(), t._values())
361*da0073e9SAndroid Build Coastguard Worker
362*da0073e9SAndroid Build Coastguard Worker        for empty_i, empty_v, empty_nnz in itertools.product([True, False], repeat=3):
363*da0073e9SAndroid Build Coastguard Worker            sparse_size = [] if empty_i else [2, 1]
364*da0073e9SAndroid Build Coastguard Worker            dense_size = [1, 0, 2] if empty_v else [1, 2]
365*da0073e9SAndroid Build Coastguard Worker            nnz = 0 if empty_nnz else 5
366*da0073e9SAndroid Build Coastguard Worker
367*da0073e9SAndroid Build Coastguard Worker            t, _, _ = self._gen_sparse(len(sparse_size), nnz, sparse_size + dense_size, dtype, device, coalesced)
368*da0073e9SAndroid Build Coastguard Worker            _test_coalesce(t)  # this tests correctness
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
371*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/89395")
372*da0073e9SAndroid Build Coastguard Worker    def test_coalesce_reference_cycle(self, device, dtype):
373*da0073e9SAndroid Build Coastguard Worker        # Test coalesce doesn't create autograd graph cycles (gh-52253)
374*da0073e9SAndroid Build Coastguard Worker
375*da0073e9SAndroid Build Coastguard Worker        # Sanity check that the helper class works as expected
376*da0073e9SAndroid Build Coastguard Worker        t = torch.rand(2)
377*da0073e9SAndroid Build Coastguard Worker        t_ref = torch._C._WeakTensorRef(t)
378*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(t_ref.expired())
379*da0073e9SAndroid Build Coastguard Worker
380*da0073e9SAndroid Build Coastguard Worker        del t
381*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(t_ref.expired())
382*da0073e9SAndroid Build Coastguard Worker
383*da0073e9SAndroid Build Coastguard Worker        def test_sparse_sum():
384*da0073e9SAndroid Build Coastguard Worker            i = torch.tensor([[0], [4]], dtype=torch.long, device=device)
385*da0073e9SAndroid Build Coastguard Worker            v = torch.tensor([[[-0.4567, -1.8797, 0.0380, 1.4316]]],
386*da0073e9SAndroid Build Coastguard Worker                             dtype=dtype, device=device)
387*da0073e9SAndroid Build Coastguard Worker            S = torch.sparse_coo_tensor(i, v)
388*da0073e9SAndroid Build Coastguard Worker            S = S.coalesce()
389*da0073e9SAndroid Build Coastguard Worker            S.requires_grad_(True)
390*da0073e9SAndroid Build Coastguard Worker            S2 = S.coalesce()
391*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(S2.is_coalesced())
392*da0073e9SAndroid Build Coastguard Worker            return torch._C._WeakTensorRef(S2)
393*da0073e9SAndroid Build Coastguard Worker
394*da0073e9SAndroid Build Coastguard Worker        ref = test_sparse_sum()
395*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(ref.expired())
396*da0073e9SAndroid Build Coastguard Worker
397*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
398*da0073e9SAndroid Build Coastguard Worker    def test_ctor_large_sizes(self, device, dtype):
399*da0073e9SAndroid Build Coastguard Worker        # Test that integer overflow is detected when computing numel
400*da0073e9SAndroid Build Coastguard Worker        # of a sparse tensor with large dimensions (gh-57416). Notice
401*da0073e9SAndroid Build Coastguard Worker        # that numel is computed internally when constructing a
402*da0073e9SAndroid Build Coastguard Worker        # tensor, hence the overflow may appear during the tensor
403*da0073e9SAndroid Build Coastguard Worker        # construction step.
404*da0073e9SAndroid Build Coastguard Worker        N = 100000
405*da0073e9SAndroid Build Coastguard Worker        indices = torch.tensor([[N, N - 1]] * 4, dtype=torch.int64, device=device)
406*da0073e9SAndroid Build Coastguard Worker        values = torch.tensor([1, 2], dtype=dtype, device=device)
407*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError,
408*da0073e9SAndroid Build Coastguard Worker                          lambda: torch.sparse_coo_tensor(
409*da0073e9SAndroid Build Coastguard Worker                              indices, values, (N + 1,) * 4, device=device))
410*da0073e9SAndroid Build Coastguard Worker
411*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
412*da0073e9SAndroid Build Coastguard Worker    def test_ctor_size_checks(self, device, dtype):
413*da0073e9SAndroid Build Coastguard Worker        indices = self.index_tensor([
414*da0073e9SAndroid Build Coastguard Worker            [0, 0, 0],
415*da0073e9SAndroid Build Coastguard Worker            [0, 3, 0],
416*da0073e9SAndroid Build Coastguard Worker            [0, 0, 0],
417*da0073e9SAndroid Build Coastguard Worker            [0, 0, 0],
418*da0073e9SAndroid Build Coastguard Worker        ], device=device)
419*da0073e9SAndroid Build Coastguard Worker        values = torch.tensor([2, 1, 3, 4], dtype=dtype, device=device)
420*da0073e9SAndroid Build Coastguard Worker
421*da0073e9SAndroid Build Coastguard Worker        # indices inconsistent with size
422*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(
423*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
424*da0073e9SAndroid Build Coastguard Worker            lambda: self.sparse_tensor(indices, values, torch.Size([2, 1, 1])))
425*da0073e9SAndroid Build Coastguard Worker
426*da0073e9SAndroid Build Coastguard Worker        # values inconsistent with size
427*da0073e9SAndroid Build Coastguard Worker        values = torch.tensor([
428*da0073e9SAndroid Build Coastguard Worker            [2, 1, 2, 1],
429*da0073e9SAndroid Build Coastguard Worker            [1, 0, 5, 2],
430*da0073e9SAndroid Build Coastguard Worker        ], dtype=dtype, device=device)
431*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(
432*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
433*da0073e9SAndroid Build Coastguard Worker            lambda: self.sparse_tensor(indices, values, torch.Size([2, 4, 2, 1])))
434*da0073e9SAndroid Build Coastguard Worker
435*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
436*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
437*da0073e9SAndroid Build Coastguard Worker    def test_ctor_is_coalesced_with_gradcheck(self, device, dtype, coalesced):
438*da0073e9SAndroid Build Coastguard Worker        for sparse_size, nnz in (((3, 3), 5), ((2, 3, 1, 5), 11)):
439*da0073e9SAndroid Build Coastguard Worker            t, _, _ = self._gen_sparse(len(sparse_size), nnz, sparse_size, dtype, device, coalesced)
440*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t.is_coalesced(), coalesced)
441*da0073e9SAndroid Build Coastguard Worker
442*da0073e9SAndroid Build Coastguard Worker            def func(indices, values, shape, is_coalesced):
443*da0073e9SAndroid Build Coastguard Worker                s = torch.sparse_coo_tensor(indices, values, shape, check_invariants=True, is_coalesced=is_coalesced)
444*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(s.is_coalesced(), is_coalesced)
445*da0073e9SAndroid Build Coastguard Worker                return s.to_dense(masked_grad=False)
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Worker            if coalesced:
448*da0073e9SAndroid Build Coastguard Worker                torch.autograd.gradcheck(func, (t._indices(), t._values().requires_grad_(True), t.shape, False))
449*da0073e9SAndroid Build Coastguard Worker                torch.autograd.gradcheck(func, (t._indices(), t._values().requires_grad_(True), t.shape, True))
450*da0073e9SAndroid Build Coastguard Worker            else:
451*da0073e9SAndroid Build Coastguard Worker                torch.autograd.gradcheck(func, (t._indices(), t._values().requires_grad_(True), t.shape, False))
452*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError,
453*da0073e9SAndroid Build Coastguard Worker                                            "cannot set is_coalesced to true if indices correspond to uncoalesced COO tensor"):
454*da0073e9SAndroid Build Coastguard Worker                    torch.autograd.gradcheck(func, (t._indices(), t._values().requires_grad_(True), t.shape, True))
455*da0073e9SAndroid Build Coastguard Worker
456*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types_and(torch.float16, torch.bfloat16))
457*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
458*da0073e9SAndroid Build Coastguard Worker    @gradcheck_semantics()
459*da0073e9SAndroid Build Coastguard Worker    def test_to_dense_with_gradcheck(self, device, dtype, gradcheck):
460*da0073e9SAndroid Build Coastguard Worker
461*da0073e9SAndroid Build Coastguard Worker        def test_tensor(x, res):
462*da0073e9SAndroid Build Coastguard Worker            x.to_dense()  # Tests triple to_dense for memory corruption
463*da0073e9SAndroid Build Coastguard Worker            x.to_dense()
464*da0073e9SAndroid Build Coastguard Worker            x.to_dense()
465*da0073e9SAndroid Build Coastguard Worker            dense_x = x.to_dense()
466*da0073e9SAndroid Build Coastguard Worker            safe_dense_x = self.safeToDense(x)
467*da0073e9SAndroid Build Coastguard Worker            dense_x = dense_x.to(res.dtype)
468*da0073e9SAndroid Build Coastguard Worker            safe_dense_x = safe_dense_x.to(res.dtype)
469*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, dense_x)
470*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, safe_dense_x)
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Worker            # Only run autograd test for float64
473*da0073e9SAndroid Build Coastguard Worker            if x.dtype != torch.float64:
474*da0073e9SAndroid Build Coastguard Worker                return
475*da0073e9SAndroid Build Coastguard Worker
476*da0073e9SAndroid Build Coastguard Worker            def fn(x):
477*da0073e9SAndroid Build Coastguard Worker                return x.to_dense(masked_grad=gradcheck.masked)
478*da0073e9SAndroid Build Coastguard Worker            x.requires_grad_(True)
479*da0073e9SAndroid Build Coastguard Worker            gradcheck(fn, (x,))
480*da0073e9SAndroid Build Coastguard Worker
481*da0073e9SAndroid Build Coastguard Worker        for value_type in [torch.double, torch.cdouble]:
482*da0073e9SAndroid Build Coastguard Worker            i = self.index_tensor([
483*da0073e9SAndroid Build Coastguard Worker                [0, 1, 2, 2],
484*da0073e9SAndroid Build Coastguard Worker                [0, 0, 0, 3],
485*da0073e9SAndroid Build Coastguard Worker                [0, 0, 1, 4],
486*da0073e9SAndroid Build Coastguard Worker            ], device=device)
487*da0073e9SAndroid Build Coastguard Worker            # we don't have to_dense for half types on CPU because it is implemented
488*da0073e9SAndroid Build Coastguard Worker            # with a slower add_ operation
489*da0073e9SAndroid Build Coastguard Worker            v = torch.tensor([2, 1, 3, 4], dtype=dtype, device=device)
490*da0073e9SAndroid Build Coastguard Worker            x = self.sparse_tensor(i, v, torch.Size([3, 4, 5]), dtype=value_type, device=device)
491*da0073e9SAndroid Build Coastguard Worker            res = torch.tensor([
492*da0073e9SAndroid Build Coastguard Worker                [[2, 0, 0, 0, 0],
493*da0073e9SAndroid Build Coastguard Worker                 [0, 0, 0, 0, 0],
494*da0073e9SAndroid Build Coastguard Worker                 [0, 0, 0, 0, 0],
495*da0073e9SAndroid Build Coastguard Worker                 [0, 0, 0, 0, 0]],
496*da0073e9SAndroid Build Coastguard Worker                [[1, 0, 0, 0, 0],
497*da0073e9SAndroid Build Coastguard Worker                 [0, 0, 0, 0, 0],
498*da0073e9SAndroid Build Coastguard Worker                 [0, 0, 0, 0, 0],
499*da0073e9SAndroid Build Coastguard Worker                 [0, 0, 0, 0, 0]],
500*da0073e9SAndroid Build Coastguard Worker                [[0, 3, 0, 0, 0],
501*da0073e9SAndroid Build Coastguard Worker                 [0, 0, 0, 0, 0],
502*da0073e9SAndroid Build Coastguard Worker                 [0, 0, 0, 0, 0],
503*da0073e9SAndroid Build Coastguard Worker                 [0, 0, 0, 0, 4]],
504*da0073e9SAndroid Build Coastguard Worker            ], dtype=dtype, device=device)
505*da0073e9SAndroid Build Coastguard Worker
506*da0073e9SAndroid Build Coastguard Worker            test_tensor(x, res)
507*da0073e9SAndroid Build Coastguard Worker            test_tensor(res, res)
508*da0073e9SAndroid Build Coastguard Worker
509*da0073e9SAndroid Build Coastguard Worker            i = self.index_tensor([
510*da0073e9SAndroid Build Coastguard Worker                [0, 1, 2, 2],
511*da0073e9SAndroid Build Coastguard Worker                [0, 0, 0, 3],
512*da0073e9SAndroid Build Coastguard Worker                [0, 0, 1, 4],
513*da0073e9SAndroid Build Coastguard Worker            ], device=device)
514*da0073e9SAndroid Build Coastguard Worker            v = torch.empty(4, 0, dtype=dtype, device=device)
515*da0073e9SAndroid Build Coastguard Worker            x = self.sparse_tensor(i, v, torch.Size([3, 4, 5, 0]), dtype=value_type, device=device)
516*da0073e9SAndroid Build Coastguard Worker            res = torch.empty((3, 4, 5, 0), dtype=dtype, device=device)
517*da0073e9SAndroid Build Coastguard Worker            test_tensor(x, res)
518*da0073e9SAndroid Build Coastguard Worker
519*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
520*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float16, torch.bfloat16, torch.float64, torch.int, torch.cfloat, torch.cdouble)
521*da0073e9SAndroid Build Coastguard Worker    def test_to_sparse(self, device, dtype, coalesced):
522*da0073e9SAndroid Build Coastguard Worker        shape = [5, 2, 10, 4]
523*da0073e9SAndroid Build Coastguard Worker        max_nnz = 1
524*da0073e9SAndroid Build Coastguard Worker        for value_type in [torch.double, torch.cdouble]:
525*da0073e9SAndroid Build Coastguard Worker            for dim, dim_sz in enumerate(shape, 1):
526*da0073e9SAndroid Build Coastguard Worker                max_nnz *= dim_sz
527*da0073e9SAndroid Build Coastguard Worker                rnnz = torch.randint(2, max_nnz, (1,)).item()
528*da0073e9SAndroid Build Coastguard Worker                for nnz in [0, 1, rnnz]:
529*da0073e9SAndroid Build Coastguard Worker                    expected, _, _ = self._gen_sparse(dim, nnz, shape, dtype=value_type, device=device,
530*da0073e9SAndroid Build Coastguard Worker                                                      coalesced=coalesced)
531*da0073e9SAndroid Build Coastguard Worker                    expected = expected.to(dtype)
532*da0073e9SAndroid Build Coastguard Worker
533*da0073e9SAndroid Build Coastguard Worker                    d = expected.to_dense()
534*da0073e9SAndroid Build Coastguard Worker                    result = d.to_sparse(dim)
535*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(d, result.to_dense())
536*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(expected.size(), result.size())
537*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(dim, result.sparse_dim())
538*da0073e9SAndroid Build Coastguard Worker
539*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
540*da0073e9SAndroid Build Coastguard Worker    def test_sparse_bool(self, device, dtype):
541*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([True, False], dtype=dtype, device=device).to(torch.bool)
542*da0073e9SAndroid Build Coastguard Worker        b = a.to_sparse().to_dense()
543*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, b)
544*da0073e9SAndroid Build Coastguard Worker
545*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/108667")
546*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
547*da0073e9SAndroid Build Coastguard Worker    def test_scalar(self, device, dtype):
548*da0073e9SAndroid Build Coastguard Worker        # tensor with value
549*da0073e9SAndroid Build Coastguard Worker        a = self.sparse_tensor(self.index_tensor([], device=device).unsqueeze(1), 12.3, [], dtype=dtype, device=device)
550*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(1, a._values().numel())
551*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, a.clone())
552*da0073e9SAndroid Build Coastguard Worker        a_coalesced = a.coalesce()
553*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(a_coalesced.is_coalesced())
554*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor(12.3, dtype=dtype, device=device), a.to_dense())
555*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, a.to_dense().to_sparse())
556*da0073e9SAndroid Build Coastguard Worker
557*da0073e9SAndroid Build Coastguard Worker        # tensor with multiple values
558*da0073e9SAndroid Build Coastguard Worker        a = self.sparse_tensor(self.index_tensor([], device=device).unsqueeze(1).expand(0, 2),
559*da0073e9SAndroid Build Coastguard Worker                               [12.3, 12.3], [], dtype=dtype, device=device)
560*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(2, a._values().numel())
561*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, a.clone())
562*da0073e9SAndroid Build Coastguard Worker        a_coalesced = a.coalesce()
563*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(a_coalesced.is_coalesced())
564*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor(12.3 * 2, dtype=dtype, device=device), a.to_dense())
565*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.coalesce(), a.coalesce().to_dense().to_sparse())
566*da0073e9SAndroid Build Coastguard Worker
567*da0073e9SAndroid Build Coastguard Worker        # tensor without value
568*da0073e9SAndroid Build Coastguard Worker        a = self.sparse_empty((), dtype=dtype, device=device)
569*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, a._values().numel())
570*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, a.clone())
571*da0073e9SAndroid Build Coastguard Worker        a_coalesced = a.coalesce()
572*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(a_coalesced.is_coalesced())
573*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor(0, dtype=dtype, device=device), a.to_dense())
574*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, a.to_dense().to_sparse())
575*da0073e9SAndroid Build Coastguard Worker
576*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
577*da0073e9SAndroid Build Coastguard Worker    def test_shared(self, device, dtype):
578*da0073e9SAndroid Build Coastguard Worker        i = self.index_tensor([[2]], device=device)
579*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([5], dtype=dtype, device=device)
580*da0073e9SAndroid Build Coastguard Worker        x = self.sparse_tensor(i, v, torch.Size([3]))
581*da0073e9SAndroid Build Coastguard Worker        v[0] = 6
582*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor([0, 0, 6], dtype=dtype, device=device), self.safeToDense(x))
583*da0073e9SAndroid Build Coastguard Worker        i[0][0] = 0
584*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor([6, 0, 0], dtype=dtype, device=device), self.safeToDense(x))
585*da0073e9SAndroid Build Coastguard Worker
586*da0073e9SAndroid Build Coastguard Worker        i = self.index_tensor([[2]], device=device)
587*da0073e9SAndroid Build Coastguard Worker        v = torch.empty((1, 0), dtype=dtype, device=device)
588*da0073e9SAndroid Build Coastguard Worker        x = self.sparse_tensor(i, v, torch.Size([3, 0]))
589*da0073e9SAndroid Build Coastguard Worker        i[0][0] = 0
590*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.empty((3, 0), dtype=dtype, device=device), self.safeToDense(x))
591*da0073e9SAndroid Build Coastguard Worker
592*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
593*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
594*da0073e9SAndroid Build Coastguard Worker    @gradcheck_semantics()
595*da0073e9SAndroid Build Coastguard Worker    def test_to_dense_hybrid(self, device, dtype, gradcheck):
596*da0073e9SAndroid Build Coastguard Worker
597*da0073e9SAndroid Build Coastguard Worker        def test_tensor(x, res):
598*da0073e9SAndroid Build Coastguard Worker            x.to_dense()  # Tests double to_dense for memory corruption
599*da0073e9SAndroid Build Coastguard Worker            x.to_dense()
600*da0073e9SAndroid Build Coastguard Worker            x.to_dense()
601*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, x.to_dense())
602*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, self.safeToDense(x))
603*da0073e9SAndroid Build Coastguard Worker
604*da0073e9SAndroid Build Coastguard Worker            def fn(x):
605*da0073e9SAndroid Build Coastguard Worker                return x.to_dense(masked_grad=gradcheck.masked)
606*da0073e9SAndroid Build Coastguard Worker            x.requires_grad_(True)
607*da0073e9SAndroid Build Coastguard Worker            gradcheck(fn, (x,))
608*da0073e9SAndroid Build Coastguard Worker
609*da0073e9SAndroid Build Coastguard Worker        i = self.index_tensor([
610*da0073e9SAndroid Build Coastguard Worker            [0, 1, 2, 2],
611*da0073e9SAndroid Build Coastguard Worker            [0, 0, 0, 3],
612*da0073e9SAndroid Build Coastguard Worker        ], device=device)
613*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([[2, 3], [1, 2], [3, 4], [4, 5]], dtype=dtype, device=device)
614*da0073e9SAndroid Build Coastguard Worker        x = self.sparse_tensor(i, v, torch.Size([3, 4, 2]))
615*da0073e9SAndroid Build Coastguard Worker        res = torch.tensor([
616*da0073e9SAndroid Build Coastguard Worker            [[2, 3],
617*da0073e9SAndroid Build Coastguard Worker             [0, 0],
618*da0073e9SAndroid Build Coastguard Worker             [0, 0],
619*da0073e9SAndroid Build Coastguard Worker             [0, 0]],
620*da0073e9SAndroid Build Coastguard Worker            [[1, 2],
621*da0073e9SAndroid Build Coastguard Worker             [0, 0],
622*da0073e9SAndroid Build Coastguard Worker             [0, 0],
623*da0073e9SAndroid Build Coastguard Worker             [0, 0]],
624*da0073e9SAndroid Build Coastguard Worker            [[3, 4],
625*da0073e9SAndroid Build Coastguard Worker             [0, 0],
626*da0073e9SAndroid Build Coastguard Worker             [0, 0],
627*da0073e9SAndroid Build Coastguard Worker             [4, 5]],
628*da0073e9SAndroid Build Coastguard Worker        ], dtype=dtype, device=device)
629*da0073e9SAndroid Build Coastguard Worker        test_tensor(x, res)
630*da0073e9SAndroid Build Coastguard Worker
631*da0073e9SAndroid Build Coastguard Worker        i = self.index_tensor([
632*da0073e9SAndroid Build Coastguard Worker            [0, 1, 2, 2],
633*da0073e9SAndroid Build Coastguard Worker            [0, 0, 0, 3],
634*da0073e9SAndroid Build Coastguard Worker        ], device=device)
635*da0073e9SAndroid Build Coastguard Worker        v = torch.empty((4, 2, 0), dtype=dtype, device=device)
636*da0073e9SAndroid Build Coastguard Worker        x = self.sparse_tensor(i, v, torch.Size([3, 4, 2, 0]))
637*da0073e9SAndroid Build Coastguard Worker        res = torch.empty((3, 4, 2, 0), dtype=dtype, device=device)
638*da0073e9SAndroid Build Coastguard Worker        test_tensor(x, res)
639*da0073e9SAndroid Build Coastguard Worker
640*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
641*da0073e9SAndroid Build Coastguard Worker    def test_contig(self, device, dtype):
642*da0073e9SAndroid Build Coastguard Worker        def test_tensor(x, exp_i, exp_v):
643*da0073e9SAndroid Build Coastguard Worker            x = x.coalesce()
644*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(exp_i, x._indices())
645*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(exp_v, x._values())
646*da0073e9SAndroid Build Coastguard Worker
647*da0073e9SAndroid Build Coastguard Worker        i = self.index_tensor([
648*da0073e9SAndroid Build Coastguard Worker            [1, 0, 35, 14, 39, 6, 71, 66, 40, 27],
649*da0073e9SAndroid Build Coastguard Worker            [92, 31, 62, 50, 22, 65, 89, 74, 56, 34],
650*da0073e9SAndroid Build Coastguard Worker        ], device=device)
651*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype, device=device)
652*da0073e9SAndroid Build Coastguard Worker        x = self.sparse_tensor(i, v, torch.Size([100, 100]))
653*da0073e9SAndroid Build Coastguard Worker        exp_i = self.index_tensor([
654*da0073e9SAndroid Build Coastguard Worker            [0, 1, 6, 14, 27, 35, 39, 40, 66, 71],
655*da0073e9SAndroid Build Coastguard Worker            [31, 92, 65, 50, 34, 62, 22, 56, 74, 89],
656*da0073e9SAndroid Build Coastguard Worker        ], device=device)
657*da0073e9SAndroid Build Coastguard Worker        exp_v = torch.tensor([2, 1, 6, 4, 10, 3, 5, 9, 8, 7], dtype=dtype, device=device)
658*da0073e9SAndroid Build Coastguard Worker        test_tensor(x, exp_i, exp_v)
659*da0073e9SAndroid Build Coastguard Worker
660*da0073e9SAndroid Build Coastguard Worker        i = self.index_tensor([
661*da0073e9SAndroid Build Coastguard Worker            [2, 0, 2, 1],
662*da0073e9SAndroid Build Coastguard Worker            [0, 0, 3, 0],
663*da0073e9SAndroid Build Coastguard Worker            [1, 0, 4, 0],
664*da0073e9SAndroid Build Coastguard Worker        ], device=device)
665*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([3, 2, 4, 1], dtype=dtype, device=device)
666*da0073e9SAndroid Build Coastguard Worker        x = self.sparse_tensor(i, v, torch.Size([3, 4, 5]))
667*da0073e9SAndroid Build Coastguard Worker        exp_i = self.index_tensor([
668*da0073e9SAndroid Build Coastguard Worker            [0, 1, 2, 2],
669*da0073e9SAndroid Build Coastguard Worker            [0, 0, 0, 3],
670*da0073e9SAndroid Build Coastguard Worker            [0, 0, 1, 4],
671*da0073e9SAndroid Build Coastguard Worker        ], device=device)
672*da0073e9SAndroid Build Coastguard Worker        exp_v = torch.tensor([2, 1, 3, 4], dtype=dtype, device=device)
673*da0073e9SAndroid Build Coastguard Worker        test_tensor(x, exp_i, exp_v)
674*da0073e9SAndroid Build Coastguard Worker
675*da0073e9SAndroid Build Coastguard Worker        i = self.index_tensor([
676*da0073e9SAndroid Build Coastguard Worker            [2, 0, 2, 1],
677*da0073e9SAndroid Build Coastguard Worker            [0, 0, 3, 0],
678*da0073e9SAndroid Build Coastguard Worker            [1, 0, 4, 0],
679*da0073e9SAndroid Build Coastguard Worker        ], device=device)
680*da0073e9SAndroid Build Coastguard Worker        v = torch.empty([4, 0], dtype=dtype, device=device)
681*da0073e9SAndroid Build Coastguard Worker        x = self.sparse_tensor(i, v, torch.Size([3, 4, 5, 0]))
682*da0073e9SAndroid Build Coastguard Worker        exp_i = self.index_tensor([
683*da0073e9SAndroid Build Coastguard Worker            [0, 1, 2, 2],
684*da0073e9SAndroid Build Coastguard Worker            [0, 0, 0, 3],
685*da0073e9SAndroid Build Coastguard Worker            [0, 0, 1, 4],
686*da0073e9SAndroid Build Coastguard Worker        ], device=device)
687*da0073e9SAndroid Build Coastguard Worker        exp_v = torch.empty([4, 0], dtype=dtype, device=device)
688*da0073e9SAndroid Build Coastguard Worker        test_tensor(x, exp_i, exp_v)
689*da0073e9SAndroid Build Coastguard Worker
690*da0073e9SAndroid Build Coastguard Worker        # Duplicate indices
691*da0073e9SAndroid Build Coastguard Worker        i = self.index_tensor([
692*da0073e9SAndroid Build Coastguard Worker            [0, 0, 2, 0],
693*da0073e9SAndroid Build Coastguard Worker            [0, 0, 3, 0],
694*da0073e9SAndroid Build Coastguard Worker            [0, 0, 4, 0],
695*da0073e9SAndroid Build Coastguard Worker        ], device=device)
696*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([3, 2, 4, 1], dtype=dtype, device=device)
697*da0073e9SAndroid Build Coastguard Worker        x = self.sparse_tensor(i, v, torch.Size([3, 4, 5]))
698*da0073e9SAndroid Build Coastguard Worker        exp_i = self.index_tensor([
699*da0073e9SAndroid Build Coastguard Worker            [0, 2],
700*da0073e9SAndroid Build Coastguard Worker            [0, 3],
701*da0073e9SAndroid Build Coastguard Worker            [0, 4],
702*da0073e9SAndroid Build Coastguard Worker        ], device=device)
703*da0073e9SAndroid Build Coastguard Worker        exp_v = torch.tensor([6, 4], dtype=dtype, device=device)
704*da0073e9SAndroid Build Coastguard Worker        test_tensor(x, exp_i, exp_v)
705*da0073e9SAndroid Build Coastguard Worker
706*da0073e9SAndroid Build Coastguard Worker        i = self.index_tensor([
707*da0073e9SAndroid Build Coastguard Worker            [0, 0, 2, 0],
708*da0073e9SAndroid Build Coastguard Worker            [0, 0, 3, 0],
709*da0073e9SAndroid Build Coastguard Worker            [0, 0, 4, 0],
710*da0073e9SAndroid Build Coastguard Worker        ], device=device)
711*da0073e9SAndroid Build Coastguard Worker        v = torch.empty([4, 0], dtype=dtype, device=device)
712*da0073e9SAndroid Build Coastguard Worker        x = self.sparse_tensor(i, v, torch.Size([3, 4, 5, 0]))
713*da0073e9SAndroid Build Coastguard Worker        exp_i = self.index_tensor([
714*da0073e9SAndroid Build Coastguard Worker            [0, 2],
715*da0073e9SAndroid Build Coastguard Worker            [0, 3],
716*da0073e9SAndroid Build Coastguard Worker            [0, 4],
717*da0073e9SAndroid Build Coastguard Worker        ], device=device)
718*da0073e9SAndroid Build Coastguard Worker        exp_v = torch.empty([2, 0], dtype=dtype, device=device)
719*da0073e9SAndroid Build Coastguard Worker        test_tensor(x, exp_i, exp_v)
720*da0073e9SAndroid Build Coastguard Worker
721*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
722*da0073e9SAndroid Build Coastguard Worker    def test_contig_hybrid(self, device, dtype):
723*da0073e9SAndroid Build Coastguard Worker        def test_tensor(x, exp_i, exp_v):
724*da0073e9SAndroid Build Coastguard Worker            x = x.coalesce()
725*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(exp_i, x._indices())
726*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(exp_v, x._values())
727*da0073e9SAndroid Build Coastguard Worker
728*da0073e9SAndroid Build Coastguard Worker        i = self.index_tensor([
729*da0073e9SAndroid Build Coastguard Worker            [1, 0, 35, 14, 39, 6, 71, 66, 40, 27],
730*da0073e9SAndroid Build Coastguard Worker            [92, 31, 62, 50, 22, 65, 89, 74, 56, 34],
731*da0073e9SAndroid Build Coastguard Worker        ], device=device)
732*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([
733*da0073e9SAndroid Build Coastguard Worker            [1, 2], [2, 3], [3, 4], [4, 5], [5, 6],
734*da0073e9SAndroid Build Coastguard Worker            [6, 7], [7, 8], [8, 9], [9, 10], [10, 11],
735*da0073e9SAndroid Build Coastguard Worker        ], dtype=dtype, device=device)
736*da0073e9SAndroid Build Coastguard Worker        x = self.sparse_tensor(i, v, torch.Size([100, 100, 2]))
737*da0073e9SAndroid Build Coastguard Worker        exp_i = self.index_tensor([
738*da0073e9SAndroid Build Coastguard Worker            [0, 1, 6, 14, 27, 35, 39, 40, 66, 71],
739*da0073e9SAndroid Build Coastguard Worker            [31, 92, 65, 50, 34, 62, 22, 56, 74, 89],
740*da0073e9SAndroid Build Coastguard Worker        ], device=device)
741*da0073e9SAndroid Build Coastguard Worker        exp_v = torch.tensor([
742*da0073e9SAndroid Build Coastguard Worker            [2, 3], [1, 2], [6, 7], [4, 5], [10, 11],
743*da0073e9SAndroid Build Coastguard Worker            [3, 4], [5, 6], [9, 10], [8, 9], [7, 8],
744*da0073e9SAndroid Build Coastguard Worker        ], dtype=dtype, device=device)
745*da0073e9SAndroid Build Coastguard Worker        test_tensor(x, exp_i, exp_v)
746*da0073e9SAndroid Build Coastguard Worker
747*da0073e9SAndroid Build Coastguard Worker        i = self.index_tensor([
748*da0073e9SAndroid Build Coastguard Worker            [2, 0, 2, 1],
749*da0073e9SAndroid Build Coastguard Worker            [0, 0, 3, 0],
750*da0073e9SAndroid Build Coastguard Worker            [1, 0, 4, 0],
751*da0073e9SAndroid Build Coastguard Worker        ], device=device)
752*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([[3, 3, 3], [2, 2, 2], [4, 4, 4], [1, 1, 1]], dtype=dtype, device=device)
753*da0073e9SAndroid Build Coastguard Worker        x = self.sparse_tensor(i, v, torch.Size([3, 4, 5, 3]))
754*da0073e9SAndroid Build Coastguard Worker        exp_i = self.index_tensor([
755*da0073e9SAndroid Build Coastguard Worker            [0, 1, 2, 2],
756*da0073e9SAndroid Build Coastguard Worker            [0, 0, 0, 3],
757*da0073e9SAndroid Build Coastguard Worker            [0, 0, 1, 4],
758*da0073e9SAndroid Build Coastguard Worker        ], device=device)
759*da0073e9SAndroid Build Coastguard Worker        exp_v = torch.tensor([[2, 2, 2], [1, 1, 1], [3, 3, 3], [4, 4, 4]], dtype=dtype, device=device)
760*da0073e9SAndroid Build Coastguard Worker        test_tensor(x, exp_i, exp_v)
761*da0073e9SAndroid Build Coastguard Worker
762*da0073e9SAndroid Build Coastguard Worker        i = self.index_tensor([
763*da0073e9SAndroid Build Coastguard Worker            [2, 0, 2, 1],
764*da0073e9SAndroid Build Coastguard Worker            [0, 0, 3, 0],
765*da0073e9SAndroid Build Coastguard Worker            [1, 0, 4, 0],
766*da0073e9SAndroid Build Coastguard Worker        ], device=device)
767*da0073e9SAndroid Build Coastguard Worker        v = torch.empty([4, 3, 0], dtype=dtype, device=device)
768*da0073e9SAndroid Build Coastguard Worker        x = self.sparse_tensor(i, v, torch.Size([3, 4, 5, 3, 0]))
769*da0073e9SAndroid Build Coastguard Worker        exp_i = self.index_tensor([
770*da0073e9SAndroid Build Coastguard Worker            [0, 1, 2, 2],
771*da0073e9SAndroid Build Coastguard Worker            [0, 0, 0, 3],
772*da0073e9SAndroid Build Coastguard Worker            [0, 0, 1, 4],
773*da0073e9SAndroid Build Coastguard Worker        ], device=device)
774*da0073e9SAndroid Build Coastguard Worker        exp_v = torch.empty([4, 3, 0], dtype=dtype, device=device)
775*da0073e9SAndroid Build Coastguard Worker        test_tensor(x, exp_i, exp_v)
776*da0073e9SAndroid Build Coastguard Worker
777*da0073e9SAndroid Build Coastguard Worker        # Duplicate indices
778*da0073e9SAndroid Build Coastguard Worker        i = self.index_tensor([
779*da0073e9SAndroid Build Coastguard Worker            [0, 0, 2, 0],
780*da0073e9SAndroid Build Coastguard Worker            [0, 0, 3, 0],
781*da0073e9SAndroid Build Coastguard Worker            [0, 0, 4, 0],
782*da0073e9SAndroid Build Coastguard Worker        ], device=device)
783*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([[3, 2, 3], [2, 1, 1], [4, 3, 4], [1, 1, 1]], dtype=dtype, device=device)
784*da0073e9SAndroid Build Coastguard Worker        x = self.sparse_tensor(i, v, torch.Size([3, 4, 5, 3]))
785*da0073e9SAndroid Build Coastguard Worker        exp_i = self.index_tensor([
786*da0073e9SAndroid Build Coastguard Worker            [0, 2],
787*da0073e9SAndroid Build Coastguard Worker            [0, 3],
788*da0073e9SAndroid Build Coastguard Worker            [0, 4],
789*da0073e9SAndroid Build Coastguard Worker        ], device=device)
790*da0073e9SAndroid Build Coastguard Worker        exp_v = torch.tensor([[6, 4, 5], [4, 3, 4]], dtype=dtype, device=device)
791*da0073e9SAndroid Build Coastguard Worker        test_tensor(x, exp_i, exp_v)
792*da0073e9SAndroid Build Coastguard Worker
793*da0073e9SAndroid Build Coastguard Worker        i = self.index_tensor([
794*da0073e9SAndroid Build Coastguard Worker            [0, 0, 2, 0],
795*da0073e9SAndroid Build Coastguard Worker            [0, 0, 3, 0],
796*da0073e9SAndroid Build Coastguard Worker            [0, 0, 4, 0],
797*da0073e9SAndroid Build Coastguard Worker        ], device=device)
798*da0073e9SAndroid Build Coastguard Worker        v = torch.empty([4, 3, 0], dtype=dtype, device=device)
799*da0073e9SAndroid Build Coastguard Worker        x = self.sparse_tensor(i, v, torch.Size([3, 4, 5, 3, 0]))
800*da0073e9SAndroid Build Coastguard Worker        exp_i = self.index_tensor([
801*da0073e9SAndroid Build Coastguard Worker            [0, 2],
802*da0073e9SAndroid Build Coastguard Worker            [0, 3],
803*da0073e9SAndroid Build Coastguard Worker            [0, 4],
804*da0073e9SAndroid Build Coastguard Worker        ], device=device)
805*da0073e9SAndroid Build Coastguard Worker        exp_v = torch.empty([2, 3, 0], dtype=dtype, device=device)
806*da0073e9SAndroid Build Coastguard Worker        test_tensor(x, exp_i, exp_v)
807*da0073e9SAndroid Build Coastguard Worker
808*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
809*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
810*da0073e9SAndroid Build Coastguard Worker    def test_clone(self, device, dtype, coalesced):
811*da0073e9SAndroid Build Coastguard Worker        def test_shape(sparse_dims, nnz, with_size):
812*da0073e9SAndroid Build Coastguard Worker            x = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0]
813*da0073e9SAndroid Build Coastguard Worker            if not coalesced:
814*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(x.is_coalesced())
815*da0073e9SAndroid Build Coastguard Worker                y = x.clone()
816*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(y.is_coalesced())
817*da0073e9SAndroid Build Coastguard Worker            x = x.coalesce()
818*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(x.is_coalesced())
819*da0073e9SAndroid Build Coastguard Worker            y = x.clone()
820*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(y.is_coalesced())
821*da0073e9SAndroid Build Coastguard Worker
822*da0073e9SAndroid Build Coastguard Worker        test_shape(4, 20, 5)
823*da0073e9SAndroid Build Coastguard Worker        test_shape(3, 10, [100, 100, 100, 5, 5, 5, 0])
824*da0073e9SAndroid Build Coastguard Worker        test_shape(3, 0, [0, 0, 100, 5, 5, 5, 0])
825*da0073e9SAndroid Build Coastguard Worker
826*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
827*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble, torch.bfloat16)
828*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.bfloat16: 2e-2})
829*da0073e9SAndroid Build Coastguard Worker    def test_Sparse_to_Sparse_copy_(self, device, dtype, coalesced):
830*da0073e9SAndroid Build Coastguard Worker        # This is for testing torch.copy_(SparseTensor, SparseTensor)
831*da0073e9SAndroid Build Coastguard Worker        sparse_dims = 3
832*da0073e9SAndroid Build Coastguard Worker        nnz = 10
833*da0073e9SAndroid Build Coastguard Worker        sizes = [2, 3, 4, 5]  # hybrid sparse
834*da0073e9SAndroid Build Coastguard Worker        x1, _, _ = self._gen_sparse(sparse_dims, nnz, sizes, dtype, device, coalesced)
835*da0073e9SAndroid Build Coastguard Worker        x2, _, _ = self._gen_sparse(sparse_dims, nnz + 10, sizes, dtype, device, coalesced)
836*da0073e9SAndroid Build Coastguard Worker
837*da0073e9SAndroid Build Coastguard Worker        # test copy
838*da0073e9SAndroid Build Coastguard Worker        x2_dense = x2.to_dense()
839*da0073e9SAndroid Build Coastguard Worker        x1.copy_(x2)
840*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x2_dense, x1.to_dense())
841*da0073e9SAndroid Build Coastguard Worker
842*da0073e9SAndroid Build Coastguard Worker        # test type conversion (when x1.copy_(x2), x1.dtype should stay the same)
843*da0073e9SAndroid Build Coastguard Worker        x1 = x1.to(torch.float32)
844*da0073e9SAndroid Build Coastguard Worker
845*da0073e9SAndroid Build Coastguard Worker        x2 = x2.to(torch.float16)
846*da0073e9SAndroid Build Coastguard Worker        x1_dtype = x1.dtype
847*da0073e9SAndroid Build Coastguard Worker        x1.copy_(x2)
848*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x1_dtype, x1.dtype)
849*da0073e9SAndroid Build Coastguard Worker
850*da0073e9SAndroid Build Coastguard Worker        x2 = x2.to(torch.float64)
851*da0073e9SAndroid Build Coastguard Worker        x1_dtype = x1.dtype
852*da0073e9SAndroid Build Coastguard Worker        x1.copy_(x2)
853*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x1_dtype, x1.dtype)
854*da0073e9SAndroid Build Coastguard Worker
855*da0073e9SAndroid Build Coastguard Worker        # test no broadcast
856*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: x1.copy_(x2.narrow_copy(0, 0, 1)))
857*da0073e9SAndroid Build Coastguard Worker
858*da0073e9SAndroid Build Coastguard Worker        # test raise error on copy_() between dense and sparse Tensors
859*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: x1.copy_(torch.randn(5, 5)))
860*da0073e9SAndroid Build Coastguard Worker
861*da0073e9SAndroid Build Coastguard Worker        # test autograd
862*da0073e9SAndroid Build Coastguard Worker        x1, _, _ = self._gen_sparse(sparse_dims, nnz, sizes, dtype, device, coalesced)
863*da0073e9SAndroid Build Coastguard Worker        x2, _, _ = self._gen_sparse(sparse_dims, nnz + 10, sizes, dtype, device, coalesced)
864*da0073e9SAndroid Build Coastguard Worker        x2.requires_grad_(True)
865*da0073e9SAndroid Build Coastguard Worker        x1.copy_(x2)
866*da0073e9SAndroid Build Coastguard Worker        y = x1 * 2
867*da0073e9SAndroid Build Coastguard Worker        x2_clone = x2.clone()
868*da0073e9SAndroid Build Coastguard Worker        y.backward(x2_clone)
869*da0073e9SAndroid Build Coastguard Worker        expected_grad = x2_clone * 2
870*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_grad.to_dense(), x2.grad.to_dense())
871*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(None, x1.grad)
872*da0073e9SAndroid Build Coastguard Worker
873*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
874*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
875*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
876*da0073e9SAndroid Build Coastguard Worker    def test_Sparse_to_Sparse_copy_multi_gpu(self, device, dtype, coalesced):
877*da0073e9SAndroid Build Coastguard Worker        # This is for testing torch.copy_(SparseTensor, SparseTensor) across GPU devices
878*da0073e9SAndroid Build Coastguard Worker        sparse_dims = 3
879*da0073e9SAndroid Build Coastguard Worker        nnz = 10
880*da0073e9SAndroid Build Coastguard Worker        sizes = [2, 3, 4, 5]  # hybrid sparse
881*da0073e9SAndroid Build Coastguard Worker        x1, _, _ = self._gen_sparse(sparse_dims, nnz, sizes, dtype, device, coalesced)
882*da0073e9SAndroid Build Coastguard Worker        x2, _, _ = self._gen_sparse(sparse_dims, nnz + 10, sizes, dtype, device, coalesced)
883*da0073e9SAndroid Build Coastguard Worker        x1 = x1.to('cuda:0')
884*da0073e9SAndroid Build Coastguard Worker
885*da0073e9SAndroid Build Coastguard Worker        def test_cross_device(x1, x2):
886*da0073e9SAndroid Build Coastguard Worker            x1_device = x1.device
887*da0073e9SAndroid Build Coastguard Worker            x1.copy_(x2)
888*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x2.to('cuda:0').to_dense(), x1.to_dense())
889*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x1_device, x1.device)
890*da0073e9SAndroid Build Coastguard Worker
891*da0073e9SAndroid Build Coastguard Worker        test_cross_device(x1, x2.to('cuda:1'))  # test across gpu devices
892*da0073e9SAndroid Build Coastguard Worker        test_cross_device(x1, x2.to('cpu'))  # test between cpu and gpu
893*da0073e9SAndroid Build Coastguard Worker
894*da0073e9SAndroid Build Coastguard Worker        # test autograd
895*da0073e9SAndroid Build Coastguard Worker        x2 = x2.to('cuda:1')
896*da0073e9SAndroid Build Coastguard Worker        x2.requires_grad_(True)
897*da0073e9SAndroid Build Coastguard Worker        x1.copy_(x2)
898*da0073e9SAndroid Build Coastguard Worker        y = x1 * 2
899*da0073e9SAndroid Build Coastguard Worker        x2_clone = x2.clone().to('cuda:0')
900*da0073e9SAndroid Build Coastguard Worker        y.backward(x2_clone)
901*da0073e9SAndroid Build Coastguard Worker        expected_grad = x2_clone * 2
902*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_grad.to_dense(), x2.grad.to('cuda:0').to_dense())
903*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(None, x1.grad)
904*da0073e9SAndroid Build Coastguard Worker
905*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
906*da0073e9SAndroid Build Coastguard Worker    def test_cuda_empty(self, device):
907*da0073e9SAndroid Build Coastguard Worker        def test_tensor(x):
908*da0073e9SAndroid Build Coastguard Worker            y = x.to(device)
909*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.sparse_dim(), y.sparse_dim())
910*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.dense_dim(), y.dense_dim())
911*da0073e9SAndroid Build Coastguard Worker            x = y.cpu()
912*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y.sparse_dim(), x.sparse_dim())
913*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y.dense_dim(), x.dense_dim())
914*da0073e9SAndroid Build Coastguard Worker
915*da0073e9SAndroid Build Coastguard Worker        x = torch.sparse_coo_tensor((2, 3, 4), dtype=torch.float32)
916*da0073e9SAndroid Build Coastguard Worker        test_tensor(x)
917*da0073e9SAndroid Build Coastguard Worker
918*da0073e9SAndroid Build Coastguard Worker        x = torch.sparse_coo_tensor((2, 3, 4), dtype=torch.float16)
919*da0073e9SAndroid Build Coastguard Worker        test_tensor(x)
920*da0073e9SAndroid Build Coastguard Worker
921*da0073e9SAndroid Build Coastguard Worker        x = torch.sparse_coo_tensor((2, 3, 4), dtype=torch.float16)
922*da0073e9SAndroid Build Coastguard Worker        test_tensor(x)
923*da0073e9SAndroid Build Coastguard Worker
924*da0073e9SAndroid Build Coastguard Worker        x = torch.sparse_coo_tensor((2, 3, 4, 0), dtype=torch.float32)
925*da0073e9SAndroid Build Coastguard Worker        test_tensor(x)
926*da0073e9SAndroid Build Coastguard Worker
927*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
928*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
929*da0073e9SAndroid Build Coastguard Worker    def test_transpose(self, device, dtype, coalesced):
930*da0073e9SAndroid Build Coastguard Worker        def test_shape(sparse_dims, nnz, with_size):
931*da0073e9SAndroid Build Coastguard Worker            x = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0]
932*da0073e9SAndroid Build Coastguard Worker            y = self.safeToDense(x)
933*da0073e9SAndroid Build Coastguard Worker
934*da0073e9SAndroid Build Coastguard Worker            for i, j in itertools.combinations(range(4), 2):
935*da0073e9SAndroid Build Coastguard Worker                x = x.transpose_(i, j)
936*da0073e9SAndroid Build Coastguard Worker                y = y.transpose(i, j)
937*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(self.safeToDense(x), y)
938*da0073e9SAndroid Build Coastguard Worker
939*da0073e9SAndroid Build Coastguard Worker                x = x.transpose(i, j)
940*da0073e9SAndroid Build Coastguard Worker                y = y.transpose(i, j)
941*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(self.safeToDense(x), y)
942*da0073e9SAndroid Build Coastguard Worker
943*da0073e9SAndroid Build Coastguard Worker        test_shape(4, 6, 3)
944*da0073e9SAndroid Build Coastguard Worker        test_shape(4, 3, [7, 7, 7, 3, 3, 3, 0])
945*da0073e9SAndroid Build Coastguard Worker        test_shape(4, 0, [0, 0, 7, 3, 3, 3, 0])
946*da0073e9SAndroid Build Coastguard Worker
947*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
948*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
949*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
950*da0073e9SAndroid Build Coastguard Worker    @gradcheck_semantics()
951*da0073e9SAndroid Build Coastguard Worker    def test_permute(self, device, dtype, coalesced, gradcheck):
952*da0073e9SAndroid Build Coastguard Worker        # trivial checks
953*da0073e9SAndroid Build Coastguard Worker        s = torch.rand(3, 3, 3, device=device, dtype=dtype).to_sparse()
954*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "does not match the length"):
955*da0073e9SAndroid Build Coastguard Worker            s.permute(dims=(1, 0))
956*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "duplicate dims"):
957*da0073e9SAndroid Build Coastguard Worker            s.permute(dims=(1, 1, 1))
958*da0073e9SAndroid Build Coastguard Worker        # Calling permute on a sparse tensor with an empty tuple used to segfault,
959*da0073e9SAndroid Build Coastguard Worker        # see https://github.com/pytorch/pytorch/issues/116325
960*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((), device=device, dtype=dtype).to_sparse()
961*da0073e9SAndroid Build Coastguard Worker        x.permute(())
962*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(x.values()), 1)
963*da0073e9SAndroid Build Coastguard Worker
964*da0073e9SAndroid Build Coastguard Worker        def test_shape(sparse_dims, nnz, with_size):
965*da0073e9SAndroid Build Coastguard Worker            ndim = len(with_size)
966*da0073e9SAndroid Build Coastguard Worker            valid_sparse_dims = torch.arange(-ndim, -ndim + sparse_dims)
967*da0073e9SAndroid Build Coastguard Worker            valid_dense_dims = torch.arange(-ndim + sparse_dims, 0)
968*da0073e9SAndroid Build Coastguard Worker
969*da0073e9SAndroid Build Coastguard Worker            for dims in itertools.permutations(range(-ndim, 0)):
970*da0073e9SAndroid Build Coastguard Worker                s = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0]
971*da0073e9SAndroid Build Coastguard Worker                d = self.safeToDense(s)
972*da0073e9SAndroid Build Coastguard Worker
973*da0073e9SAndroid Build Coastguard Worker                dims_sparse, _ = torch.tensor(dims[:sparse_dims]).sort()
974*da0073e9SAndroid Build Coastguard Worker                dims_dense, _ = torch.tensor(dims[sparse_dims:]).sort()
975*da0073e9SAndroid Build Coastguard Worker
976*da0073e9SAndroid Build Coastguard Worker                if (valid_sparse_dims == dims_sparse).all() and (valid_dense_dims == dims_dense).all():
977*da0073e9SAndroid Build Coastguard Worker                    # if valid permutation, test for correctness
978*da0073e9SAndroid Build Coastguard Worker                    s_permuted = s.permute(dims)
979*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(s_permuted, d.permute(dims))
980*da0073e9SAndroid Build Coastguard Worker
981*da0073e9SAndroid Build Coastguard Worker                    # if s is coalesced, and perm does not touch 0-dim,
982*da0073e9SAndroid Build Coastguard Worker                    # the result has to be coalesced as well
983*da0073e9SAndroid Build Coastguard Worker                    if dims[0] == 0:
984*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(s_permuted.is_coalesced(), s.is_coalesced())
985*da0073e9SAndroid Build Coastguard Worker                    else:
986*da0073e9SAndroid Build Coastguard Worker                        self.assertFalse(s_permuted.is_coalesced())
987*da0073e9SAndroid Build Coastguard Worker
988*da0073e9SAndroid Build Coastguard Worker                    gradcheck(lambda t: t.permute(dims).to_dense(masked_grad=gradcheck.masked), s.requires_grad_())
989*da0073e9SAndroid Build Coastguard Worker                else:
990*da0073e9SAndroid Build Coastguard Worker                    # otherwise check if exception is thrown
991*da0073e9SAndroid Build Coastguard Worker                    fail_message = "transpositions between sparse and dense dimensions are not allowed"
992*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(RuntimeError, fail_message):
993*da0073e9SAndroid Build Coastguard Worker                        s.permute(dims)
994*da0073e9SAndroid Build Coastguard Worker
995*da0073e9SAndroid Build Coastguard Worker        test_shape(2, 3, [2, 3, 4, 5])
996*da0073e9SAndroid Build Coastguard Worker        test_shape(2, 3, [2, 2, 0])
997*da0073e9SAndroid Build Coastguard Worker        # if nnz=0, it is not true that t == t.to_dense().to_sparse()
998*da0073e9SAndroid Build Coastguard Worker        # unless t.sparse_dim == t.dim (i.e. t is not hybrid)
999*da0073e9SAndroid Build Coastguard Worker        test_shape(3, 0, [0, 0, 2])
1000*da0073e9SAndroid Build Coastguard Worker
1001*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
1002*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1003*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
1004*da0073e9SAndroid Build Coastguard Worker    def test_coalesce_transpose_mm(self, device, dtype, coalesced):
1005*da0073e9SAndroid Build Coastguard Worker        def test_shape(di, dj, dk, nnz):
1006*da0073e9SAndroid Build Coastguard Worker            x, _, _ = self._gen_sparse(2, nnz, [dj, di], dtype, device, coalesced)
1007*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(dj, dk, dtype=dtype, device=device)
1008*da0073e9SAndroid Build Coastguard Worker
1009*da0073e9SAndroid Build Coastguard Worker            x_coalesced = x.coalesce()
1010*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(x_coalesced.is_coalesced())
1011*da0073e9SAndroid Build Coastguard Worker
1012*da0073e9SAndroid Build Coastguard Worker            x_coalesced_t = x_coalesced.t()
1013*da0073e9SAndroid Build Coastguard Worker            # Transpose is `colasced`-preserving if the indices tensor is empty.
1014*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_coalesced_t.is_coalesced(), di * nnz == 0)
1015*da0073e9SAndroid Build Coastguard Worker
1016*da0073e9SAndroid Build Coastguard Worker            res = torch.mm(x_coalesced_t, y)
1017*da0073e9SAndroid Build Coastguard Worker            expected = torch.mm(self.safeToDense(x_coalesced_t), y)
1018*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, expected)
1019*da0073e9SAndroid Build Coastguard Worker
1020*da0073e9SAndroid Build Coastguard Worker        test_shape(10, 20, 30, 20)
1021*da0073e9SAndroid Build Coastguard Worker        test_shape(0, 20, 30, 0)
1022*da0073e9SAndroid Build Coastguard Worker        test_shape(10, 0, 30, 0)
1023*da0073e9SAndroid Build Coastguard Worker        test_shape(10, 20, 0, 0)
1024*da0073e9SAndroid Build Coastguard Worker        test_shape(10, 20, 0, 20)
1025*da0073e9SAndroid Build Coastguard Worker
1026*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1166")
1027*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
1028*da0073e9SAndroid Build Coastguard Worker    def test_t_empty(self, device, dtype):
1029*da0073e9SAndroid Build Coastguard Worker        def test_in_place(x):
1030*da0073e9SAndroid Build Coastguard Worker            shape_original = x.shape
1031*da0073e9SAndroid Build Coastguard Worker            x.t_()
1032*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.Size([shape_original[1], shape_original[0]]), x.size())
1033*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(0, x._indices().numel())
1034*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(0, x._values().numel())
1035*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.sparse_dim(), 2)
1036*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.dense_dim(), 0)
1037*da0073e9SAndroid Build Coastguard Worker
1038*da0073e9SAndroid Build Coastguard Worker        def test_not_in_place(x):
1039*da0073e9SAndroid Build Coastguard Worker            shape_original = x.shape
1040*da0073e9SAndroid Build Coastguard Worker            y = x.t()
1041*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.Size([shape_original[1], shape_original[0]]), y.size())
1042*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(0, y._indices().numel())
1043*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(0, y._values().numel())
1044*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.sparse_dim(), 2)
1045*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.dense_dim(), 0)
1046*da0073e9SAndroid Build Coastguard Worker
1047*da0073e9SAndroid Build Coastguard Worker        x = self.sparse_empty(2, 3, dtype=dtype, device=device)
1048*da0073e9SAndroid Build Coastguard Worker        test_in_place(x)
1049*da0073e9SAndroid Build Coastguard Worker        test_not_in_place(x)
1050*da0073e9SAndroid Build Coastguard Worker
1051*da0073e9SAndroid Build Coastguard Worker        x = self.sparse_empty(2, 0, dtype=dtype, device=device)
1052*da0073e9SAndroid Build Coastguard Worker        test_in_place(x)
1053*da0073e9SAndroid Build Coastguard Worker        test_not_in_place(x)
1054*da0073e9SAndroid Build Coastguard Worker
1055*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
1056*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
1057*da0073e9SAndroid Build Coastguard Worker    def test_add_zeros(self, device, dtype, coalesced):
1058*da0073e9SAndroid Build Coastguard Worker        def test_shape(sparse_dims, nnz, sizes):
1059*da0073e9SAndroid Build Coastguard Worker            x, _, _ = self._gen_sparse(sparse_dims, nnz, sizes, dtype, device, coalesced)
1060*da0073e9SAndroid Build Coastguard Worker            zeros = torch.sparse_coo_tensor(sizes, device=x.device)
1061*da0073e9SAndroid Build Coastguard Worker            r1 = zeros + x
1062*da0073e9SAndroid Build Coastguard Worker            r2 = x + zeros
1063*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(r1, x)
1064*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(r2, x)
1065*da0073e9SAndroid Build Coastguard Worker
1066*da0073e9SAndroid Build Coastguard Worker        test_shape(1, 20, [1])
1067*da0073e9SAndroid Build Coastguard Worker        test_shape(4, 20, [3, 17, 19, 5])
1068*da0073e9SAndroid Build Coastguard Worker        test_shape(2, 20, [3, 17, 19, 5])
1069*da0073e9SAndroid Build Coastguard Worker        test_shape(2, 20, [3, 17, 19, 0])
1070*da0073e9SAndroid Build Coastguard Worker
1071*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
1072*da0073e9SAndroid Build Coastguard Worker    def test_add_sub_nnz(self, device, dtype):
1073*da0073e9SAndroid Build Coastguard Worker        # nnz should not grow unbounded (gh-34964)
1074*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, dtype=dtype, device=device).to_sparse()
1075*da0073e9SAndroid Build Coastguard Worker        x.add_(x)
1076*da0073e9SAndroid Build Coastguard Worker        x.add_(x)
1077*da0073e9SAndroid Build Coastguard Worker        self.assertLessEqual(x._nnz(), 10)
1078*da0073e9SAndroid Build Coastguard Worker
1079*da0073e9SAndroid Build Coastguard Worker        x.sub_(2 * x)
1080*da0073e9SAndroid Build Coastguard Worker        x.sub_(2 * x)
1081*da0073e9SAndroid Build Coastguard Worker        self.assertLessEqual(x._nnz(), 10)
1082*da0073e9SAndroid Build Coastguard Worker
1083*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
1084*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
1085*da0073e9SAndroid Build Coastguard Worker    def test_cat(self, device, dtype, coalesced):
1086*da0073e9SAndroid Build Coastguard Worker        # shapes: list of tuples (sparse_dims, nnz, sizes)
1087*da0073e9SAndroid Build Coastguard Worker        def test_shapes(shapes, dim, fail_message=None):
1088*da0073e9SAndroid Build Coastguard Worker            inputs = [self._gen_sparse(shape[0], shape[1], shape[2], dtype, device, coalesced)[0]
1089*da0073e9SAndroid Build Coastguard Worker                      for shape in shapes]
1090*da0073e9SAndroid Build Coastguard Worker            if fail_message:
1091*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, fail_message):
1092*da0073e9SAndroid Build Coastguard Worker                    torch.cat(inputs, dim)
1093*da0073e9SAndroid Build Coastguard Worker            else:
1094*da0073e9SAndroid Build Coastguard Worker                result = torch.cat(inputs, dim)
1095*da0073e9SAndroid Build Coastguard Worker                dense_result = torch.cat([t.to_dense() for t in inputs], dim)
1096*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(dense_result, result.to_dense())
1097*da0073e9SAndroid Build Coastguard Worker
1098*da0073e9SAndroid Build Coastguard Worker        test_shapes(
1099*da0073e9SAndroid Build Coastguard Worker            [(3, 10, [2, 3, 4]), (3, 10, [2, 1, 4]), (3, 10, [2, 4, 4])], 1)
1100*da0073e9SAndroid Build Coastguard Worker
1101*da0073e9SAndroid Build Coastguard Worker        # mismatched sizes
1102*da0073e9SAndroid Build Coastguard Worker        test_shapes([(3, 10, [2, 3, 4]), (3, 10, [2, 1, 4])], 0,
1103*da0073e9SAndroid Build Coastguard Worker                    "All tensors must have the same shape: \\[2, 3, 4].*\\[2, 1, 4]")
1104*da0073e9SAndroid Build Coastguard Worker        # hybrid sparse/dense
1105*da0073e9SAndroid Build Coastguard Worker        test_shapes(
1106*da0073e9SAndroid Build Coastguard Worker            [(2, 10, [2, 3, 4]), (2, 10, [2, 1, 4]), (2, 10, [2, 4, 4])], 1)
1107*da0073e9SAndroid Build Coastguard Worker        # cat along dense dim
1108*da0073e9SAndroid Build Coastguard Worker        test_shapes([(2, 10, [2, 3, 4]), (2, 10, [2, 3, 7])], 2)
1109*da0073e9SAndroid Build Coastguard Worker        test_shapes([(1, 10, [2, 3, 4]), (1, 10, [2, 3, 4])], 1)
1110*da0073e9SAndroid Build Coastguard Worker        test_shapes([(1, 10, [2, 3, 4]), (1, 10, [2, 3, 4])], 2)
1111*da0073e9SAndroid Build Coastguard Worker        # mismatched dimensions
1112*da0073e9SAndroid Build Coastguard Worker        test_shapes([(2, 10, [2, 3, 4]), (3, 10, [2, 3, 4])], 0,
1113*da0073e9SAndroid Build Coastguard Worker                    "All tensors must have the same.*2, 1, but tensor at position 1 has 3, 0.")
1114*da0073e9SAndroid Build Coastguard Worker        # wrapped dimension
1115*da0073e9SAndroid Build Coastguard Worker        test_shapes(
1116*da0073e9SAndroid Build Coastguard Worker            [(3, 10, [2, 3, 4]), (3, 10, [2, 1, 4]), (3, 10, [2, 4, 4])], -2)
1117*da0073e9SAndroid Build Coastguard Worker
1118*da0073e9SAndroid Build Coastguard Worker        # sparse with dense
1119*da0073e9SAndroid Build Coastguard Worker        sp = self._gen_sparse(3, 10, [2, 3, 4], dtype, device, coalesced)[0]
1120*da0073e9SAndroid Build Coastguard Worker        dn = sp.to_dense()
1121*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError,
1122*da0073e9SAndroid Build Coastguard Worker                                    "Concatenating sparse tensors, but a dense tensor was found at position 1."):
1123*da0073e9SAndroid Build Coastguard Worker            torch.cat((sp, dn))
1124*da0073e9SAndroid Build Coastguard Worker
1125*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
1126*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
1127*da0073e9SAndroid Build Coastguard Worker    def test_unsqueeze(self, device, dtype, coalesced):
1128*da0073e9SAndroid Build Coastguard Worker        def test_shape(sparse_dims, nnz, sizes, unsqueeze_dim, fail_message=None):
1129*da0073e9SAndroid Build Coastguard Worker            x, _, _ = self._gen_sparse(sparse_dims, nnz, sizes, dtype, device, coalesced)
1130*da0073e9SAndroid Build Coastguard Worker            if fail_message:
1131*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(IndexError, fail_message):
1132*da0073e9SAndroid Build Coastguard Worker                    torch.unsqueeze(x, unsqueeze_dim)
1133*da0073e9SAndroid Build Coastguard Worker            else:
1134*da0073e9SAndroid Build Coastguard Worker                result = torch.unsqueeze(x, unsqueeze_dim)
1135*da0073e9SAndroid Build Coastguard Worker                dense_result = torch.unsqueeze(x.to_dense(), unsqueeze_dim)
1136*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(dense_result, result.to_dense())
1137*da0073e9SAndroid Build Coastguard Worker
1138*da0073e9SAndroid Build Coastguard Worker        # basic case
1139*da0073e9SAndroid Build Coastguard Worker        test_shape(3, 10, [5, 7, 11], 0)
1140*da0073e9SAndroid Build Coastguard Worker
1141*da0073e9SAndroid Build Coastguard Worker        # hybrid sparse/dense, unsqueeze along sparse dim
1142*da0073e9SAndroid Build Coastguard Worker        test_shape(3, 10, [5, 7, 11, 13, 17], 0)
1143*da0073e9SAndroid Build Coastguard Worker        test_shape(3, 10, [5, 7, 11, 13, 17], 3)
1144*da0073e9SAndroid Build Coastguard Worker
1145*da0073e9SAndroid Build Coastguard Worker        # unsqueeze along dense dimensions
1146*da0073e9SAndroid Build Coastguard Worker        test_shape(3, 10, [5, 7, 11, 13, 17], 4)
1147*da0073e9SAndroid Build Coastguard Worker        test_shape(3, 10, [5, 7, 11, 13, 17], 5)
1148*da0073e9SAndroid Build Coastguard Worker
1149*da0073e9SAndroid Build Coastguard Worker        # wrapped dimensions
1150*da0073e9SAndroid Build Coastguard Worker        test_shape(3, 10, [5, 7, 11, 13, 17], -1)
1151*da0073e9SAndroid Build Coastguard Worker        test_shape(3, 10, [5, 7, 11, 13, 17], -6)
1152*da0073e9SAndroid Build Coastguard Worker
1153*da0073e9SAndroid Build Coastguard Worker        # bounds
1154*da0073e9SAndroid Build Coastguard Worker        test_shape(3, 10, [5, 7, 11, 13, 17], -7, "Dimension out of range")
1155*da0073e9SAndroid Build Coastguard Worker        test_shape(3, 10, [5, 7, 11, 13, 17], 6, "Dimension out of range")
1156*da0073e9SAndroid Build Coastguard Worker
1157*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
1158*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
1159*da0073e9SAndroid Build Coastguard Worker    def test_select(self, device, dtype, coalesced):
1160*da0073e9SAndroid Build Coastguard Worker        def test_shape(sparse_dims, nnz, sizes, select_dim, select_index, fail_message=None):
1161*da0073e9SAndroid Build Coastguard Worker            x, _, _ = self._gen_sparse(sparse_dims, nnz, sizes, dtype, device, coalesced)
1162*da0073e9SAndroid Build Coastguard Worker            if fail_message:
1163*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(IndexError, fail_message):
1164*da0073e9SAndroid Build Coastguard Worker                    torch.select(x, select_dim, select_index)
1165*da0073e9SAndroid Build Coastguard Worker            else:
1166*da0073e9SAndroid Build Coastguard Worker                result = torch.select(x, select_dim, select_index)
1167*da0073e9SAndroid Build Coastguard Worker                if result.is_sparse:
1168*da0073e9SAndroid Build Coastguard Worker                    result = result.to_dense()
1169*da0073e9SAndroid Build Coastguard Worker                dense_result = torch.select(x.to_dense(), select_dim, select_index)
1170*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(dense_result, result)
1171*da0073e9SAndroid Build Coastguard Worker
1172*da0073e9SAndroid Build Coastguard Worker
1173*da0073e9SAndroid Build Coastguard Worker        sizes = [5, 7, 11, 13, 17]
1174*da0073e9SAndroid Build Coastguard Worker        # hybrid sparse/dense, select sparse dim, result is dense
1175*da0073e9SAndroid Build Coastguard Worker        for i in range(sizes[0]):
1176*da0073e9SAndroid Build Coastguard Worker            test_shape(1, 10, sizes, 0, i)
1177*da0073e9SAndroid Build Coastguard Worker        test_shape(1, 10, sizes, 0, sizes[0] + 1, r'select[(][)][:] index \d out of range.*')
1178*da0073e9SAndroid Build Coastguard Worker
1179*da0073e9SAndroid Build Coastguard Worker        # hybrid sparse/dense, select sparse dim, result is sparse
1180*da0073e9SAndroid Build Coastguard Worker        for d in range(3):
1181*da0073e9SAndroid Build Coastguard Worker            for i in range(sizes[d]):
1182*da0073e9SAndroid Build Coastguard Worker                test_shape(3, 10, sizes, d, i)
1183*da0073e9SAndroid Build Coastguard Worker
1184*da0073e9SAndroid Build Coastguard Worker        # hybrid sparse/dense, select dense dim, result is sparse
1185*da0073e9SAndroid Build Coastguard Worker        for d in range(1, 3):
1186*da0073e9SAndroid Build Coastguard Worker            for i in range(sizes[d]):
1187*da0073e9SAndroid Build Coastguard Worker                test_shape(1, 10, sizes, d, i)
1188*da0073e9SAndroid Build Coastguard Worker
1189*da0073e9SAndroid Build Coastguard Worker    @dtypes(*integral_types())
1190*da0073e9SAndroid Build Coastguard Worker    def test_select_no_type_promotion(self, device, dtype):
1191*da0073e9SAndroid Build Coastguard Worker        # see https://github.com/pytorch/pytorch/issues/82150
1192*da0073e9SAndroid Build Coastguard Worker        idx = torch.tensor([[0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 1]])
1193*da0073e9SAndroid Build Coastguard Worker        val = torch.ones(6, dtype=dtype)
1194*da0073e9SAndroid Build Coastguard Worker        s = torch.sparse_coo_tensor(idx, val, size=(3, 3))
1195*da0073e9SAndroid Build Coastguard Worker
1196*da0073e9SAndroid Build Coastguard Worker        for t in (s, s * torch.tensor(0, dtype=dtype)):
1197*da0073e9SAndroid Build Coastguard Worker            # empty checks
1198*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t.dtype, t[2].dtype)
1199*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t.dtype, t[0, 1].dtype)
1200*da0073e9SAndroid Build Coastguard Worker            # sum should not promote
1201*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t.dtype, t[0, 0].dtype)
1202*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t.dtype, t[1, 1].dtype)
1203*da0073e9SAndroid Build Coastguard Worker
1204*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
1205*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
1206*da0073e9SAndroid Build Coastguard Worker    def test_index_select(self, device, dtype, coalesced):
1207*da0073e9SAndroid Build Coastguard Worker        def test_shape(sparse_dims, nnz, sizes, select_dim, select_index, fail_message=None):
1208*da0073e9SAndroid Build Coastguard Worker            if isinstance(select_index, int):
1209*da0073e9SAndroid Build Coastguard Worker                select_index = [select_index]
1210*da0073e9SAndroid Build Coastguard Worker            if isinstance(select_index, list):
1211*da0073e9SAndroid Build Coastguard Worker                select_index = torch.tensor(select_index, device=device, dtype=torch.long)
1212*da0073e9SAndroid Build Coastguard Worker            x, _, _ = self._gen_sparse(sparse_dims, nnz, sizes, dtype, device, coalesced)
1213*da0073e9SAndroid Build Coastguard Worker            if fail_message:
1214*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(IndexError, fail_message):
1215*da0073e9SAndroid Build Coastguard Worker                    torch.index_select(x, select_dim, select_index)
1216*da0073e9SAndroid Build Coastguard Worker            else:
1217*da0073e9SAndroid Build Coastguard Worker                result = torch.index_select(x, select_dim, select_index)
1218*da0073e9SAndroid Build Coastguard Worker                if result.is_sparse:
1219*da0073e9SAndroid Build Coastguard Worker                    result = result.to_dense()
1220*da0073e9SAndroid Build Coastguard Worker                dense_result = torch.index_select(x.to_dense(), select_dim, select_index)
1221*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(dense_result, result)
1222*da0073e9SAndroid Build Coastguard Worker
1223*da0073e9SAndroid Build Coastguard Worker        sizes = [5, 7, 11, 13, 17]
1224*da0073e9SAndroid Build Coastguard Worker        for d in range(len(sizes)):
1225*da0073e9SAndroid Build Coastguard Worker            for index in [0, sizes[d] - 1, [0, sizes[d] // 2, sizes[d] - 1]]:
1226*da0073e9SAndroid Build Coastguard Worker                test_shape(1, 10, sizes, d, index)
1227*da0073e9SAndroid Build Coastguard Worker                test_shape(len(sizes) // 2, 10, sizes, d, index)
1228*da0073e9SAndroid Build Coastguard Worker                test_shape(len(sizes), 10, sizes, d, index)
1229*da0073e9SAndroid Build Coastguard Worker
1230*da0073e9SAndroid Build Coastguard Worker    def _test_index_select_exhaustive_index(self, sizes, dims, device, dtype, coalesced):
1231*da0073e9SAndroid Build Coastguard Worker        t = make_tensor(sizes, dtype=dtype, device=device)
1232*da0073e9SAndroid Build Coastguard Worker        t_sparse = t.to_sparse().coalesce() if coalesced else t.to_sparse()
1233*da0073e9SAndroid Build Coastguard Worker        t_small_sparse, _, _ = self._gen_sparse(len(sizes), 2, sizes, dtype, device, coalesced)
1234*da0073e9SAndroid Build Coastguard Worker        t_small = t_small_sparse.to_dense()
1235*da0073e9SAndroid Build Coastguard Worker        for d in dims:
1236*da0073e9SAndroid Build Coastguard Worker            # NOTE: indices are negative
1237*da0073e9SAndroid Build Coastguard Worker            idx_dim_d_range = list(range(-sizes[d], 0))
1238*da0073e9SAndroid Build Coastguard Worker            for idx_len in range(sizes[d], sizes[d] + 1):
1239*da0073e9SAndroid Build Coastguard Worker                # creates all possible valid indices into dim d of lenght idx_len
1240*da0073e9SAndroid Build Coastguard Worker                for idx in itertools.product(*itertools.repeat(idx_dim_d_range, idx_len)):
1241*da0073e9SAndroid Build Coastguard Worker                    t_idx = torch.tensor(idx, dtype=torch.long, device=device)
1242*da0073e9SAndroid Build Coastguard Worker
1243*da0073e9SAndroid Build Coastguard Worker                    # NOTE: index_select for dense does not support negative indices,
1244*da0073e9SAndroid Build Coastguard Worker                    # hence + sizes[d]. See https://github.com/pytorch/pytorch/issues/76347
1245*da0073e9SAndroid Build Coastguard Worker
1246*da0073e9SAndroid Build Coastguard Worker                    # tests the nnz > sizes[d] branch
1247*da0073e9SAndroid Build Coastguard Worker                    dense_result = t.index_select(d, t_idx + sizes[d])
1248*da0073e9SAndroid Build Coastguard Worker                    sparse_result = t_sparse.index_select(d, t_idx)
1249*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(dense_result, sparse_result)
1250*da0073e9SAndroid Build Coastguard Worker
1251*da0073e9SAndroid Build Coastguard Worker                    # tests the nnz <= sizes[d] branch
1252*da0073e9SAndroid Build Coastguard Worker                    small_dense_result = t_small.index_select(d, t_idx + sizes[d])
1253*da0073e9SAndroid Build Coastguard Worker                    small_sparse_result = t_small_sparse.index_select(d, t_idx)
1254*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(small_dense_result, small_sparse_result)
1255*da0073e9SAndroid Build Coastguard Worker
1256*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
1257*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
1258*da0073e9SAndroid Build Coastguard Worker    def test_index_select_exhaustive_index_small(self, device, dtype, coalesced):
1259*da0073e9SAndroid Build Coastguard Worker        # will trigger brute-force algo
1260*da0073e9SAndroid Build Coastguard Worker        self._test_index_select_exhaustive_index((3, 3, 4), range(3), device, dtype, coalesced)
1261*da0073e9SAndroid Build Coastguard Worker
1262*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
1263*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
1264*da0073e9SAndroid Build Coastguard Worker    def test_index_select_exhaustive_index_large(self, device, dtype, coalesced):
1265*da0073e9SAndroid Build Coastguard Worker        # will trigger more sophisticated algos
1266*da0073e9SAndroid Build Coastguard Worker        self._test_index_select_exhaustive_index((100, 50, 3, 3), (2, 3), device, dtype, coalesced)
1267*da0073e9SAndroid Build Coastguard Worker
1268*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
1269*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
1270*da0073e9SAndroid Build Coastguard Worker    def test_index_select_empty_and_non_contiguous_index(self, device, dtype, coalesced):
1271*da0073e9SAndroid Build Coastguard Worker        # empty index
1272*da0073e9SAndroid Build Coastguard Worker        idx_empty = torch.tensor([], dtype=torch.long, device=device)
1273*da0073e9SAndroid Build Coastguard Worker        t = make_tensor((5, 5), dtype=dtype, device=device)
1274*da0073e9SAndroid Build Coastguard Worker        res_dense = t.index_select(0, idx_empty)
1275*da0073e9SAndroid Build Coastguard Worker        res_sparse = t.to_sparse().index_select(0, idx_empty)
1276*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_dense, res_sparse)
1277*da0073e9SAndroid Build Coastguard Worker
1278*da0073e9SAndroid Build Coastguard Worker        # non-contigous index
1279*da0073e9SAndroid Build Coastguard Worker        idx = torch.randint(low=0, high=5, size=(10, 2), device=device)[:, 0]
1280*da0073e9SAndroid Build Coastguard Worker
1281*da0073e9SAndroid Build Coastguard Worker        def run_test(sizes):
1282*da0073e9SAndroid Build Coastguard Worker            # case nnz > size[d]
1283*da0073e9SAndroid Build Coastguard Worker            t = make_tensor(sizes, dtype=dtype, device=device)
1284*da0073e9SAndroid Build Coastguard Worker            res_dense = t.index_select(0, idx)
1285*da0073e9SAndroid Build Coastguard Worker            res_sparse = t.to_sparse().index_select(0, idx)
1286*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res_dense, res_sparse)
1287*da0073e9SAndroid Build Coastguard Worker
1288*da0073e9SAndroid Build Coastguard Worker            # case nnz <= size[d]
1289*da0073e9SAndroid Build Coastguard Worker            t_small_sparse, _, _ = self._gen_sparse(len(sizes), 2, sizes, dtype, device, coalesced)
1290*da0073e9SAndroid Build Coastguard Worker            res_sparse = t_small_sparse.index_select(0, idx)
1291*da0073e9SAndroid Build Coastguard Worker            res_dense = t_small_sparse.to_dense().index_select(0, idx)
1292*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res_dense, res_sparse)
1293*da0073e9SAndroid Build Coastguard Worker
1294*da0073e9SAndroid Build Coastguard Worker        # brute-force
1295*da0073e9SAndroid Build Coastguard Worker        run_test((10, 10))
1296*da0073e9SAndroid Build Coastguard Worker        # more sophisticated algos
1297*da0073e9SAndroid Build Coastguard Worker        run_test((10, 100, 100))
1298*da0073e9SAndroid Build Coastguard Worker
1299*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1300*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
1301*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
1302*da0073e9SAndroid Build Coastguard Worker    def test_index_select_parallelization(self, device, dtype, coalesced):
1303*da0073e9SAndroid Build Coastguard Worker        """
1304*da0073e9SAndroid Build Coastguard Worker        Test with sizes that will trigger parallelization (i.e. with sizes
1305*da0073e9SAndroid Build Coastguard Worker        that are >= at::internal::GRAIN_SIZE)
1306*da0073e9SAndroid Build Coastguard Worker        """
1307*da0073e9SAndroid Build Coastguard Worker        def run_test(nnz, size):
1308*da0073e9SAndroid Build Coastguard Worker            t_sparse, _, _ = self._gen_sparse(1, nnz, (size,), dtype, device, coalesced)
1309*da0073e9SAndroid Build Coastguard Worker            t_dense = t_sparse.to_dense()
1310*da0073e9SAndroid Build Coastguard Worker
1311*da0073e9SAndroid Build Coastguard Worker            # idx_small to (sort) and (binary) search into t_sparse
1312*da0073e9SAndroid Build Coastguard Worker            idx_small = torch.randint(size, (nnz // 2,), device=device)
1313*da0073e9SAndroid Build Coastguard Worker            # idx_large to (sort) and (binary) search into idx_large
1314*da0073e9SAndroid Build Coastguard Worker            # NOTE: when coalesced=True, the (binary) search will be
1315*da0073e9SAndroid Build Coastguard Worker            # done over t_sparse anyway, as it is already sorted.
1316*da0073e9SAndroid Build Coastguard Worker            idx_large = torch.randint(size, (nnz * 2,), device=device)
1317*da0073e9SAndroid Build Coastguard Worker            for idx in (idx_small, idx_large):
1318*da0073e9SAndroid Build Coastguard Worker                res_dense = t_dense.index_select(0, idx)
1319*da0073e9SAndroid Build Coastguard Worker                res_sparse = t_sparse.index_select(0, idx)
1320*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res_dense, res_sparse)
1321*da0073e9SAndroid Build Coastguard Worker
1322*da0073e9SAndroid Build Coastguard Worker        # NOTE: GRAIN_SIZE = 32768
1323*da0073e9SAndroid Build Coastguard Worker        # case nnz <= size[d]
1324*da0073e9SAndroid Build Coastguard Worker        tlen = 70000  # > 2 * GRAIN_SIZE
1325*da0073e9SAndroid Build Coastguard Worker        run_test(tlen, tlen)
1326*da0073e9SAndroid Build Coastguard Worker
1327*da0073e9SAndroid Build Coastguard Worker        # case nnz > size[d]
1328*da0073e9SAndroid Build Coastguard Worker        run_test(tlen, tlen // 2)
1329*da0073e9SAndroid Build Coastguard Worker
1330*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1331*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
1332*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
1333*da0073e9SAndroid Build Coastguard Worker    def test_mm(self, device, dtype, coalesced):
1334*da0073e9SAndroid Build Coastguard Worker        def test_shape(di, dj, dk, nnz):
1335*da0073e9SAndroid Build Coastguard Worker            x, _, _ = self._gen_sparse(2, nnz, [di, dj], dtype, device, coalesced)
1336*da0073e9SAndroid Build Coastguard Worker            t = torch.randn(di, dk, dtype=dtype, device=device)
1337*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(dj, dk, dtype=dtype, device=device)
1338*da0073e9SAndroid Build Coastguard Worker            alpha = random.random()
1339*da0073e9SAndroid Build Coastguard Worker            beta = random.random()
1340*da0073e9SAndroid Build Coastguard Worker
1341*da0073e9SAndroid Build Coastguard Worker            res = torch.addmm(t, x, y, beta=beta, alpha=alpha)
1342*da0073e9SAndroid Build Coastguard Worker            expected = torch.addmm(t, self.safeToDense(x), y, beta=beta, alpha=alpha)
1343*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, expected)
1344*da0073e9SAndroid Build Coastguard Worker
1345*da0073e9SAndroid Build Coastguard Worker            res = torch.addmm(t, x, y)
1346*da0073e9SAndroid Build Coastguard Worker            expected = torch.addmm(t, self.safeToDense(x), y)
1347*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, expected)
1348*da0073e9SAndroid Build Coastguard Worker
1349*da0073e9SAndroid Build Coastguard Worker            res = torch.mm(x, y)
1350*da0073e9SAndroid Build Coastguard Worker            expected = torch.mm(self.safeToDense(x), y)
1351*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, expected)
1352*da0073e9SAndroid Build Coastguard Worker
1353*da0073e9SAndroid Build Coastguard Worker        test_shape(10, 100, 100, 20)
1354*da0073e9SAndroid Build Coastguard Worker        test_shape(100, 1000, 200, 20)
1355*da0073e9SAndroid Build Coastguard Worker        test_shape(64, 10000, 300, 20)
1356*da0073e9SAndroid Build Coastguard Worker        test_shape(0, 100, 100, 0)
1357*da0073e9SAndroid Build Coastguard Worker        test_shape(10, 0, 100, 0)
1358*da0073e9SAndroid Build Coastguard Worker        test_shape(10, 100, 0, 0)
1359*da0073e9SAndroid Build Coastguard Worker        test_shape(10, 100, 0, 20)
1360*da0073e9SAndroid Build Coastguard Worker
1361*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
1362*da0073e9SAndroid Build Coastguard Worker        IS_WINDOWS and TEST_CUDA,
1363*da0073e9SAndroid Build Coastguard Worker        "bmm sparse-dense CUDA is not yet supported in Windows, at least up to CUDA 10.1"
1364*da0073e9SAndroid Build Coastguard Worker    )
1365*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
1366*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
1367*da0073e9SAndroid Build Coastguard Worker    def test_bmm(self, device, dtype, coalesced):
1368*da0073e9SAndroid Build Coastguard Worker        def test_shape(num_mats, dim_i, dim_j, dim_k, nnz):
1369*da0073e9SAndroid Build Coastguard Worker            a_list = []
1370*da0073e9SAndroid Build Coastguard Worker            b_list = []
1371*da0073e9SAndroid Build Coastguard Worker            for mat_idx in range(num_mats):
1372*da0073e9SAndroid Build Coastguard Worker                a_mat = self._gen_sparse(2, nnz, [dim_i, dim_j], dtype, device, coalesced)[0]
1373*da0073e9SAndroid Build Coastguard Worker                b_mat = torch.randn([dim_j, dim_k], dtype=dtype, device=device)
1374*da0073e9SAndroid Build Coastguard Worker                a_list.append(a_mat)
1375*da0073e9SAndroid Build Coastguard Worker                b_list.append(b_mat)
1376*da0073e9SAndroid Build Coastguard Worker
1377*da0073e9SAndroid Build Coastguard Worker            a = torch.stack(a_list)
1378*da0073e9SAndroid Build Coastguard Worker            b = torch.stack(b_list)
1379*da0073e9SAndroid Build Coastguard Worker            ab = a.bmm(b)
1380*da0073e9SAndroid Build Coastguard Worker
1381*da0073e9SAndroid Build Coastguard Worker            # Compare each matrix against result from mm()
1382*da0073e9SAndroid Build Coastguard Worker            for mat_idx in range(num_mats):
1383*da0073e9SAndroid Build Coastguard Worker                a_mat = a_list[mat_idx]
1384*da0073e9SAndroid Build Coastguard Worker                b_mat = b_list[mat_idx]
1385*da0073e9SAndroid Build Coastguard Worker                ab_mat_bmm = ab[mat_idx]
1386*da0073e9SAndroid Build Coastguard Worker                ab_mat_mm = a_mat.mm(b_mat)
1387*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(ab_mat_bmm, ab_mat_mm)
1388*da0073e9SAndroid Build Coastguard Worker
1389*da0073e9SAndroid Build Coastguard Worker        test_shape(10, 10, 100, 99, 20)
1390*da0073e9SAndroid Build Coastguard Worker        test_shape(10, 100, 1000, 200, 20)
1391*da0073e9SAndroid Build Coastguard Worker        test_shape(10, 64, 10000, 300, 20)
1392*da0073e9SAndroid Build Coastguard Worker        test_shape(10, 0, 100, 99, 0)
1393*da0073e9SAndroid Build Coastguard Worker        test_shape(10, 10, 0, 100, 0)
1394*da0073e9SAndroid Build Coastguard Worker        test_shape(10, 10, 100, 0, 0)
1395*da0073e9SAndroid Build Coastguard Worker        test_shape(10, 10, 100, 0, 20)
1396*da0073e9SAndroid Build Coastguard Worker        test_shape(10, 10, 100, 0, 20)
1397*da0073e9SAndroid Build Coastguard Worker
1398*da0073e9SAndroid Build Coastguard Worker        a = torch.rand([10, 23, 32], dtype=dtype, device=device)
1399*da0073e9SAndroid Build Coastguard Worker        a[3] = torch.zeros(23, 32, dtype=dtype, device=device)
1400*da0073e9SAndroid Build Coastguard Worker        a[6] = torch.zeros(23, 32, dtype=dtype, device=device)
1401*da0073e9SAndroid Build Coastguard Worker        a = a.to_sparse()
1402*da0073e9SAndroid Build Coastguard Worker        b = torch.rand([10, 32, 10], dtype=dtype, device=device)
1403*da0073e9SAndroid Build Coastguard Worker        b[4] = torch.zeros(32, 10, dtype=dtype, device=device)
1404*da0073e9SAndroid Build Coastguard Worker        b[6] = torch.zeros(32, 10, dtype=dtype, device=device)
1405*da0073e9SAndroid Build Coastguard Worker        ab = a.bmm(b)
1406*da0073e9SAndroid Build Coastguard Worker        for mat_idx in range(ab.size(0)):
1407*da0073e9SAndroid Build Coastguard Worker            ab_mat = ab[mat_idx]
1408*da0073e9SAndroid Build Coastguard Worker            ab_mat_check = a[mat_idx].mm(b[mat_idx])
1409*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ab_mat, ab_mat_check)
1410*da0073e9SAndroid Build Coastguard Worker
1411*da0073e9SAndroid Build Coastguard Worker        ab_traspose_check = b.transpose(1, 2).to_sparse().bmm(
1412*da0073e9SAndroid Build Coastguard Worker            a.transpose(1, 2).to_dense()
1413*da0073e9SAndroid Build Coastguard Worker        ).transpose(1, 2)
1414*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ab, ab_traspose_check)
1415*da0073e9SAndroid Build Coastguard Worker
1416*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1417*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
1418*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
1419*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
1420*da0073e9SAndroid Build Coastguard Worker        IS_WINDOWS,
1421*da0073e9SAndroid Build Coastguard Worker        "bmm sparse-dense CUDA is not yet supported in Windows, at least up to CUDA 10.1"
1422*da0073e9SAndroid Build Coastguard Worker    )
1423*da0073e9SAndroid Build Coastguard Worker    def test_bmm_deterministic(self, device, dtype, coalesced):
1424*da0073e9SAndroid Build Coastguard Worker        def test_shape(num_mats, dim_i, dim_j, dim_k, nnz):
1425*da0073e9SAndroid Build Coastguard Worker            a_list = []
1426*da0073e9SAndroid Build Coastguard Worker            b_list = []
1427*da0073e9SAndroid Build Coastguard Worker            for mat_idx in range(num_mats):
1428*da0073e9SAndroid Build Coastguard Worker                a_list.append(self._gen_sparse(2, nnz, [dim_i, dim_j], dtype, device, coalesced)[0])
1429*da0073e9SAndroid Build Coastguard Worker                b_list.append(torch.randn([dim_j, dim_k], dtype=dtype, device=device))
1430*da0073e9SAndroid Build Coastguard Worker
1431*da0073e9SAndroid Build Coastguard Worker            a = torch.stack(a_list).cuda()
1432*da0073e9SAndroid Build Coastguard Worker            b = torch.stack(b_list).cuda()
1433*da0073e9SAndroid Build Coastguard Worker            with DeterministicGuard(torch.are_deterministic_algorithms_enabled()):
1434*da0073e9SAndroid Build Coastguard Worker                torch.use_deterministic_algorithms(False)
1435*da0073e9SAndroid Build Coastguard Worker                ab_nondeterministic = torch.bmm(a, b)
1436*da0073e9SAndroid Build Coastguard Worker                torch.use_deterministic_algorithms(True)
1437*da0073e9SAndroid Build Coastguard Worker                ab_deterministic = torch.bmm(a, b)
1438*da0073e9SAndroid Build Coastguard Worker            diff_abs = (ab_deterministic - ab_nondeterministic).abs()
1439*da0073e9SAndroid Build Coastguard Worker            diff_rel = diff_abs / ab_deterministic.abs()
1440*da0073e9SAndroid Build Coastguard Worker            diff_rel[torch.isnan(diff_rel)] = 0
1441*da0073e9SAndroid Build Coastguard Worker
1442*da0073e9SAndroid Build Coastguard Worker            # deterministic and non-deterministic results should either be
1443*da0073e9SAndroid Build Coastguard Worker            # equal or within a small relative difference
1444*da0073e9SAndroid Build Coastguard Worker            equal_abs_or_rel = diff_abs.eq(0).logical_or(diff_rel.lt(0.001))
1445*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(equal_abs_or_rel.all())
1446*da0073e9SAndroid Build Coastguard Worker
1447*da0073e9SAndroid Build Coastguard Worker        test_shape(10, 10, 100, 99, 20)
1448*da0073e9SAndroid Build Coastguard Worker        test_shape(10, 100, 1000, 200, 20)
1449*da0073e9SAndroid Build Coastguard Worker        test_shape(10, 64, 10000, 300, 20)
1450*da0073e9SAndroid Build Coastguard Worker        test_shape(10, 0, 100, 99, 0)
1451*da0073e9SAndroid Build Coastguard Worker        test_shape(10, 10, 0, 100, 0)
1452*da0073e9SAndroid Build Coastguard Worker        test_shape(10, 10, 100, 0, 0)
1453*da0073e9SAndroid Build Coastguard Worker        test_shape(10, 10, 100, 0, 20)
1454*da0073e9SAndroid Build Coastguard Worker        test_shape(10, 10, 100, 0, 20)
1455*da0073e9SAndroid Build Coastguard Worker
1456*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1457*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
1458*da0073e9SAndroid Build Coastguard Worker        not IS_WINDOWS or not TEST_WITH_ROCM,
1459*da0073e9SAndroid Build Coastguard Worker        "this test ensures bmm sparse-dense CUDA gives an error when run on Windows with CUDA < 11.0"
1460*da0073e9SAndroid Build Coastguard Worker    )
1461*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
1462*da0073e9SAndroid Build Coastguard Worker    def test_bmm_windows_error(self, device, dtype):
1463*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(2, 2, 2, dtype=dtype).to_sparse().cuda()
1464*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(2, 2, 2, dtype=dtype).cuda()
1465*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1466*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
1467*da0073e9SAndroid Build Coastguard Worker                "bmm sparse-dense CUDA is not supported on Windows with cuda before 11.0"):
1468*da0073e9SAndroid Build Coastguard Worker            ab = a.bmm(b)
1469*da0073e9SAndroid Build Coastguard Worker
1470*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1471*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
1472*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
1473*da0073e9SAndroid Build Coastguard Worker    def test_saddmm(self, device, dtype, coalesced):
1474*da0073e9SAndroid Build Coastguard Worker        def test_shape(di, dj, dk, nnz):
1475*da0073e9SAndroid Build Coastguard Worker            x = self._gen_sparse(2, nnz, [di, dj], dtype, device, coalesced)[0]
1476*da0073e9SAndroid Build Coastguard Worker            t = self._gen_sparse(2, nnz, [di, dk], dtype, device, coalesced)[0]
1477*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(dj, dk, dtype=dtype, device=device)
1478*da0073e9SAndroid Build Coastguard Worker            alpha = random.random()
1479*da0073e9SAndroid Build Coastguard Worker            beta = random.random()
1480*da0073e9SAndroid Build Coastguard Worker
1481*da0073e9SAndroid Build Coastguard Worker            res = torch.saddmm(t, x, y, beta=beta, alpha=alpha)
1482*da0073e9SAndroid Build Coastguard Worker            expected = torch.addmm(self.safeToDense(t), self.safeToDense(x), y, beta=beta, alpha=alpha)
1483*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(self.safeToDense(res), expected)
1484*da0073e9SAndroid Build Coastguard Worker
1485*da0073e9SAndroid Build Coastguard Worker            res = torch.saddmm(t, x, y)
1486*da0073e9SAndroid Build Coastguard Worker            expected = torch.addmm(self.safeToDense(t), self.safeToDense(x), y)
1487*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(self.safeToDense(res), expected)
1488*da0073e9SAndroid Build Coastguard Worker
1489*da0073e9SAndroid Build Coastguard Worker            res = torch.smm(x, y)
1490*da0073e9SAndroid Build Coastguard Worker            expected = torch.mm(self.safeToDense(x), y)
1491*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(self.safeToDense(res), expected)
1492*da0073e9SAndroid Build Coastguard Worker
1493*da0073e9SAndroid Build Coastguard Worker        test_shape(7, 5, 3, 20)
1494*da0073e9SAndroid Build Coastguard Worker        test_shape(1000, 100, 100, 20)
1495*da0073e9SAndroid Build Coastguard Worker        test_shape(3000, 64, 300, 20)
1496*da0073e9SAndroid Build Coastguard Worker        test_shape(0, 100, 100, 0)
1497*da0073e9SAndroid Build Coastguard Worker        test_shape(1000, 0, 100, 0)
1498*da0073e9SAndroid Build Coastguard Worker        test_shape(1000, 100, 0, 0)
1499*da0073e9SAndroid Build Coastguard Worker
1500*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1501*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
1502*da0073e9SAndroid Build Coastguard Worker    # adding a graph break before self.assertFalse(weight._indices().is_contiguous())
1503*da0073e9SAndroid Build Coastguard Worker    # makes the test pass so some existent sparse related bug
1504*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("skip")
1505*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
1506*da0073e9SAndroid Build Coastguard Worker    def test_sspaddmm(self, device, dtype, coalesced):
1507*da0073e9SAndroid Build Coastguard Worker
1508*da0073e9SAndroid Build Coastguard Worker        def test_shape(di, dj, dk, nnz):
1509*da0073e9SAndroid Build Coastguard Worker            x = self._gen_sparse(2, nnz, [di, dj], dtype, device, coalesced)[0]
1510*da0073e9SAndroid Build Coastguard Worker            t = self._gen_sparse(2, nnz, [di, dk], dtype, device, coalesced)[0]
1511*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(dj, dk, dtype=dtype, device=device)
1512*da0073e9SAndroid Build Coastguard Worker            alpha = random.random()
1513*da0073e9SAndroid Build Coastguard Worker            beta = random.random()
1514*da0073e9SAndroid Build Coastguard Worker
1515*da0073e9SAndroid Build Coastguard Worker            res = t.sspaddmm(x, y, beta=beta, alpha=alpha)
1516*da0073e9SAndroid Build Coastguard Worker            expected = torch.addmm(self.safeToDense(t), self.safeToDense(x), y, beta=beta, alpha=alpha)
1517*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(self.safeToDense(res), expected)
1518*da0073e9SAndroid Build Coastguard Worker
1519*da0073e9SAndroid Build Coastguard Worker            res = t.sspaddmm(x, y)
1520*da0073e9SAndroid Build Coastguard Worker            expected = torch.addmm(self.safeToDense(t), self.safeToDense(x), y)
1521*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(self.safeToDense(res), expected)
1522*da0073e9SAndroid Build Coastguard Worker
1523*da0073e9SAndroid Build Coastguard Worker        test_shape(7, 5, 3, 20)
1524*da0073e9SAndroid Build Coastguard Worker        test_shape(1000, 100, 100, 20)
1525*da0073e9SAndroid Build Coastguard Worker        test_shape(3000, 64, 300, 20)
1526*da0073e9SAndroid Build Coastguard Worker        test_shape(0, 100, 100, 0)
1527*da0073e9SAndroid Build Coastguard Worker        test_shape(1000, 0, 100, 0)
1528*da0073e9SAndroid Build Coastguard Worker        test_shape(1000, 100, 0, 0)
1529*da0073e9SAndroid Build Coastguard Worker
1530*da0073e9SAndroid Build Coastguard Worker        # Test code from issue https://github.com/pytorch/pytorch/issues/45113
1531*da0073e9SAndroid Build Coastguard Worker        batch_size, input_size, hidden_size = 5, 3, 7
1532*da0073e9SAndroid Build Coastguard Worker
1533*da0073e9SAndroid Build Coastguard Worker        # Create coalesced sparse tensor with non-contiguous indices
1534*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn(hidden_size, input_size, dtype=dtype, device=device).to_sparse()
1535*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(weight.is_coalesced())
1536*da0073e9SAndroid Build Coastguard Worker        non_contig_indices = weight.indices().mT.contiguous().mT
1537*da0073e9SAndroid Build Coastguard Worker        weight = torch.sparse_coo_tensor(
1538*da0073e9SAndroid Build Coastguard Worker            indices=non_contig_indices, values=weight.values(), size=weight.shape)
1539*da0073e9SAndroid Build Coastguard Worker        weight._coalesced_(True)
1540*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(weight._indices().is_contiguous())
1541*da0073e9SAndroid Build Coastguard Worker        # Create un/coalesced sparse tensor
1542*da0073e9SAndroid Build Coastguard Worker        bias = torch.randn((hidden_size, 1), dtype=dtype, device=device).to_sparse()
1543*da0073e9SAndroid Build Coastguard Worker        bias = torch.cat([bias] * batch_size, dim=1)
1544*da0073e9SAndroid Build Coastguard Worker
1545*da0073e9SAndroid Build Coastguard Worker        if coalesced:
1546*da0073e9SAndroid Build Coastguard Worker            bias = bias.coalesce()
1547*da0073e9SAndroid Build Coastguard Worker
1548*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(input_size, batch_size, dtype=dtype, device=device)
1549*da0073e9SAndroid Build Coastguard Worker        res = bias.sspaddmm(weight, x)
1550*da0073e9SAndroid Build Coastguard Worker
1551*da0073e9SAndroid Build Coastguard Worker        true_result = (bias.to_dense() + torch.matmul(weight.to_dense(), x)).to_sparse()
1552*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(self.safeToDense(res), self.safeToDense(true_result))
1553*da0073e9SAndroid Build Coastguard Worker
1554*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
1555*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.bfloat16: 5e-2, torch.float16: 5e-2})
1556*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble, torch.bfloat16, torch.float16)
1557*da0073e9SAndroid Build Coastguard Worker    def test_sparse_addmm(self, device, dtype, coalesced):
1558*da0073e9SAndroid Build Coastguard Worker        if (dtype is torch.bfloat16 or dtype is torch.float16) and device.startswith("cuda"):
1559*da0073e9SAndroid Build Coastguard Worker            self.skipTest('addmm_sparse_cuda is not implemented for BFloat16 and Half')
1560*da0073e9SAndroid Build Coastguard Worker
1561*da0073e9SAndroid Build Coastguard Worker
1562*da0073e9SAndroid Build Coastguard Worker        def test_shape(m, n, p, nnz, broadcast, alpha_beta=None):
1563*da0073e9SAndroid Build Coastguard Worker            if alpha_beta is None:
1564*da0073e9SAndroid Build Coastguard Worker                alpha = random.random()
1565*da0073e9SAndroid Build Coastguard Worker                beta = random.random()
1566*da0073e9SAndroid Build Coastguard Worker            else:
1567*da0073e9SAndroid Build Coastguard Worker                alpha, beta = alpha_beta
1568*da0073e9SAndroid Build Coastguard Worker            if broadcast:
1569*da0073e9SAndroid Build Coastguard Worker                D1 = make_tensor((), dtype=dtype, device=device, requires_grad=True)
1570*da0073e9SAndroid Build Coastguard Worker            else:
1571*da0073e9SAndroid Build Coastguard Worker                D1 = make_tensor([n, p], dtype=dtype, device=device, requires_grad=True)
1572*da0073e9SAndroid Build Coastguard Worker            D2 = make_tensor([m, p], dtype=dtype, device=device, requires_grad=True)
1573*da0073e9SAndroid Build Coastguard Worker            S = self._gen_sparse(2, nnz, [n, m], dtype, device, coalesced)[0]
1574*da0073e9SAndroid Build Coastguard Worker            S_dense = S.to_dense().requires_grad_(True)
1575*da0073e9SAndroid Build Coastguard Worker            S.requires_grad_(True)
1576*da0073e9SAndroid Build Coastguard Worker            Y = torch.sparse.addmm(D1, S, D2, beta=beta, alpha=alpha)
1577*da0073e9SAndroid Build Coastguard Worker            Y_dense = torch.addmm(D1, S_dense, D2, beta=beta, alpha=alpha)
1578*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(Y, Y_dense)
1579*da0073e9SAndroid Build Coastguard Worker
1580*da0073e9SAndroid Build Coastguard Worker            if dtype not in {torch.double, torch.cdouble}:
1581*da0073e9SAndroid Build Coastguard Worker                # gradcheck will likely fail with low-precision input dtypes.
1582*da0073e9SAndroid Build Coastguard Worker                return
1583*da0073e9SAndroid Build Coastguard Worker
1584*da0073e9SAndroid Build Coastguard Worker            def fn(S, D1, D2, beta=beta, alpha=alpha):
1585*da0073e9SAndroid Build Coastguard Worker                return torch.sparse.addmm(D1, S, D2, beta=beta, alpha=alpha)
1586*da0073e9SAndroid Build Coastguard Worker            gradcheck(fn, (S, D1, D2), masked=True)
1587*da0073e9SAndroid Build Coastguard Worker
1588*da0073e9SAndroid Build Coastguard Worker        test_shape(7, 8, 9, 20, False, None)
1589*da0073e9SAndroid Build Coastguard Worker        test_shape(7, 8, 9, 20, True, None)
1590*da0073e9SAndroid Build Coastguard Worker        test_shape(7, 8, 9, 20, False, (1, 0))
1591*da0073e9SAndroid Build Coastguard Worker        test_shape(7, 8, 9, 20, True, (1, 0))
1592*da0073e9SAndroid Build Coastguard Worker        test_shape(7, 8, 9, 20, False, (1, 1))
1593*da0073e9SAndroid Build Coastguard Worker        test_shape(7, 8, 9, 20, True, (1, 1))
1594*da0073e9SAndroid Build Coastguard Worker
1595*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
1596*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
1597*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
1598*da0073e9SAndroid Build Coastguard Worker    def test_sparse_mm(self, device, dtype, coalesced):
1599*da0073e9SAndroid Build Coastguard Worker        def test_shape(d1, d2, d3, nnz, transposed):
1600*da0073e9SAndroid Build Coastguard Worker            if transposed:
1601*da0073e9SAndroid Build Coastguard Worker                D = torch.randn(d3, d2, dtype=dtype,
1602*da0073e9SAndroid Build Coastguard Worker                                device=device).t_().requires_grad_(True)
1603*da0073e9SAndroid Build Coastguard Worker            else:
1604*da0073e9SAndroid Build Coastguard Worker                D = torch.randn(d2, d3, dtype=dtype, device=device).requires_grad_(True)
1605*da0073e9SAndroid Build Coastguard Worker            S = self._gen_sparse(2, nnz, [d1, d2], dtype, device, coalesced)[0]
1606*da0073e9SAndroid Build Coastguard Worker            S_dense = S.to_dense().requires_grad_(True)
1607*da0073e9SAndroid Build Coastguard Worker            S.requires_grad_(True)
1608*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.sparse.mm(S, D), torch.mm(S_dense, D))
1609*da0073e9SAndroid Build Coastguard Worker
1610*da0073e9SAndroid Build Coastguard Worker            def fn(S, D):
1611*da0073e9SAndroid Build Coastguard Worker                return torch.sparse.mm(S, D)
1612*da0073e9SAndroid Build Coastguard Worker            gradcheck(fn, (S, D), masked=True)
1613*da0073e9SAndroid Build Coastguard Worker
1614*da0073e9SAndroid Build Coastguard Worker        test_shape(7, 8, 9, 20, False)
1615*da0073e9SAndroid Build Coastguard Worker        test_shape(7, 8, 9, 20, True)
1616*da0073e9SAndroid Build Coastguard Worker
1617*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
1618*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
1619*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
1620*da0073e9SAndroid Build Coastguard Worker    @gradcheck_semantics()
1621*da0073e9SAndroid Build Coastguard Worker    def test_sparse_mul(self, device, dtype, coalesced, gradcheck):
1622*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/79914
1623*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([[0., 1]], dtype=dtype, device=device).to_sparse().requires_grad_(True)
1624*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor([[0., 1]], dtype=dtype, device=device).to_sparse().requires_grad_(True)
1625*da0073e9SAndroid Build Coastguard Worker        gradcheck(lambda x, y: torch.sparse.sum(x * y).to_dense(masked_grad=gradcheck.masked), [a, b])
1626*da0073e9SAndroid Build Coastguard Worker
1627*da0073e9SAndroid Build Coastguard Worker        def test_shape(sparse_dims, nnz, with_shape):
1628*da0073e9SAndroid Build Coastguard Worker            a = self._gen_sparse(sparse_dims, nnz, with_shape, dtype, device, coalesced)[0].requires_grad_(True)
1629*da0073e9SAndroid Build Coastguard Worker            b = self._gen_sparse(sparse_dims, nnz, with_shape, dtype, device, coalesced)[0].requires_grad_(True)
1630*da0073e9SAndroid Build Coastguard Worker
1631*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((a * b).to_dense(), a.to_dense() * b.to_dense(), masked=True)
1632*da0073e9SAndroid Build Coastguard Worker            gradcheck(lambda x, y: (x * y).to_dense(), [a, b])
1633*da0073e9SAndroid Build Coastguard Worker            # Issues with 0-dim indices/values
1634*da0073e9SAndroid Build Coastguard Worker            gradcheck(lambda x, y: torch.sparse.sum(x * y).to_dense(), [a, b], masked=True)
1635*da0073e9SAndroid Build Coastguard Worker
1636*da0073e9SAndroid Build Coastguard Worker        # TODO: Re-enable these
1637*da0073e9SAndroid Build Coastguard Worker        # test_shape(2, 3, [2, 3, 4, 5])
1638*da0073e9SAndroid Build Coastguard Worker        # test_shape(2, 3, [2, 2, 0])
1639*da0073e9SAndroid Build Coastguard Worker
1640*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
1641*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
1642*da0073e9SAndroid Build Coastguard Worker    def test_dsmm(self, device, dtype, coalesced):
1643*da0073e9SAndroid Build Coastguard Worker        def test_shape(di, dj, dk, nnz):
1644*da0073e9SAndroid Build Coastguard Worker            x = self._gen_sparse(2, nnz, [di, dj], dtype, device, coalesced)[0]
1645*da0073e9SAndroid Build Coastguard Worker            y = self.randn(dj, dk, dtype=dtype, device=device)
1646*da0073e9SAndroid Build Coastguard Worker
1647*da0073e9SAndroid Build Coastguard Worker            res = torch.dsmm(x, y)
1648*da0073e9SAndroid Build Coastguard Worker            expected = torch.mm(self.safeToDense(x), y)
1649*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, expected)
1650*da0073e9SAndroid Build Coastguard Worker
1651*da0073e9SAndroid Build Coastguard Worker        test_shape(7, 5, 3, 20)
1652*da0073e9SAndroid Build Coastguard Worker        test_shape(1000, 100, 100, 20)
1653*da0073e9SAndroid Build Coastguard Worker        test_shape(3000, 64, 300, 20)
1654*da0073e9SAndroid Build Coastguard Worker        test_shape(0, 100, 100, 0)
1655*da0073e9SAndroid Build Coastguard Worker        test_shape(1000, 0, 100, 0)
1656*da0073e9SAndroid Build Coastguard Worker        test_shape(1000, 100, 0, 0)
1657*da0073e9SAndroid Build Coastguard Worker        test_shape(1000, 100, 0, 20)
1658*da0073e9SAndroid Build Coastguard Worker
1659*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
1660*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
1661*da0073e9SAndroid Build Coastguard Worker    def test_hsmm(self, device, dtype, coalesced):
1662*da0073e9SAndroid Build Coastguard Worker        def test_shape(di, dj, dk, nnz):
1663*da0073e9SAndroid Build Coastguard Worker            x = self._gen_sparse(2, nnz, [di, dj], dtype, device, coalesced)[0]
1664*da0073e9SAndroid Build Coastguard Worker            y = self.randn(dj, dk, dtype=dtype, device=device)
1665*da0073e9SAndroid Build Coastguard Worker
1666*da0073e9SAndroid Build Coastguard Worker            res = torch.hsmm(x, y)
1667*da0073e9SAndroid Build Coastguard Worker            expected = torch.mm(self.safeToDense(x), y)
1668*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.to_dense(), expected)
1669*da0073e9SAndroid Build Coastguard Worker
1670*da0073e9SAndroid Build Coastguard Worker        test_shape(7, 5, 3, 20)
1671*da0073e9SAndroid Build Coastguard Worker        test_shape(1000, 100, 100, 20)
1672*da0073e9SAndroid Build Coastguard Worker        test_shape(3000, 64, 300, 20)
1673*da0073e9SAndroid Build Coastguard Worker        test_shape(0, 100, 100, 0)
1674*da0073e9SAndroid Build Coastguard Worker        test_shape(1000, 0, 100, 0)
1675*da0073e9SAndroid Build Coastguard Worker        test_shape(1000, 100, 0, 0)
1676*da0073e9SAndroid Build Coastguard Worker        test_shape(1000, 100, 0, 20)
1677*da0073e9SAndroid Build Coastguard Worker
1678*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
1679*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
1680*da0073e9SAndroid Build Coastguard Worker    def test_spadd(self, device, dtype, coalesced):
1681*da0073e9SAndroid Build Coastguard Worker
1682*da0073e9SAndroid Build Coastguard Worker        def _test_spadd_shape(nnz, shape_i, shape_v=None):
1683*da0073e9SAndroid Build Coastguard Worker            shape = shape_i + (shape_v or [])
1684*da0073e9SAndroid Build Coastguard Worker            x, _, _ = self._gen_sparse(len(shape_i), nnz, shape, dtype, device, coalesced)
1685*da0073e9SAndroid Build Coastguard Worker            y = self.randn(*shape, dtype=dtype, device=device)
1686*da0073e9SAndroid Build Coastguard Worker            r = random.random()
1687*da0073e9SAndroid Build Coastguard Worker
1688*da0073e9SAndroid Build Coastguard Worker            res = torch.add(y, x, alpha=r)
1689*da0073e9SAndroid Build Coastguard Worker            expected = y + r * self.safeToDense(x)
1690*da0073e9SAndroid Build Coastguard Worker
1691*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, expected)
1692*da0073e9SAndroid Build Coastguard Worker
1693*da0073e9SAndroid Build Coastguard Worker            # Non contiguous dense tensor
1694*da0073e9SAndroid Build Coastguard Worker            s = list(shape)
1695*da0073e9SAndroid Build Coastguard Worker            s[0] = shape[-1]
1696*da0073e9SAndroid Build Coastguard Worker            s[-1] = shape[0]
1697*da0073e9SAndroid Build Coastguard Worker            y = self.randn(*s, dtype=dtype, device=device)
1698*da0073e9SAndroid Build Coastguard Worker            y.transpose_(0, len(s) - 1)
1699*da0073e9SAndroid Build Coastguard Worker            r = random.random()
1700*da0073e9SAndroid Build Coastguard Worker
1701*da0073e9SAndroid Build Coastguard Worker            res = torch.add(y, x, alpha=r)
1702*da0073e9SAndroid Build Coastguard Worker            expected = y + r * self.safeToDense(x)
1703*da0073e9SAndroid Build Coastguard Worker
1704*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, expected)
1705*da0073e9SAndroid Build Coastguard Worker
1706*da0073e9SAndroid Build Coastguard Worker            x, i, v = self._gen_sparse(len(shape_i), nnz, shape, dtype, device, coalesced)
1707*da0073e9SAndroid Build Coastguard Worker            nnz = i.size(1)
1708*da0073e9SAndroid Build Coastguard Worker
1709*da0073e9SAndroid Build Coastguard Worker            # Non contiguous sparse indices tensor
1710*da0073e9SAndroid Build Coastguard Worker            x_ = self.sparse_tensor(i[:, ::2], v[:(nnz + 1) // 2], x.shape, dtype=dtype, device=device)
1711*da0073e9SAndroid Build Coastguard Worker            res = torch.add(y, x_, alpha=r)
1712*da0073e9SAndroid Build Coastguard Worker            expected = y + r * self.safeToDense(x_)
1713*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, expected)
1714*da0073e9SAndroid Build Coastguard Worker
1715*da0073e9SAndroid Build Coastguard Worker            # Non contiguous sparse values tensor
1716*da0073e9SAndroid Build Coastguard Worker
1717*da0073e9SAndroid Build Coastguard Worker            x_ = self.sparse_tensor(i[:, :(nnz + 1) // 2], v[::2], x.shape, dtype=dtype, device=device)
1718*da0073e9SAndroid Build Coastguard Worker            res = torch.add(y, x_, alpha=r)
1719*da0073e9SAndroid Build Coastguard Worker            expected = y + r * self.safeToDense(x_)
1720*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, expected)
1721*da0073e9SAndroid Build Coastguard Worker
1722*da0073e9SAndroid Build Coastguard Worker            # Non contiguous sparse indices and values tensors
1723*da0073e9SAndroid Build Coastguard Worker            x_ = self.sparse_tensor(i[:, 1::2], v[1::2], x.shape, dtype=dtype, device=device)
1724*da0073e9SAndroid Build Coastguard Worker            res = torch.add(y, x_, alpha=r)
1725*da0073e9SAndroid Build Coastguard Worker            expected = y + r * self.safeToDense(x_)
1726*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, expected)
1727*da0073e9SAndroid Build Coastguard Worker
1728*da0073e9SAndroid Build Coastguard Worker        def _test_spadd():
1729*da0073e9SAndroid Build Coastguard Worker            _test_spadd_shape(10, [5, 6])
1730*da0073e9SAndroid Build Coastguard Worker            _test_spadd_shape(10, [10, 10, 10])
1731*da0073e9SAndroid Build Coastguard Worker            _test_spadd_shape(10, [50, 30, 20])
1732*da0073e9SAndroid Build Coastguard Worker            _test_spadd_shape(10, [5, 5, 5, 5, 5, 5])
1733*da0073e9SAndroid Build Coastguard Worker            _test_spadd_shape(0, [0, 30, 20])
1734*da0073e9SAndroid Build Coastguard Worker            _test_spadd_shape(0, [50, 0, 20])
1735*da0073e9SAndroid Build Coastguard Worker            _test_spadd_shape(0, [50, 30, 0])
1736*da0073e9SAndroid Build Coastguard Worker
1737*da0073e9SAndroid Build Coastguard Worker        def _test_spadd_hybrid():
1738*da0073e9SAndroid Build Coastguard Worker            _test_spadd_shape(10, [5, 6], [2, 3])
1739*da0073e9SAndroid Build Coastguard Worker            _test_spadd_shape(10, [10, 10, 10], [3])
1740*da0073e9SAndroid Build Coastguard Worker            _test_spadd_shape(10, [50, 30, 20], [2])
1741*da0073e9SAndroid Build Coastguard Worker            _test_spadd_shape(10, [5, 5, 5, 5, 5, 5], [2])
1742*da0073e9SAndroid Build Coastguard Worker            _test_spadd_shape(0, [0, 30, 20], [2, 0])
1743*da0073e9SAndroid Build Coastguard Worker            _test_spadd_shape(0, [50, 0, 20], [2, 0])
1744*da0073e9SAndroid Build Coastguard Worker            _test_spadd_shape(0, [50, 30, 0], [2, 0])
1745*da0073e9SAndroid Build Coastguard Worker            _test_spadd_shape(10, [50, 30, 20], [2, 0])
1746*da0073e9SAndroid Build Coastguard Worker
1747*da0073e9SAndroid Build Coastguard Worker        _test_spadd()
1748*da0073e9SAndroid Build Coastguard Worker        _test_spadd_hybrid()
1749*da0073e9SAndroid Build Coastguard Worker
1750*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
1751*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
1752*da0073e9SAndroid Build Coastguard Worker    def test_sparse_add_out_bfloat16(self, device, dtype, coalesced):
1753*da0073e9SAndroid Build Coastguard Worker        # fp32
1754*da0073e9SAndroid Build Coastguard Worker        x, _, _ = self._gen_sparse(3, 5, 10, dtype, device, coalesced)
1755*da0073e9SAndroid Build Coastguard Worker        y, _, _ = self._gen_sparse(3, 5, 10, dtype, device, coalesced)
1756*da0073e9SAndroid Build Coastguard Worker        res_fp32 = torch.add(x, y)
1757*da0073e9SAndroid Build Coastguard Worker
1758*da0073e9SAndroid Build Coastguard Worker        # bfloat16
1759*da0073e9SAndroid Build Coastguard Worker        x = x.bfloat16()
1760*da0073e9SAndroid Build Coastguard Worker        y = y.bfloat16()
1761*da0073e9SAndroid Build Coastguard Worker        res_bf16 = torch.add(x, y)
1762*da0073e9SAndroid Build Coastguard Worker        res_bf16 = res_bf16.float()  # to compare with reference
1763*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_fp32, res_bf16, atol=1e-2, rtol=0)
1764*da0073e9SAndroid Build Coastguard Worker
1765*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
1766*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
1767*da0073e9SAndroid Build Coastguard Worker    def test_norm(self, device, dtype, coalesced):
1768*da0073e9SAndroid Build Coastguard Worker        def test_shape(sparse_dims, nnz, with_size):
1769*da0073e9SAndroid Build Coastguard Worker            x, _, _ = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)
1770*da0073e9SAndroid Build Coastguard Worker            y = x.coalesce()
1771*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.norm(), y._values().norm())
1772*da0073e9SAndroid Build Coastguard Worker
1773*da0073e9SAndroid Build Coastguard Worker        test_shape(3, 10, 100)
1774*da0073e9SAndroid Build Coastguard Worker        test_shape(4, 10, [100, 100, 100, 5, 5, 5, 0])
1775*da0073e9SAndroid Build Coastguard Worker        test_shape(4, 0, [0, 0, 100, 5, 5, 5, 0])
1776*da0073e9SAndroid Build Coastguard Worker
1777*da0073e9SAndroid Build Coastguard Worker        # Unsupported arguments should error
1778*da0073e9SAndroid Build Coastguard Worker        kwarg_error_pairs = [
1779*da0073e9SAndroid Build Coastguard Worker            ({'keepdim': True},
1780*da0073e9SAndroid Build Coastguard Worker             RuntimeError, r'norm_sparse currently does not support keepdim=True'),
1781*da0073e9SAndroid Build Coastguard Worker            ({'dim': 0},
1782*da0073e9SAndroid Build Coastguard Worker             RuntimeError, r'norm_sparse currently only supports full reductions'),
1783*da0073e9SAndroid Build Coastguard Worker            ({'dtype': torch.double, 'p': 'fro'},
1784*da0073e9SAndroid Build Coastguard Worker             ValueError, r'dtype argument is not supported in frobenius norm'),
1785*da0073e9SAndroid Build Coastguard Worker            ({'dtype': torch.double, 'p': 0},
1786*da0073e9SAndroid Build Coastguard Worker             RuntimeError, r"norm_sparse currently does not support 'dtype' argument")
1787*da0073e9SAndroid Build Coastguard Worker        ]
1788*da0073e9SAndroid Build Coastguard Worker        x = self._gen_sparse(3, 10, 100, dtype, device, coalesced)[0]
1789*da0073e9SAndroid Build Coastguard Worker        for kwargs, err, msg in kwarg_error_pairs:
1790*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(err, msg):
1791*da0073e9SAndroid Build Coastguard Worker                x.norm(**kwargs)
1792*da0073e9SAndroid Build Coastguard Worker
1793*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
1794*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
1795*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_CROSSREF, "fallback triggers cuda device error")
1796*da0073e9SAndroid Build Coastguard Worker    def test_sparse_sum(self, device, dtype, coalesced):
1797*da0073e9SAndroid Build Coastguard Worker
1798*da0073e9SAndroid Build Coastguard Worker        def run_tests(S, td=None):
1799*da0073e9SAndroid Build Coastguard Worker            D = S.coalesce().to_dense().detach().requires_grad_(True)
1800*da0073e9SAndroid Build Coastguard Worker            if td is None:
1801*da0073e9SAndroid Build Coastguard Worker                S_sum = torch.sparse.sum(S)
1802*da0073e9SAndroid Build Coastguard Worker                D_sum = D.sum()
1803*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(S_sum.item(), D_sum.item())
1804*da0073e9SAndroid Build Coastguard Worker
1805*da0073e9SAndroid Build Coastguard Worker                def fn(S):
1806*da0073e9SAndroid Build Coastguard Worker                    return torch.sparse.sum(S)
1807*da0073e9SAndroid Build Coastguard Worker                gradcheck(fn, (S,), masked=True)
1808*da0073e9SAndroid Build Coastguard Worker            else:
1809*da0073e9SAndroid Build Coastguard Worker                S_sum = torch.sparse.sum(S, td)
1810*da0073e9SAndroid Build Coastguard Worker                D_sum = D.sum(td)
1811*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(S_sum.to_dense() if S_sum.is_sparse else S_sum, D_sum)
1812*da0073e9SAndroid Build Coastguard Worker
1813*da0073e9SAndroid Build Coastguard Worker                def fn(S):
1814*da0073e9SAndroid Build Coastguard Worker                    res = torch.sparse.sum(S, td)
1815*da0073e9SAndroid Build Coastguard Worker                    return res.to_dense(masked_grad=True)
1816*da0073e9SAndroid Build Coastguard Worker                gradcheck(fn, (S,), masked=True)
1817*da0073e9SAndroid Build Coastguard Worker
1818*da0073e9SAndroid Build Coastguard Worker        nnz = 10
1819*da0073e9SAndroid Build Coastguard Worker        sparse_dims = 2
1820*da0073e9SAndroid Build Coastguard Worker        with_size = [5, 5, 1, 4]  # use a dense dim = 1 to test for squeeze
1821*da0073e9SAndroid Build Coastguard Worker        test_dims = []
1822*da0073e9SAndroid Build Coastguard Worker        for i in range(1, 5):
1823*da0073e9SAndroid Build Coastguard Worker            test_dims += itertools.combinations(range(len(with_size)), i)
1824*da0073e9SAndroid Build Coastguard Worker
1825*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/16501
1826*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([[1., 0., 0., 1.],
1827*da0073e9SAndroid Build Coastguard Worker                          [0., 1., 0., 0.],
1828*da0073e9SAndroid Build Coastguard Worker                          [0., 1., 1., 0.],
1829*da0073e9SAndroid Build Coastguard Worker                          [0., 1., 0., 2.]], dtype=dtype, device=device).to_sparse()
1830*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.sparse.sum(x, dim=0), torch.sparse.sum(x, dim=-2))
1831*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.sum(x.to_dense(), dim=0), torch.sparse.sum(x, dim=0).to_dense())
1832*da0073e9SAndroid Build Coastguard Worker
1833*da0073e9SAndroid Build Coastguard Worker        S = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0]
1834*da0073e9SAndroid Build Coastguard Worker
1835*da0073e9SAndroid Build Coastguard Worker        # dim out of range
1836*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: torch.sparse.sum(S, 5))
1837*da0073e9SAndroid Build Coastguard Worker
1838*da0073e9SAndroid Build Coastguard Worker        # dim 0 appears multiple times in the list of dims
1839*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch.sparse.sum(S, [0, 0]))
1840*da0073e9SAndroid Build Coastguard Worker
1841*da0073e9SAndroid Build Coastguard Worker        # sum an empty tensor
1842*da0073e9SAndroid Build Coastguard Worker        empty_S = torch.sparse_coo_tensor(size=with_size, dtype=dtype, device=device)
1843*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.sparse.sum(empty_S, [0]).to_dense(), torch.sum(empty_S.to_dense(), [0]))
1844*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.sparse.sum(empty_S), torch.tensor(0, dtype=dtype, device=device))
1845*da0073e9SAndroid Build Coastguard Worker        empty_S.requires_grad_(True)
1846*da0073e9SAndroid Build Coastguard Worker        empty_S_sum = torch.sparse.sum(empty_S)
1847*da0073e9SAndroid Build Coastguard Worker        empty_S_sum.backward()
1848*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(empty_S.grad.to_dense(), empty_S.clone().detach().to_dense())
1849*da0073e9SAndroid Build Coastguard Worker
1850*da0073e9SAndroid Build Coastguard Worker        # test values().sum()
1851*da0073e9SAndroid Build Coastguard Worker        S = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0]
1852*da0073e9SAndroid Build Coastguard Worker        run_tests(S.requires_grad_(True))
1853*da0073e9SAndroid Build Coastguard Worker
1854*da0073e9SAndroid Build Coastguard Worker        for test_dim in test_dims:
1855*da0073e9SAndroid Build Coastguard Worker            S = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0]
1856*da0073e9SAndroid Build Coastguard Worker            run_tests(S.requires_grad_(True), test_dim)
1857*da0073e9SAndroid Build Coastguard Worker
1858*da0073e9SAndroid Build Coastguard Worker    def _test_basic_ops_shape(self, nnz_x1, nnz_x2, shape_i, shape_v, dtype, device, coalesced):
1859*da0073e9SAndroid Build Coastguard Worker        shape = shape_i + (shape_v)
1860*da0073e9SAndroid Build Coastguard Worker        x1, _, _ = self._gen_sparse(len(shape_i), nnz_x1, shape, dtype, device, coalesced)
1861*da0073e9SAndroid Build Coastguard Worker        x2, _, _ = self._gen_sparse(len(shape_i), nnz_x2, shape, dtype, device, coalesced)
1862*da0073e9SAndroid Build Coastguard Worker
1863*da0073e9SAndroid Build Coastguard Worker        y1 = x1 + x2
1864*da0073e9SAndroid Build Coastguard Worker        y2 = x1.clone()
1865*da0073e9SAndroid Build Coastguard Worker        y2.add_(x2)
1866*da0073e9SAndroid Build Coastguard Worker        expected = self.safeToDense(x1) + self.safeToDense(x2)
1867*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(self.safeToDense(y1), expected)
1868*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(self.safeToDense(y2), expected)
1869*da0073e9SAndroid Build Coastguard Worker
1870*da0073e9SAndroid Build Coastguard Worker        y1 = x1 - x2
1871*da0073e9SAndroid Build Coastguard Worker        y2 = x1.clone()
1872*da0073e9SAndroid Build Coastguard Worker        y2.sub_(x2)
1873*da0073e9SAndroid Build Coastguard Worker        expected = self.safeToDense(x1) - self.safeToDense(x2)
1874*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(self.safeToDense(y1), expected)
1875*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(self.safeToDense(y2), expected)
1876*da0073e9SAndroid Build Coastguard Worker
1877*da0073e9SAndroid Build Coastguard Worker        y1 = x1 * x2
1878*da0073e9SAndroid Build Coastguard Worker        y2 = x1.clone()
1879*da0073e9SAndroid Build Coastguard Worker        y2.mul_(x2)
1880*da0073e9SAndroid Build Coastguard Worker        expected = self.safeToDense(x1) * self.safeToDense(x2)
1881*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(self.safeToDense(y1), expected)
1882*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(self.safeToDense(y2), expected)
1883*da0073e9SAndroid Build Coastguard Worker
1884*da0073e9SAndroid Build Coastguard Worker        y1 = x1 * 37.5
1885*da0073e9SAndroid Build Coastguard Worker        y2 = x1.clone()
1886*da0073e9SAndroid Build Coastguard Worker        y2.mul_(37.5)
1887*da0073e9SAndroid Build Coastguard Worker        expected = self.safeToDense(x1) * 37.5
1888*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(self.safeToDense(y1), expected)
1889*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(self.safeToDense(y2), expected)
1890*da0073e9SAndroid Build Coastguard Worker
1891*da0073e9SAndroid Build Coastguard Worker        y1 = x1 / 37.5
1892*da0073e9SAndroid Build Coastguard Worker        y2 = x1.clone()
1893*da0073e9SAndroid Build Coastguard Worker        y2.div_(37.5)
1894*da0073e9SAndroid Build Coastguard Worker        expected = self.safeToDense(x1) / 37.5
1895*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(self.safeToDense(y1), expected)
1896*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(self.safeToDense(y2), expected)
1897*da0073e9SAndroid Build Coastguard Worker
1898*da0073e9SAndroid Build Coastguard Worker        y1 = x1 // 37.5
1899*da0073e9SAndroid Build Coastguard Worker        y2 = x1.clone()
1900*da0073e9SAndroid Build Coastguard Worker        y2.floor_divide_(37.5)
1901*da0073e9SAndroid Build Coastguard Worker        expected = self.safeToDense(x1) // 37.5
1902*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(self.safeToDense(y1), expected)
1903*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(self.safeToDense(y2), expected)
1904*da0073e9SAndroid Build Coastguard Worker
1905*da0073e9SAndroid Build Coastguard Worker        # TODO: add back inplace support
1906*da0073e9SAndroid Build Coastguard Worker        y1 = x1 ** 2
1907*da0073e9SAndroid Build Coastguard Worker        y2 = x1.clone()
1908*da0073e9SAndroid Build Coastguard Worker        y2 = y2.pow(2)
1909*da0073e9SAndroid Build Coastguard Worker        expected = self.safeToDense(x1) ** 2
1910*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(self.safeToDense(y1), expected)
1911*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(self.safeToDense(y2), expected)
1912*da0073e9SAndroid Build Coastguard Worker
1913*da0073e9SAndroid Build Coastguard Worker        y = x1.clone()
1914*da0073e9SAndroid Build Coastguard Worker        y.zero_()
1915*da0073e9SAndroid Build Coastguard Worker        expected = torch.zeros(x1.size(), dtype=dtype, device=device)
1916*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(self.safeToDense(y), expected)
1917*da0073e9SAndroid Build Coastguard Worker
1918*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x1.is_coalesced(), coalesced)
1919*da0073e9SAndroid Build Coastguard Worker        y = x1.coalesce()
1920*da0073e9SAndroid Build Coastguard Worker        z = x1.coalesce()
1921*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x1.is_coalesced(), coalesced)
1922*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(y.is_coalesced())
1923*da0073e9SAndroid Build Coastguard Worker        y._values().add_(1)
1924*da0073e9SAndroid Build Coastguard Worker        if not x1.is_coalesced():
1925*da0073e9SAndroid Build Coastguard Worker            # check that coalesce is out of place if the original tensor is not
1926*da0073e9SAndroid Build Coastguard Worker            # coalesced.
1927*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(z._values() + 1, y._values())
1928*da0073e9SAndroid Build Coastguard Worker        else:
1929*da0073e9SAndroid Build Coastguard Worker            # check that coalesce is in-place if the original tensor is
1930*da0073e9SAndroid Build Coastguard Worker            # coalesced.
1931*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(z._values(), y._values())
1932*da0073e9SAndroid Build Coastguard Worker
1933*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
1934*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
1935*da0073e9SAndroid Build Coastguard Worker    def test_basic_ops(self, device, dtype, coalesced):
1936*da0073e9SAndroid Build Coastguard Worker
1937*da0073e9SAndroid Build Coastguard Worker        def _test_basic_ops():
1938*da0073e9SAndroid Build Coastguard Worker            self._test_basic_ops_shape(9, 12, [5, 6], [], dtype, device, coalesced)
1939*da0073e9SAndroid Build Coastguard Worker            self._test_basic_ops_shape(9, 12, [10, 10, 10], [], dtype, device, coalesced)
1940*da0073e9SAndroid Build Coastguard Worker            self._test_basic_ops_shape(9, 12, [50, 30, 20], [], dtype, device, coalesced)
1941*da0073e9SAndroid Build Coastguard Worker            self._test_basic_ops_shape(9, 12, [5, 5, 5, 5, 5, 5], [], dtype, device, coalesced)
1942*da0073e9SAndroid Build Coastguard Worker            self._test_basic_ops_shape(0, 12, [10, 10, 10], [], dtype, device, coalesced)
1943*da0073e9SAndroid Build Coastguard Worker            self._test_basic_ops_shape(9, 0, [10, 10, 10], [], dtype, device, coalesced)
1944*da0073e9SAndroid Build Coastguard Worker            self._test_basic_ops_shape(0, 0, [10, 10, 10], [], dtype, device, coalesced)
1945*da0073e9SAndroid Build Coastguard Worker            self._test_basic_ops_shape(0, 0, [10, 10, 0], [], dtype, device, coalesced)
1946*da0073e9SAndroid Build Coastguard Worker            self._test_basic_ops_shape(0, 0, [], [], dtype, device, coalesced)
1947*da0073e9SAndroid Build Coastguard Worker
1948*da0073e9SAndroid Build Coastguard Worker        def _test_basic_ops_hybrid():
1949*da0073e9SAndroid Build Coastguard Worker            self._test_basic_ops_shape(9, 12, [5, 6], [2, 3], dtype, device, coalesced)
1950*da0073e9SAndroid Build Coastguard Worker            self._test_basic_ops_shape(9, 12, [10, 10, 10], [3], dtype, device, coalesced)
1951*da0073e9SAndroid Build Coastguard Worker            self._test_basic_ops_shape(9, 12, [50, 30, 20], [2], dtype, device, coalesced)
1952*da0073e9SAndroid Build Coastguard Worker            self._test_basic_ops_shape(9, 12, [5, 5, 5, 5, 5, 5], [2], dtype, device, coalesced)
1953*da0073e9SAndroid Build Coastguard Worker            self._test_basic_ops_shape(0, 12, [10, 10, 10], [2], dtype, device, coalesced)
1954*da0073e9SAndroid Build Coastguard Worker            self._test_basic_ops_shape(9, 0, [10, 10, 10], [2], dtype, device, coalesced)
1955*da0073e9SAndroid Build Coastguard Worker            self._test_basic_ops_shape(0, 0, [10, 10, 10], [2], dtype, device, coalesced)
1956*da0073e9SAndroid Build Coastguard Worker            self._test_basic_ops_shape(9, 12, [10, 10, 10], [2, 0], dtype, device, coalesced)
1957*da0073e9SAndroid Build Coastguard Worker            self._test_basic_ops_shape(0, 12, [10, 10, 10], [2, 0], dtype, device, coalesced)
1958*da0073e9SAndroid Build Coastguard Worker            self._test_basic_ops_shape(9, 0, [10, 10, 10], [2, 0], dtype, device, coalesced)
1959*da0073e9SAndroid Build Coastguard Worker            self._test_basic_ops_shape(0, 0, [10, 10, 10], [2, 0], dtype, device, coalesced)
1960*da0073e9SAndroid Build Coastguard Worker            self._test_basic_ops_shape(0, 0, [10, 10, 0], [2, 0], dtype, device, coalesced)
1961*da0073e9SAndroid Build Coastguard Worker
1962*da0073e9SAndroid Build Coastguard Worker        _test_basic_ops()
1963*da0073e9SAndroid Build Coastguard Worker        _test_basic_ops_hybrid()
1964*da0073e9SAndroid Build Coastguard Worker
1965*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
1966*da0073e9SAndroid Build Coastguard Worker    def test_add_dense_sparse_mismatch(self, device, dtype):
1967*da0073e9SAndroid Build Coastguard Worker        def test_shape(dense_size, sparse_dims_shape, dense_dims_shape, sparse_size):
1968*da0073e9SAndroid Build Coastguard Worker            x = torch.zeros(dense_size, dtype=dtype, device=device)
1969*da0073e9SAndroid Build Coastguard Worker            sparse_y = self.sparse_tensor(torch.zeros(sparse_dims_shape, dtype=torch.int64, device=device),
1970*da0073e9SAndroid Build Coastguard Worker                                          torch.randn(dense_dims_shape, dtype=dtype, device=device),
1971*da0073e9SAndroid Build Coastguard Worker                                          torch.Size(sparse_size))
1972*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
1973*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
1974*da0073e9SAndroid Build Coastguard Worker                    "add: expected 'self' and 'other' to have same size"):
1975*da0073e9SAndroid Build Coastguard Worker                x + sparse_y
1976*da0073e9SAndroid Build Coastguard Worker
1977*da0073e9SAndroid Build Coastguard Worker        test_shape([3, 4], [1, 4], [4, 4, 4], [3, 4, 4])
1978*da0073e9SAndroid Build Coastguard Worker        test_shape([3, 4, 0], [1, 4], [4, 4, 4, 0], [3, 4, 4, 0])
1979*da0073e9SAndroid Build Coastguard Worker
1980*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
1981*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
1982*da0073e9SAndroid Build Coastguard Worker    def test_add_noncontiguous(self, device, dtype):
1983*da0073e9SAndroid Build Coastguard Worker        indices = self.index_tensor([[1, 2], [0, 2]], device=device)
1984*da0073e9SAndroid Build Coastguard Worker        values = torch.tensor([1.], dtype=dtype, device=device).expand(2, 3, 4, 5)
1985*da0073e9SAndroid Build Coastguard Worker        x = self.sparse_tensor(indices, values, dtype=dtype, device=device)
1986*da0073e9SAndroid Build Coastguard Worker        assert not x._values().is_contiguous()
1987*da0073e9SAndroid Build Coastguard Worker        y = x + x
1988*da0073e9SAndroid Build Coastguard Worker        expected = self.safeToDense(x) + self.safeToDense(x)
1989*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(self.safeToDense(y), expected)
1990*da0073e9SAndroid Build Coastguard Worker
1991*da0073e9SAndroid Build Coastguard Worker    def _test_sparse_mask_shape(self, nnz_x1, nnz_x2, shape_i, shape_v, dtype, device, coalesced):
1992*da0073e9SAndroid Build Coastguard Worker        shape = shape_i + (shape_v or [])
1993*da0073e9SAndroid Build Coastguard Worker        x1, _, _ = self._gen_sparse(len(shape_i), nnz_x1, shape, dtype, device, coalesced)
1994*da0073e9SAndroid Build Coastguard Worker        x2, _, _ = self._gen_sparse(len(shape_i), nnz_x2, shape, dtype, device, coalesced)
1995*da0073e9SAndroid Build Coastguard Worker
1996*da0073e9SAndroid Build Coastguard Worker        y1 = x1 + x2
1997*da0073e9SAndroid Build Coastguard Worker        y2 = x1.clone()
1998*da0073e9SAndroid Build Coastguard Worker        y2.add_(x2)
1999*da0073e9SAndroid Build Coastguard Worker        expected = self.safeToDense(x1) + self.safeToDense(x2)
2000*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(self.safeToDense(y1), expected)
2001*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(self.safeToDense(y2), expected)
2002*da0073e9SAndroid Build Coastguard Worker
2003*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
2004*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
2005*da0073e9SAndroid Build Coastguard Worker    def test_sparse_mask(self, device, dtype, coalesced):
2006*da0073e9SAndroid Build Coastguard Worker        def _test_sparse_mask_fixed():
2007*da0073e9SAndroid Build Coastguard Worker            i = self.index_tensor([
2008*da0073e9SAndroid Build Coastguard Worker                [1, 3, 0, 4],
2009*da0073e9SAndroid Build Coastguard Worker                [2, 1, 2, 3],
2010*da0073e9SAndroid Build Coastguard Worker            ], device=device)
2011*da0073e9SAndroid Build Coastguard Worker            v = torch.tensor([1, 2, 3, 4], dtype=dtype, device=device)
2012*da0073e9SAndroid Build Coastguard Worker            x = self.sparse_tensor(i, v, torch.Size([5, 4]), dtype=dtype, device=device).coalesce()
2013*da0073e9SAndroid Build Coastguard Worker            dense = torch.tensor([
2014*da0073e9SAndroid Build Coastguard Worker                [1, 2, 3, 4],
2015*da0073e9SAndroid Build Coastguard Worker                [5, 6, 7, 8],
2016*da0073e9SAndroid Build Coastguard Worker                [9, 10, 11, 12],
2017*da0073e9SAndroid Build Coastguard Worker                [13, 14, 15, 16],
2018*da0073e9SAndroid Build Coastguard Worker                [17, 18, 19, 20],
2019*da0073e9SAndroid Build Coastguard Worker            ], dtype=dtype, device=device)
2020*da0073e9SAndroid Build Coastguard Worker            exp_v = torch.tensor([7, 14, 3, 20], dtype=dtype, device=device)
2021*da0073e9SAndroid Build Coastguard Worker            res_dense_lhs = dense.sparse_mask(x)
2022*da0073e9SAndroid Build Coastguard Worker            sparse = dense.to_sparse()
2023*da0073e9SAndroid Build Coastguard Worker            res_sparse_lhs = sparse.sparse_mask(x)
2024*da0073e9SAndroid Build Coastguard Worker            expected = self.sparse_tensor(i, exp_v, torch.Size([5, 4]), dtype=dtype, device=device)
2025*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res_dense_lhs.coalesce(), expected.coalesce())
2026*da0073e9SAndroid Build Coastguard Worker            # check no side effects for the coalesce flag.
2027*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(sparse.is_coalesced())
2028*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res_sparse_lhs.coalesce(), expected.coalesce())
2029*da0073e9SAndroid Build Coastguard Worker
2030*da0073e9SAndroid Build Coastguard Worker            i = self.index_tensor([
2031*da0073e9SAndroid Build Coastguard Worker                [1, 3, 0, 4],
2032*da0073e9SAndroid Build Coastguard Worker                [2, 1, 2, 3],
2033*da0073e9SAndroid Build Coastguard Worker            ], device=device)
2034*da0073e9SAndroid Build Coastguard Worker            v = torch.empty([4, 0], dtype=dtype, device=device)
2035*da0073e9SAndroid Build Coastguard Worker            x = self.sparse_tensor(i, v, torch.Size([5, 4, 0])).coalesce()
2036*da0073e9SAndroid Build Coastguard Worker            dense = torch.empty([5, 4, 0], dtype=dtype, device=device)
2037*da0073e9SAndroid Build Coastguard Worker            exp_v = torch.empty([4, 0], dtype=dtype, device=device)
2038*da0073e9SAndroid Build Coastguard Worker            res_dense_lhs = dense.sparse_mask(x)
2039*da0073e9SAndroid Build Coastguard Worker            sparse = dense.to_sparse(2)
2040*da0073e9SAndroid Build Coastguard Worker            res_sparse_lhs = sparse.sparse_mask(x)
2041*da0073e9SAndroid Build Coastguard Worker            expected = self.sparse_tensor(i, exp_v, torch.Size([5, 4, 0]), dtype=dtype, device=device)
2042*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res_dense_lhs.coalesce(), expected.coalesce())
2043*da0073e9SAndroid Build Coastguard Worker            # check no side effects for the coalesce flag.
2044*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(sparse.is_coalesced())
2045*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res_sparse_lhs.coalesce(), expected.coalesce())
2046*da0073e9SAndroid Build Coastguard Worker
2047*da0073e9SAndroid Build Coastguard Worker        _test_sparse_mask_fixed()
2048*da0073e9SAndroid Build Coastguard Worker
2049*da0073e9SAndroid Build Coastguard Worker        self._test_sparse_mask_shape(9, 12, [5, 6], [], dtype, device, coalesced)
2050*da0073e9SAndroid Build Coastguard Worker        self._test_sparse_mask_shape(9, 12, [10, 10, 10], [], dtype, device, coalesced)
2051*da0073e9SAndroid Build Coastguard Worker        self._test_sparse_mask_shape(9, 12, [50, 30, 20], [], dtype, device, coalesced)
2052*da0073e9SAndroid Build Coastguard Worker        self._test_sparse_mask_shape(9, 12, [5, 5, 5, 5, 5, 5], [], dtype, device, coalesced)
2053*da0073e9SAndroid Build Coastguard Worker        self._test_sparse_mask_shape(0, 12, [10, 10, 10], [], dtype, device, coalesced)
2054*da0073e9SAndroid Build Coastguard Worker        self._test_sparse_mask_shape(9, 0, [10, 10, 10], [], dtype, device, coalesced)
2055*da0073e9SAndroid Build Coastguard Worker        self._test_sparse_mask_shape(0, 0, [10, 10, 10], [], dtype, device, coalesced)
2056*da0073e9SAndroid Build Coastguard Worker        self._test_sparse_mask_shape(0, 0, [10, 10, 0], [], dtype, device, coalesced)
2057*da0073e9SAndroid Build Coastguard Worker
2058*da0073e9SAndroid Build Coastguard Worker        # check repetitions and matchings in the intersection
2059*da0073e9SAndroid Build Coastguard Worker        lhs = torch.randint(0, 5, (100,), device=device)
2060*da0073e9SAndroid Build Coastguard Worker        rhs = torch.randint(0, 5, (100,), device=device).to_sparse()
2061*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(lhs.to_sparse().sparse_mask(rhs), lhs.sparse_mask(rhs))
2062*da0073e9SAndroid Build Coastguard Worker
2063*da0073e9SAndroid Build Coastguard Worker        # check coalesce
2064*da0073e9SAndroid Build Coastguard Worker        sparse_c = torch.rand(3, 3, device=device).to_sparse()
2065*da0073e9SAndroid Build Coastguard Worker        sparse_unc = torch.rand(3, 3, device=device).to_sparse()._coalesced_(False)
2066*da0073e9SAndroid Build Coastguard Worker        for lhs, rhs in [(sparse_c, sparse_unc), (sparse_unc, sparse_c)]:
2067*da0073e9SAndroid Build Coastguard Worker            res_all_sparse = lhs.sparse_mask(rhs)
2068*da0073e9SAndroid Build Coastguard Worker            res_dense_sparse = lhs.to_dense().sparse_mask(rhs)
2069*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res_all_sparse.coalesce(), res_dense_sparse.coalesce())
2070*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(rhs.is_coalesced(), res_all_sparse.is_coalesced())
2071*da0073e9SAndroid Build Coastguard Worker
2072*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
2073*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
2074*da0073e9SAndroid Build Coastguard Worker    def test_sparse_mask_hybrid(self, device, dtype, coalesced):
2075*da0073e9SAndroid Build Coastguard Worker        def _test_sparse_mask_hybrid_fixed():
2076*da0073e9SAndroid Build Coastguard Worker            i = self.index_tensor([
2077*da0073e9SAndroid Build Coastguard Worker                [1, 3, 0, 4],
2078*da0073e9SAndroid Build Coastguard Worker                [2, 1, 2, 3],
2079*da0073e9SAndroid Build Coastguard Worker            ])
2080*da0073e9SAndroid Build Coastguard Worker            v = torch.tensor([[1, 2], [2, 3], [3, 4], [4, 5]])
2081*da0073e9SAndroid Build Coastguard Worker            # TODO: This is also testing that, if coalesce is a no-op,
2082*da0073e9SAndroid Build Coastguard Worker            # the indices don't get permuted. I don't know if we actually
2083*da0073e9SAndroid Build Coastguard Worker            # want to give this invariant.
2084*da0073e9SAndroid Build Coastguard Worker            x = self.sparse_tensor(i, v, torch.Size([5, 4, 2])).coalesce()
2085*da0073e9SAndroid Build Coastguard Worker            dense = torch.tensor([
2086*da0073e9SAndroid Build Coastguard Worker                [[1, 3], [2, 2], [3, 3], [4, 2]],
2087*da0073e9SAndroid Build Coastguard Worker                [[5, 7], [6, 7], [7, 9], [8, 9]],
2088*da0073e9SAndroid Build Coastguard Worker                [[9, 2], [10, 4], [11, 1], [12, 3]],
2089*da0073e9SAndroid Build Coastguard Worker                [[13, 5], [14, 1], [15, 1], [16, 6]],
2090*da0073e9SAndroid Build Coastguard Worker                [[17, 7], [18, 2], [19, 7], [20, 1]],
2091*da0073e9SAndroid Build Coastguard Worker            ])
2092*da0073e9SAndroid Build Coastguard Worker            res_dense_lhs = dense.sparse_mask(x)
2093*da0073e9SAndroid Build Coastguard Worker            sparse = dense.to_sparse(2)
2094*da0073e9SAndroid Build Coastguard Worker            res_sparse_lhs = sparse.sparse_mask(x)
2095*da0073e9SAndroid Build Coastguard Worker            exp_v = torch.tensor([[7, 9], [14, 1], [3, 3], [20, 1]])
2096*da0073e9SAndroid Build Coastguard Worker            expected = self.sparse_tensor(i, exp_v, torch.Size([5, 4, 2]))
2097*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res_dense_lhs.coalesce(), expected.coalesce())
2098*da0073e9SAndroid Build Coastguard Worker            # check no side effects for the coalesce flag
2099*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(sparse.is_coalesced())
2100*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res_sparse_lhs.coalesce(), expected.coalesce())
2101*da0073e9SAndroid Build Coastguard Worker
2102*da0073e9SAndroid Build Coastguard Worker            i = self.index_tensor([
2103*da0073e9SAndroid Build Coastguard Worker                [1, 3, 0, 4],
2104*da0073e9SAndroid Build Coastguard Worker                [2, 1, 2, 3],
2105*da0073e9SAndroid Build Coastguard Worker            ])
2106*da0073e9SAndroid Build Coastguard Worker            v = torch.empty(4, 2, 0)
2107*da0073e9SAndroid Build Coastguard Worker            x = self.sparse_tensor(i, v, torch.Size([5, 4, 2, 0])).coalesce()
2108*da0073e9SAndroid Build Coastguard Worker            dense = torch.empty(5, 4, 2, 0)
2109*da0073e9SAndroid Build Coastguard Worker            res_dense_lhs = dense.sparse_mask(x)
2110*da0073e9SAndroid Build Coastguard Worker            sparse = dense.to_sparse(2)
2111*da0073e9SAndroid Build Coastguard Worker            res_sparse_lhs = sparse.sparse_mask(x)
2112*da0073e9SAndroid Build Coastguard Worker            exp_v = torch.empty(4, 2, 0)
2113*da0073e9SAndroid Build Coastguard Worker            expected = self.sparse_tensor(i, exp_v, torch.Size([5, 4, 2, 0]))
2114*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res_dense_lhs.coalesce(), expected.coalesce())
2115*da0073e9SAndroid Build Coastguard Worker            # check no side effects for the coalesce flag
2116*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(sparse.is_coalesced())
2117*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res_sparse_lhs.coalesce(), expected.coalesce())
2118*da0073e9SAndroid Build Coastguard Worker
2119*da0073e9SAndroid Build Coastguard Worker        _test_sparse_mask_hybrid_fixed()
2120*da0073e9SAndroid Build Coastguard Worker
2121*da0073e9SAndroid Build Coastguard Worker        self._test_sparse_mask_shape(9, 12, [5, 6], [2, 3], dtype, device, coalesced)
2122*da0073e9SAndroid Build Coastguard Worker        self._test_sparse_mask_shape(9, 12, [10, 10, 10], [3], dtype, device, coalesced)
2123*da0073e9SAndroid Build Coastguard Worker        self._test_sparse_mask_shape(9, 12, [50, 30, 20], [2], dtype, device, coalesced)
2124*da0073e9SAndroid Build Coastguard Worker        self._test_sparse_mask_shape(9, 12, [5, 5, 5, 5, 5, 5], [2], dtype, device, coalesced)
2125*da0073e9SAndroid Build Coastguard Worker        self._test_sparse_mask_shape(0, 12, [10, 10, 10], [2], dtype, device, coalesced)
2126*da0073e9SAndroid Build Coastguard Worker        self._test_sparse_mask_shape(9, 0, [10, 10, 10], [2], dtype, device, coalesced)
2127*da0073e9SAndroid Build Coastguard Worker        self._test_sparse_mask_shape(0, 0, [10, 10, 10], [2], dtype, device, coalesced)
2128*da0073e9SAndroid Build Coastguard Worker        self._test_sparse_mask_shape(9, 12, [10, 10, 10], [2, 0], dtype, device, coalesced)
2129*da0073e9SAndroid Build Coastguard Worker        self._test_sparse_mask_shape(0, 12, [10, 10, 10], [2, 0], dtype, device, coalesced)
2130*da0073e9SAndroid Build Coastguard Worker        self._test_sparse_mask_shape(9, 0, [10, 10, 10], [2, 0], dtype, device, coalesced)
2131*da0073e9SAndroid Build Coastguard Worker        self._test_sparse_mask_shape(0, 0, [10, 10, 10], [2, 0], dtype, device, coalesced)
2132*da0073e9SAndroid Build Coastguard Worker        self._test_sparse_mask_shape(0, 0, [10, 10, 0], [2, 0], dtype, device, coalesced)
2133*da0073e9SAndroid Build Coastguard Worker
2134*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
2135*da0073e9SAndroid Build Coastguard Worker    @skipIfCrossRef
2136*da0073e9SAndroid Build Coastguard Worker    def test_sparse_mask_backward(self, device, dtype):
2137*da0073e9SAndroid Build Coastguard Worker        from itertools import product, repeat
2138*da0073e9SAndroid Build Coastguard Worker
2139*da0073e9SAndroid Build Coastguard Worker        shape = (5, 5)
2140*da0073e9SAndroid Build Coastguard Worker        sparse_dims = len(shape)
2141*da0073e9SAndroid Build Coastguard Worker        nnzs = (0, 5, 15, 25)
2142*da0073e9SAndroid Build Coastguard Worker
2143*da0073e9SAndroid Build Coastguard Worker        lhs_data = torch.arange(1, 26, device=device).reshape(shape).to(dtype).to_sparse(sparse_dims)
2144*da0073e9SAndroid Build Coastguard Worker        rhs_data = lhs_data.clone()
2145*da0073e9SAndroid Build Coastguard Worker
2146*da0073e9SAndroid Build Coastguard Worker        for nnz in nnzs:
2147*da0073e9SAndroid Build Coastguard Worker            for lhs_is_coalesced, rhs_is_coalesced in product(*repeat((True, False), 2)):
2148*da0073e9SAndroid Build Coastguard Worker                lhs = torch.sparse_coo_tensor(
2149*da0073e9SAndroid Build Coastguard Worker                    lhs_data._indices()[:, :nnz],
2150*da0073e9SAndroid Build Coastguard Worker                    lhs_data._values()[:nnz],
2151*da0073e9SAndroid Build Coastguard Worker                    lhs_data.shape
2152*da0073e9SAndroid Build Coastguard Worker                ).clone()._coalesced_(lhs_is_coalesced).requires_grad_(True)
2153*da0073e9SAndroid Build Coastguard Worker
2154*da0073e9SAndroid Build Coastguard Worker                rhs = torch.sparse_coo_tensor(
2155*da0073e9SAndroid Build Coastguard Worker                    lhs_data._indices()[:, -nnz:],
2156*da0073e9SAndroid Build Coastguard Worker                    lhs_data._values()[-nnz:],
2157*da0073e9SAndroid Build Coastguard Worker                    lhs_data.shape
2158*da0073e9SAndroid Build Coastguard Worker                ).clone()._coalesced_(rhs_is_coalesced)
2159*da0073e9SAndroid Build Coastguard Worker
2160*da0073e9SAndroid Build Coastguard Worker                # To test masked semantics we need to make sure that
2161*da0073e9SAndroid Build Coastguard Worker                # sparsity_pattern(lhs) == sparsity_pattern(lhs.grad).
2162*da0073e9SAndroid Build Coastguard Worker                # lhs.sparse_mask(lhs_mask) accomplishes that.
2163*da0073e9SAndroid Build Coastguard Worker                lhs_mask = lhs.detach().clone()
2164*da0073e9SAndroid Build Coastguard Worker                gradcheck(lambda x: x.sparse_mask(lhs_mask).sparse_mask(rhs).to_dense(masked_grad=True), (lhs,), masked=True)
2165*da0073e9SAndroid Build Coastguard Worker                gradcheck(lambda x: x.sparse_mask(rhs).to_dense(masked_grad=False), (lhs,), masked=False)
2166*da0073e9SAndroid Build Coastguard Worker
2167*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
2168*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
2169*da0073e9SAndroid Build Coastguard Worker    def test_zeros(self, device, dtype, coalesced):
2170*da0073e9SAndroid Build Coastguard Worker        def _test_zeros(nnzs, shape, out_shape_i, out_shape_v=None):
2171*da0073e9SAndroid Build Coastguard Worker            out_shape = out_shape_i + (out_shape_v or [])
2172*da0073e9SAndroid Build Coastguard Worker            for nnz in nnzs:
2173*da0073e9SAndroid Build Coastguard Worker                out, _, _ = self._gen_sparse(len(out_shape_i), nnz, out_shape, dtype, device, coalesced)
2174*da0073e9SAndroid Build Coastguard Worker                torch.zeros(*shape, out=out, dtype=dtype, device=device)
2175*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(tuple(out.size()), tuple(shape))
2176*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(out._indices().numel() == out._values().numel() == 0)
2177*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out._nnz(), 0)
2178*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out.sparse_dim(), len(shape))
2179*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out.dense_dim(), 0)
2180*da0073e9SAndroid Build Coastguard Worker
2181*da0073e9SAndroid Build Coastguard Worker        def test_shape(i_shapes, v_shapes, shape, nnzs):
2182*da0073e9SAndroid Build Coastguard Worker            for i_dim in range(1, len(i_shapes) + 1):
2183*da0073e9SAndroid Build Coastguard Worker                for v_dim in range(len(v_shapes) + 1):
2184*da0073e9SAndroid Build Coastguard Worker                    _test_zeros(nnzs, shape, i_shapes[:i_dim], v_shapes[:v_dim])
2185*da0073e9SAndroid Build Coastguard Worker        test_shape([2, 3, 4], [3, 4, 5, 6], [2, 3, 4], [9, 12])
2186*da0073e9SAndroid Build Coastguard Worker        test_shape([0, 3, 4], [3, 4, 5, 6], [2, 3, 4], [0])
2187*da0073e9SAndroid Build Coastguard Worker        test_shape([2, 3, 4], [0, 4, 5, 6], [2, 3, 4], [9, 12])
2188*da0073e9SAndroid Build Coastguard Worker        test_shape([2, 3, 4], [3, 4, 5, 6], [2, 3, 0], [9, 12])
2189*da0073e9SAndroid Build Coastguard Worker        test_shape([0, 3, 4], [3, 4, 5, 6], [2, 3, 0], [0])
2190*da0073e9SAndroid Build Coastguard Worker        test_shape([2, 3, 4], [0, 4, 5, 6], [2, 3, 0], [9, 12])
2191*da0073e9SAndroid Build Coastguard Worker
2192*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
2193*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
2194*da0073e9SAndroid Build Coastguard Worker    def test_zeros_like(self, device, dtype, coalesced):
2195*da0073e9SAndroid Build Coastguard Worker        def _test_zeros_like(nnzs, template_shape_i, template_shape_v=None):
2196*da0073e9SAndroid Build Coastguard Worker            template_shape_v = template_shape_v or []
2197*da0073e9SAndroid Build Coastguard Worker            template_shape = template_shape_i + template_shape_v
2198*da0073e9SAndroid Build Coastguard Worker            for nnz in nnzs:
2199*da0073e9SAndroid Build Coastguard Worker                t, _, _ = self._gen_sparse(len(template_shape_i), nnz, template_shape, dtype, device, coalesced)
2200*da0073e9SAndroid Build Coastguard Worker                res = torch.zeros_like(t)
2201*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(tuple(res.size()), tuple(template_shape))
2202*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(res._indices().numel() == res._values().numel() == 0)
2203*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res._nnz(), 0)
2204*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res.sparse_dim(), len(template_shape_i))
2205*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res.dense_dim(), len(template_shape_v))
2206*da0073e9SAndroid Build Coastguard Worker
2207*da0073e9SAndroid Build Coastguard Worker        def test_shape(i_shapes, v_shapes, nnzs):
2208*da0073e9SAndroid Build Coastguard Worker            for i_dim in range(1, len(i_shapes) + 1):
2209*da0073e9SAndroid Build Coastguard Worker                for v_dim in range(len(v_shapes) + 1):
2210*da0073e9SAndroid Build Coastguard Worker                    _test_zeros_like(nnzs, i_shapes[:i_dim], v_shapes[:v_dim])
2211*da0073e9SAndroid Build Coastguard Worker        test_shape([2, 3, 4], [3, 4, 5, 6], [9, 12])
2212*da0073e9SAndroid Build Coastguard Worker        test_shape([0, 3, 4], [3, 4, 5, 6], [0])
2213*da0073e9SAndroid Build Coastguard Worker        test_shape([2, 3, 4], [0, 4, 5, 6], [9, 12])
2214*da0073e9SAndroid Build Coastguard Worker        test_shape([2, 3, 4], [3, 4, 5, 6], [9, 12])
2215*da0073e9SAndroid Build Coastguard Worker        test_shape([0, 3, 4], [3, 4, 5, 6], [0])
2216*da0073e9SAndroid Build Coastguard Worker        test_shape([2, 3, 4], [0, 4, 5, 6], [9, 12])
2217*da0073e9SAndroid Build Coastguard Worker
2218*da0073e9SAndroid Build Coastguard Worker        sparse_tensor, _, _ = self._gen_sparse(len([2, 3]), 9, [2, 3] + [5, 6], dtype, device, coalesced)
2219*da0073e9SAndroid Build Coastguard Worker        data = (sparse_tensor, sparse_tensor, sparse_tensor, sparse_tensor.unsqueeze(0))
2220*da0073e9SAndroid Build Coastguard Worker        mem_formats = [torch.channels_last, torch.contiguous_format, torch.preserve_format, torch.channels_last_3d]
2221*da0073e9SAndroid Build Coastguard Worker        for x, mem_format in zip(data, mem_formats):
2222*da0073e9SAndroid Build Coastguard Worker
2223*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "memory format option is only supported by strided tensors"):
2224*da0073e9SAndroid Build Coastguard Worker                result = torch.zeros_like(x, memory_format=mem_format)
2225*da0073e9SAndroid Build Coastguard Worker
2226*da0073e9SAndroid Build Coastguard Worker            result = torch.zeros_like(x, layout=torch.strided, memory_format=mem_format)
2227*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(result.layout == torch.strided)
2228*da0073e9SAndroid Build Coastguard Worker
2229*da0073e9SAndroid Build Coastguard Worker        dense_tensor = sparse_tensor.to_dense()
2230*da0073e9SAndroid Build Coastguard Worker        result = torch.zeros_like(dense_tensor, layout=torch.sparse_coo)
2231*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dense_tensor.shape, result.shape)
2232*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.layout, torch.sparse_coo)
2233*da0073e9SAndroid Build Coastguard Worker
2234*da0073e9SAndroid Build Coastguard Worker        sparse_zeros = torch.sparse_coo_tensor(dense_tensor.shape)
2235*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result._indices().shape, sparse_zeros._indices().shape)
2236*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result._values().shape, sparse_zeros._values().shape)
2237*da0073e9SAndroid Build Coastguard Worker
2238*da0073e9SAndroid Build Coastguard Worker    def _assert_sparse_invars(self, t):
2239*da0073e9SAndroid Build Coastguard Worker        # SparseTensor has the following invariants:
2240*da0073e9SAndroid Build Coastguard Worker        # - sparse_dim + dense_dim = len(SparseTensor.shape)
2241*da0073e9SAndroid Build Coastguard Worker        # - SparseTensor._indices().shape = (sparse_dim, nnz)
2242*da0073e9SAndroid Build Coastguard Worker        # - SparseTensor._values().shape = (nnz, SparseTensor.shape[sparse_dim:])
2243*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.sparse_dim() + t.dense_dim(), len(t.shape))
2244*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tuple(t._indices().shape), (t.sparse_dim(), t._nnz()))
2245*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tuple(t._values().shape), (t._nnz(), ) + t.shape[t.sparse_dim():])
2246*da0073e9SAndroid Build Coastguard Worker
2247*da0073e9SAndroid Build Coastguard Worker    def _test_empty_like(self, sparse_tensor, dtype, device, coalesced):
2248*da0073e9SAndroid Build Coastguard Worker
2249*da0073e9SAndroid Build Coastguard Worker        result = torch.empty_like(sparse_tensor)
2250*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(result.is_sparse)
2251*da0073e9SAndroid Build Coastguard Worker        self._assert_sparse_invars(result)
2252*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.shape, sparse_tensor.shape)
2253*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.dtype, sparse_tensor.dtype)
2254*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.device, sparse_tensor.device)
2255*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.sparse_dim(), sparse_tensor.sparse_dim())
2256*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.dense_dim(), sparse_tensor.dense_dim())
2257*da0073e9SAndroid Build Coastguard Worker
2258*da0073e9SAndroid Build Coastguard Worker        sparse_tensor, _, _ = self._gen_sparse(len([2, 3]), 9, [2, 3] + [5, 6], dtype, device, coalesced)
2259*da0073e9SAndroid Build Coastguard Worker        data = (sparse_tensor, sparse_tensor, sparse_tensor, sparse_tensor.unsqueeze(0))
2260*da0073e9SAndroid Build Coastguard Worker        mem_formats = [torch.channels_last, torch.contiguous_format, torch.preserve_format, torch.channels_last_3d]
2261*da0073e9SAndroid Build Coastguard Worker        for x, mem_format in zip(data, mem_formats):
2262*da0073e9SAndroid Build Coastguard Worker
2263*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "memory format option is only supported by strided tensors"):
2264*da0073e9SAndroid Build Coastguard Worker                result = torch.empty_like(x, memory_format=mem_format)
2265*da0073e9SAndroid Build Coastguard Worker
2266*da0073e9SAndroid Build Coastguard Worker            result = torch.empty_like(x, layout=torch.strided, memory_format=mem_format)
2267*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(result.layout == torch.strided)
2268*da0073e9SAndroid Build Coastguard Worker
2269*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
2270*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"Could not run 'aten::empty_strided' with arguments from the 'Sparse(CPU|CUDA)' backend"
2271*da0073e9SAndroid Build Coastguard Worker        ):
2272*da0073e9SAndroid Build Coastguard Worker            dense_tensor = sparse_tensor.to_dense()
2273*da0073e9SAndroid Build Coastguard Worker            result = torch.empty_like(dense_tensor, layout=torch.sparse_coo)
2274*da0073e9SAndroid Build Coastguard Worker
2275*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
2276*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
2277*da0073e9SAndroid Build Coastguard Worker    def test_empty_like(self, device, dtype, coalesced):
2278*da0073e9SAndroid Build Coastguard Worker        # tests https://github.com/pytorch/pytorch/issues/43699
2279*da0073e9SAndroid Build Coastguard Worker
2280*da0073e9SAndroid Build Coastguard Worker        if coalesced:
2281*da0073e9SAndroid Build Coastguard Worker            input_coalesced = torch.sparse_coo_tensor(
2282*da0073e9SAndroid Build Coastguard Worker                indices=torch.tensor([[0, 1, 2]]),
2283*da0073e9SAndroid Build Coastguard Worker                values=torch.tensor([3.0, -4.0, 5.0]),
2284*da0073e9SAndroid Build Coastguard Worker                size=[3, ],
2285*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
2286*da0073e9SAndroid Build Coastguard Worker                device=device
2287*da0073e9SAndroid Build Coastguard Worker            ).coalesce()
2288*da0073e9SAndroid Build Coastguard Worker            self._test_empty_like(input_coalesced, dtype, device, coalesced)
2289*da0073e9SAndroid Build Coastguard Worker
2290*da0073e9SAndroid Build Coastguard Worker            # hybrid sparse input
2291*da0073e9SAndroid Build Coastguard Worker            input_coalesced = torch.sparse_coo_tensor(
2292*da0073e9SAndroid Build Coastguard Worker                indices=torch.tensor([[1, 3], [2, 4]]),
2293*da0073e9SAndroid Build Coastguard Worker                values=torch.tensor([[-1.0, 3.0], [-5.0, 7.0]]),
2294*da0073e9SAndroid Build Coastguard Worker                size=[4, 5, 2],
2295*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
2296*da0073e9SAndroid Build Coastguard Worker                device=device
2297*da0073e9SAndroid Build Coastguard Worker            ).coalesce()
2298*da0073e9SAndroid Build Coastguard Worker            self._test_empty_like(input_coalesced, dtype, device, coalesced)
2299*da0073e9SAndroid Build Coastguard Worker
2300*da0073e9SAndroid Build Coastguard Worker        if not coalesced:
2301*da0073e9SAndroid Build Coastguard Worker            # test uncoalesced input
2302*da0073e9SAndroid Build Coastguard Worker            input_uncoalesced = torch.sparse_coo_tensor(
2303*da0073e9SAndroid Build Coastguard Worker                indices=torch.tensor([[0], [1], [2], [0], [1], [2]]).transpose(1, 0),
2304*da0073e9SAndroid Build Coastguard Worker                values=torch.tensor([2.0, -3.0, -4.0, 1.0, -1.0, 1.5]),
2305*da0073e9SAndroid Build Coastguard Worker                size=[3, ],
2306*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
2307*da0073e9SAndroid Build Coastguard Worker                device=device
2308*da0073e9SAndroid Build Coastguard Worker            )
2309*da0073e9SAndroid Build Coastguard Worker            self._test_empty_like(input_uncoalesced, dtype, device, coalesced)
2310*da0073e9SAndroid Build Coastguard Worker
2311*da0073e9SAndroid Build Coastguard Worker            # test on empty sparse tensor
2312*da0073e9SAndroid Build Coastguard Worker            input_uncoalesced = torch.sparse_coo_tensor(
2313*da0073e9SAndroid Build Coastguard Worker                indices=torch.zeros([2, 0]),
2314*da0073e9SAndroid Build Coastguard Worker                values=torch.zeros([0, 5, 5, 5, 5, 5, 5, 0]),
2315*da0073e9SAndroid Build Coastguard Worker                size=[0, 0, 5, 5, 5, 5, 5, 5, 0],
2316*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
2317*da0073e9SAndroid Build Coastguard Worker                device=device
2318*da0073e9SAndroid Build Coastguard Worker            )
2319*da0073e9SAndroid Build Coastguard Worker            self._test_empty_like(input_uncoalesced, dtype, device, coalesced)
2320*da0073e9SAndroid Build Coastguard Worker
2321*da0073e9SAndroid Build Coastguard Worker    def _test_narrow(self, input, narrow_args):
2322*da0073e9SAndroid Build Coastguard Worker        expected = input.to_dense().narrow(*narrow_args)
2323*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, input.narrow_copy(*narrow_args).to_dense())
2324*da0073e9SAndroid Build Coastguard Worker
2325*da0073e9SAndroid Build Coastguard Worker    def _all_narrow_combs(self, shape):
2326*da0073e9SAndroid Build Coastguard Worker        for dim, dim_sz in enumerate(shape):
2327*da0073e9SAndroid Build Coastguard Worker            for start in range(dim_sz):
2328*da0073e9SAndroid Build Coastguard Worker                for length in range(dim_sz - start):
2329*da0073e9SAndroid Build Coastguard Worker                    yield [dim, start, length]
2330*da0073e9SAndroid Build Coastguard Worker
2331*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
2332*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
2333*da0073e9SAndroid Build Coastguard Worker    def test_narrow(self, device, dtype, coalesced):
2334*da0073e9SAndroid Build Coastguard Worker        shape = [3, 3, 4, 2]
2335*da0073e9SAndroid Build Coastguard Worker        input, _, _ = self._gen_sparse(4, 19, shape, dtype, device, coalesced)
2336*da0073e9SAndroid Build Coastguard Worker        for narrow_args in self._all_narrow_combs(shape):
2337*da0073e9SAndroid Build Coastguard Worker            self._test_narrow(input, narrow_args)
2338*da0073e9SAndroid Build Coastguard Worker
2339*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: input.narrow_copy(-1, 0, 3))  # dim < 0
2340*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: input.narrow_copy(10, 0, 3))  # dim > input.dim()
2341*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: input.narrow_copy(0, shape[0] + 1, 3))  # start > size of dim
2342*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: input.narrow_copy(0, 2, shape[0]))  # start+length > size of dim
2343*da0073e9SAndroid Build Coastguard Worker
2344*da0073e9SAndroid Build Coastguard Worker        with_dense, _, _ = self._gen_sparse(2, 7, shape, dtype, device, coalesced)
2345*da0073e9SAndroid Build Coastguard Worker        for narrow_args in self._all_narrow_combs(shape):
2346*da0073e9SAndroid Build Coastguard Worker            self._test_narrow(with_dense, narrow_args)
2347*da0073e9SAndroid Build Coastguard Worker
2348*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: with_dense.narrow_copy(10, 0, 3))  # dim > sparseDim + denseDim
2349*da0073e9SAndroid Build Coastguard Worker
2350*da0073e9SAndroid Build Coastguard Worker    def _test_log1p_tensor(self, sparse_tensor, coalesced):
2351*da0073e9SAndroid Build Coastguard Worker        def is_integral(dtype):
2352*da0073e9SAndroid Build Coastguard Worker            return dtype in integral_types()
2353*da0073e9SAndroid Build Coastguard Worker
2354*da0073e9SAndroid Build Coastguard Worker        dense_tensor = sparse_tensor.to_dense()
2355*da0073e9SAndroid Build Coastguard Worker        expected_output = dense_tensor.log1p()
2356*da0073e9SAndroid Build Coastguard Worker        is_integral_dtype = is_integral(sparse_tensor.dtype)
2357*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_output, sparse_tensor.log1p().to_dense())
2358*da0073e9SAndroid Build Coastguard Worker        if is_integral_dtype:
2359*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "result type .* can't be cast to"):
2360*da0073e9SAndroid Build Coastguard Worker                sparse_tensor.coalesce().log1p_()
2361*da0073e9SAndroid Build Coastguard Worker        else:
2362*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected_output, sparse_tensor.coalesce().log1p_().to_dense())
2363*da0073e9SAndroid Build Coastguard Worker
2364*da0073e9SAndroid Build Coastguard Worker        if not coalesced:
2365*da0073e9SAndroid Build Coastguard Worker            # test in-place op on uncoalesced input
2366*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "log1p_ requires coalesced input"):
2367*da0073e9SAndroid Build Coastguard Worker                sparse_tensor.log1p_()
2368*da0073e9SAndroid Build Coastguard Worker
2369*da0073e9SAndroid Build Coastguard Worker        if is_integral_dtype:
2370*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "only Tensors of floating point dtype can require gradients"):
2371*da0073e9SAndroid Build Coastguard Worker                sparse_tensor.requires_grad_()
2372*da0073e9SAndroid Build Coastguard Worker
2373*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
2374*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types())
2375*da0073e9SAndroid Build Coastguard Worker    def test_log1p(self, device, dtype, coalesced):
2376*da0073e9SAndroid Build Coastguard Worker        if coalesced:
2377*da0073e9SAndroid Build Coastguard Worker            input_coalesced = torch.sparse_coo_tensor(
2378*da0073e9SAndroid Build Coastguard Worker                indices=torch.tensor([[0], [1], [2]]).transpose(1, 0),
2379*da0073e9SAndroid Build Coastguard Worker                values=torch.tensor([3.0, 4.0, 5.0]),
2380*da0073e9SAndroid Build Coastguard Worker                size=[3, ],
2381*da0073e9SAndroid Build Coastguard Worker                device=device,
2382*da0073e9SAndroid Build Coastguard Worker                dtype=dtype
2383*da0073e9SAndroid Build Coastguard Worker            ).coalesce()
2384*da0073e9SAndroid Build Coastguard Worker            self._test_log1p_tensor(input_coalesced, coalesced)
2385*da0073e9SAndroid Build Coastguard Worker
2386*da0073e9SAndroid Build Coastguard Worker            # hybrid sparse input
2387*da0073e9SAndroid Build Coastguard Worker            input_coalesced = torch.sparse_coo_tensor(
2388*da0073e9SAndroid Build Coastguard Worker                indices=torch.tensor([[1, 3], [2, 4]]),
2389*da0073e9SAndroid Build Coastguard Worker                values=torch.tensor([[1.0, 3.0], [5.0, 7.0]]),
2390*da0073e9SAndroid Build Coastguard Worker                size=[4, 5, 2],
2391*da0073e9SAndroid Build Coastguard Worker                device=device,
2392*da0073e9SAndroid Build Coastguard Worker                dtype=dtype
2393*da0073e9SAndroid Build Coastguard Worker            ).coalesce()
2394*da0073e9SAndroid Build Coastguard Worker            self._test_log1p_tensor(input_coalesced, coalesced)
2395*da0073e9SAndroid Build Coastguard Worker
2396*da0073e9SAndroid Build Coastguard Worker        if not coalesced:
2397*da0073e9SAndroid Build Coastguard Worker            # test uncoalesced input
2398*da0073e9SAndroid Build Coastguard Worker            input_uncoalesced = torch.sparse_coo_tensor(
2399*da0073e9SAndroid Build Coastguard Worker                indices=torch.tensor([[0], [1], [2], [0], [1], [2]]).transpose(1, 0),
2400*da0073e9SAndroid Build Coastguard Worker                values=torch.tensor([2.0, 3.0, 4.0, 1.0, 1.0, 1.0]),
2401*da0073e9SAndroid Build Coastguard Worker                size=[3, ],
2402*da0073e9SAndroid Build Coastguard Worker                device=device,
2403*da0073e9SAndroid Build Coastguard Worker                dtype=dtype
2404*da0073e9SAndroid Build Coastguard Worker            )
2405*da0073e9SAndroid Build Coastguard Worker            self._test_log1p_tensor(input_uncoalesced, coalesced)
2406*da0073e9SAndroid Build Coastguard Worker
2407*da0073e9SAndroid Build Coastguard Worker            # test on empty sparse tensor
2408*da0073e9SAndroid Build Coastguard Worker            input_uncoalesced = torch.sparse_coo_tensor(
2409*da0073e9SAndroid Build Coastguard Worker                indices=torch.zeros([2, 0]),
2410*da0073e9SAndroid Build Coastguard Worker                values=torch.zeros([0, 5, 5, 5, 5, 5, 5, 0]),
2411*da0073e9SAndroid Build Coastguard Worker                size=[0, 0, 5, 5, 5, 5, 5, 5, 0],
2412*da0073e9SAndroid Build Coastguard Worker                device=device,
2413*da0073e9SAndroid Build Coastguard Worker                dtype=dtype
2414*da0073e9SAndroid Build Coastguard Worker            )
2415*da0073e9SAndroid Build Coastguard Worker            # empty tensors are coalesced at creation (nnz < 2) we must force the uncoalesced state
2416*da0073e9SAndroid Build Coastguard Worker            input_uncoalesced._coalesced_(False)
2417*da0073e9SAndroid Build Coastguard Worker            self._test_log1p_tensor(input_uncoalesced, coalesced)
2418*da0073e9SAndroid Build Coastguard Worker
2419*da0073e9SAndroid Build Coastguard Worker    def _test_neg_negative(self, sparse_tensor):
2420*da0073e9SAndroid Build Coastguard Worker        dense_tensor = sparse_tensor.to_dense()
2421*da0073e9SAndroid Build Coastguard Worker        expected_output = dense_tensor.neg()
2422*da0073e9SAndroid Build Coastguard Worker
2423*da0073e9SAndroid Build Coastguard Worker        ops = (
2424*da0073e9SAndroid Build Coastguard Worker            torch.neg, torch.Tensor.neg, torch.Tensor.neg_,
2425*da0073e9SAndroid Build Coastguard Worker            torch.negative, torch.Tensor.negative, torch.Tensor.negative_,
2426*da0073e9SAndroid Build Coastguard Worker            operator.neg
2427*da0073e9SAndroid Build Coastguard Worker        )
2428*da0073e9SAndroid Build Coastguard Worker        for op in ops:
2429*da0073e9SAndroid Build Coastguard Worker            sparse_tensor_copy = sparse_tensor.clone()
2430*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected_output, op(sparse_tensor_copy).to_dense())
2431*da0073e9SAndroid Build Coastguard Worker
2432*da0073e9SAndroid Build Coastguard Worker            if op in (torch.neg, torch.negative):
2433*da0073e9SAndroid Build Coastguard Worker                sparse_tensor_out = torch.zeros_like(sparse_tensor)
2434*da0073e9SAndroid Build Coastguard Worker                op(sparse_tensor, out=sparse_tensor_out)
2435*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected_output, sparse_tensor_out.to_dense())
2436*da0073e9SAndroid Build Coastguard Worker
2437*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
2438*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
2439*da0073e9SAndroid Build Coastguard Worker    def test_neg_negative(self, device, dtype, coalesced):
2440*da0073e9SAndroid Build Coastguard Worker
2441*da0073e9SAndroid Build Coastguard Worker        if coalesced:
2442*da0073e9SAndroid Build Coastguard Worker            input_coalesced = torch.sparse_coo_tensor(
2443*da0073e9SAndroid Build Coastguard Worker                indices=torch.tensor([[0, 1, 2]]),
2444*da0073e9SAndroid Build Coastguard Worker                values=torch.tensor([3.0, -4.0, 5.0]),
2445*da0073e9SAndroid Build Coastguard Worker                size=[3, ],
2446*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
2447*da0073e9SAndroid Build Coastguard Worker                device=device
2448*da0073e9SAndroid Build Coastguard Worker            ).coalesce()
2449*da0073e9SAndroid Build Coastguard Worker            self._test_neg_negative(input_coalesced)
2450*da0073e9SAndroid Build Coastguard Worker
2451*da0073e9SAndroid Build Coastguard Worker            # hybrid sparse input
2452*da0073e9SAndroid Build Coastguard Worker            input_coalesced = torch.sparse_coo_tensor(
2453*da0073e9SAndroid Build Coastguard Worker                indices=torch.tensor([[1, 3], [2, 4]]),
2454*da0073e9SAndroid Build Coastguard Worker                values=torch.tensor([[-1.0, 3.0], [-5.0, 7.0]]),
2455*da0073e9SAndroid Build Coastguard Worker                size=[4, 5, 2],
2456*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
2457*da0073e9SAndroid Build Coastguard Worker                device=device
2458*da0073e9SAndroid Build Coastguard Worker            ).coalesce()
2459*da0073e9SAndroid Build Coastguard Worker            self._test_neg_negative(input_coalesced)
2460*da0073e9SAndroid Build Coastguard Worker
2461*da0073e9SAndroid Build Coastguard Worker        if not coalesced:
2462*da0073e9SAndroid Build Coastguard Worker            # test uncoalesced input
2463*da0073e9SAndroid Build Coastguard Worker            input_uncoalesced = torch.sparse_coo_tensor(
2464*da0073e9SAndroid Build Coastguard Worker                indices=torch.tensor([[0], [1], [2], [0], [1], [2]]).transpose(1, 0),
2465*da0073e9SAndroid Build Coastguard Worker                values=torch.tensor([2.0, -3.0, -4.0, 1.0, -1.0, 1.5]),
2466*da0073e9SAndroid Build Coastguard Worker                size=[3, ],
2467*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
2468*da0073e9SAndroid Build Coastguard Worker                device=device
2469*da0073e9SAndroid Build Coastguard Worker            )
2470*da0073e9SAndroid Build Coastguard Worker            self._test_neg_negative(input_uncoalesced)
2471*da0073e9SAndroid Build Coastguard Worker
2472*da0073e9SAndroid Build Coastguard Worker            # test on empty sparse tensor
2473*da0073e9SAndroid Build Coastguard Worker            input_uncoalesced = torch.sparse_coo_tensor(
2474*da0073e9SAndroid Build Coastguard Worker                indices=torch.zeros([2, 0]),
2475*da0073e9SAndroid Build Coastguard Worker                values=torch.zeros([0, 5, 5, 5, 5, 5, 5, 0]),
2476*da0073e9SAndroid Build Coastguard Worker                size=[0, 0, 5, 5, 5, 5, 5, 5, 0],
2477*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
2478*da0073e9SAndroid Build Coastguard Worker                device=device
2479*da0073e9SAndroid Build Coastguard Worker            )
2480*da0073e9SAndroid Build Coastguard Worker            self._test_neg_negative(input_uncoalesced)
2481*da0073e9SAndroid Build Coastguard Worker
2482*da0073e9SAndroid Build Coastguard Worker    def _test_asin_arcsin(self, sparse_tensor, coalesced):
2483*da0073e9SAndroid Build Coastguard Worker        def is_integral(dtype):
2484*da0073e9SAndroid Build Coastguard Worker            return dtype in integral_types()
2485*da0073e9SAndroid Build Coastguard Worker        is_integral_dtype = is_integral(sparse_tensor.dtype)
2486*da0073e9SAndroid Build Coastguard Worker
2487*da0073e9SAndroid Build Coastguard Worker        dense_tensor = sparse_tensor.to_dense()
2488*da0073e9SAndroid Build Coastguard Worker        expected_output = dense_tensor.asin()
2489*da0073e9SAndroid Build Coastguard Worker
2490*da0073e9SAndroid Build Coastguard Worker        ops = (
2491*da0073e9SAndroid Build Coastguard Worker            torch.asin, torch.Tensor.asin,
2492*da0073e9SAndroid Build Coastguard Worker            torch.arcsin, torch.Tensor.arcsin,
2493*da0073e9SAndroid Build Coastguard Worker        )
2494*da0073e9SAndroid Build Coastguard Worker        for op in ops:
2495*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected_output, op(sparse_tensor).to_dense())
2496*da0073e9SAndroid Build Coastguard Worker            if op in (torch.asin, torch.arcsin):
2497*da0073e9SAndroid Build Coastguard Worker                sparse_tensor_out = torch.zeros_like(sparse_tensor)
2498*da0073e9SAndroid Build Coastguard Worker                if not is_integral_dtype:
2499*da0073e9SAndroid Build Coastguard Worker                    op(sparse_tensor, out=sparse_tensor_out)
2500*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(expected_output, sparse_tensor_out.to_dense())
2501*da0073e9SAndroid Build Coastguard Worker                else:
2502*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(RuntimeError, "result type .* can't be cast to"):
2503*da0073e9SAndroid Build Coastguard Worker                        op(sparse_tensor, out=sparse_tensor_out)
2504*da0073e9SAndroid Build Coastguard Worker
2505*da0073e9SAndroid Build Coastguard Worker        for op in (torch.Tensor.asin_, torch.Tensor.arcsin_):
2506*da0073e9SAndroid Build Coastguard Worker            if is_integral_dtype:
2507*da0073e9SAndroid Build Coastguard Worker                # test coalesce on integral dtype tensor
2508*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, "result type .* can't be cast to"):
2509*da0073e9SAndroid Build Coastguard Worker                    op(sparse_tensor.clone().coalesce()).to_dense()
2510*da0073e9SAndroid Build Coastguard Worker            else:
2511*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected_output, op(sparse_tensor.clone().coalesce()).to_dense())
2512*da0073e9SAndroid Build Coastguard Worker
2513*da0073e9SAndroid Build Coastguard Worker            if not coalesced:
2514*da0073e9SAndroid Build Coastguard Worker                # test in-place op on uncoalesced input
2515*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, "asin_ requires coalesced input"):
2516*da0073e9SAndroid Build Coastguard Worker                    op(sparse_tensor)
2517*da0073e9SAndroid Build Coastguard Worker
2518*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
2519*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types())
2520*da0073e9SAndroid Build Coastguard Worker    def test_asin_arcsin(self, device, dtype, coalesced):
2521*da0073e9SAndroid Build Coastguard Worker        if coalesced:
2522*da0073e9SAndroid Build Coastguard Worker            input_coalesced = torch.sparse_coo_tensor(
2523*da0073e9SAndroid Build Coastguard Worker                indices=torch.tensor([[0, 1, 2, 3]]),
2524*da0073e9SAndroid Build Coastguard Worker                values=torch.tensor([0.5, -0.5, 0.7, -0.7]),
2525*da0073e9SAndroid Build Coastguard Worker                size=[4, ],
2526*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
2527*da0073e9SAndroid Build Coastguard Worker                device=device
2528*da0073e9SAndroid Build Coastguard Worker            ).coalesce()
2529*da0073e9SAndroid Build Coastguard Worker            self._test_asin_arcsin(input_coalesced, coalesced)
2530*da0073e9SAndroid Build Coastguard Worker
2531*da0073e9SAndroid Build Coastguard Worker            # hybrid sparse input
2532*da0073e9SAndroid Build Coastguard Worker            input_coalesced = torch.sparse_coo_tensor(
2533*da0073e9SAndroid Build Coastguard Worker                indices=torch.tensor([[1, 3], [2, 4]]),
2534*da0073e9SAndroid Build Coastguard Worker                values=torch.tensor([[-0.1, 0.24], [-0.44, 0.1]]),
2535*da0073e9SAndroid Build Coastguard Worker                size=[4, 5, 2],
2536*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
2537*da0073e9SAndroid Build Coastguard Worker                device=device
2538*da0073e9SAndroid Build Coastguard Worker            ).coalesce()
2539*da0073e9SAndroid Build Coastguard Worker            self._test_asin_arcsin(input_coalesced, coalesced)
2540*da0073e9SAndroid Build Coastguard Worker
2541*da0073e9SAndroid Build Coastguard Worker        if not coalesced:
2542*da0073e9SAndroid Build Coastguard Worker            # test uncoalesced input
2543*da0073e9SAndroid Build Coastguard Worker            input_uncoalesced = torch.sparse_coo_tensor(
2544*da0073e9SAndroid Build Coastguard Worker                indices=torch.tensor([[0], [1], [2], [0], [1], [2]]).transpose(1, 0),
2545*da0073e9SAndroid Build Coastguard Worker                values=torch.tensor([0.3, -0.3, -0.4, 0.3, -0.5, 0.15]),
2546*da0073e9SAndroid Build Coastguard Worker                size=[3, ],
2547*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
2548*da0073e9SAndroid Build Coastguard Worker                device=device
2549*da0073e9SAndroid Build Coastguard Worker            )
2550*da0073e9SAndroid Build Coastguard Worker            self._test_asin_arcsin(input_uncoalesced, coalesced)
2551*da0073e9SAndroid Build Coastguard Worker
2552*da0073e9SAndroid Build Coastguard Worker            # test on empty sparse tensor
2553*da0073e9SAndroid Build Coastguard Worker            input_uncoalesced = torch.sparse_coo_tensor(
2554*da0073e9SAndroid Build Coastguard Worker                indices=torch.zeros([2, 0]),
2555*da0073e9SAndroid Build Coastguard Worker                values=torch.zeros([0, 5, 5, 5, 5, 5, 5, 0]),
2556*da0073e9SAndroid Build Coastguard Worker                size=[0, 0, 5, 5, 5, 5, 5, 5, 0],
2557*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
2558*da0073e9SAndroid Build Coastguard Worker                device=device
2559*da0073e9SAndroid Build Coastguard Worker            )
2560*da0073e9SAndroid Build Coastguard Worker            # empty tensors are coalesced at creation (nnz < 2) we must force the uncoalesced state
2561*da0073e9SAndroid Build Coastguard Worker            input_uncoalesced._coalesced_(False)
2562*da0073e9SAndroid Build Coastguard Worker            self._test_asin_arcsin(input_uncoalesced, coalesced)
2563*da0073e9SAndroid Build Coastguard Worker
2564*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
2565*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
2566*da0073e9SAndroid Build Coastguard Worker    def test_mv(self, device, dtype, coalesced):
2567*da0073e9SAndroid Build Coastguard Worker        def test_shape(di, dj, dk, nnz):
2568*da0073e9SAndroid Build Coastguard Worker            x, _, _ = self._gen_sparse(2, nnz, [di, dj], dtype, device, coalesced)
2569*da0073e9SAndroid Build Coastguard Worker            t = torch.randn(dk, dtype=dtype, device=device)
2570*da0073e9SAndroid Build Coastguard Worker
2571*da0073e9SAndroid Build Coastguard Worker            res = x.matmul(t)
2572*da0073e9SAndroid Build Coastguard Worker            expected = self.safeToDense(x).matmul(t)
2573*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, expected)
2574*da0073e9SAndroid Build Coastguard Worker
2575*da0073e9SAndroid Build Coastguard Worker        test_shape(10, 100, 100, 20)
2576*da0073e9SAndroid Build Coastguard Worker        test_shape(100, 1000, 1000, 20)
2577*da0073e9SAndroid Build Coastguard Worker        test_shape(64, 10000, 10000, 20)
2578*da0073e9SAndroid Build Coastguard Worker        test_shape(0, 100, 100, 0)
2579*da0073e9SAndroid Build Coastguard Worker        test_shape(10, 0, 0, 0)
2580*da0073e9SAndroid Build Coastguard Worker        test_shape(10, 100, 100, 0)
2581*da0073e9SAndroid Build Coastguard Worker        test_shape(10, 100, 100, 20)
2582*da0073e9SAndroid Build Coastguard Worker
2583*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"mv: expected self\.size\(-1\) == vec\.size\(-1\)"):
2584*da0073e9SAndroid Build Coastguard Worker            test_shape(10, 100, 10, 20)
2585*da0073e9SAndroid Build Coastguard Worker
2586*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "mv: two tensor dim should be 2 and 1"):
2587*da0073e9SAndroid Build Coastguard Worker            x, _, _ = self._gen_sparse(2, 20, [10, 100], dtype, device, coalesced)
2588*da0073e9SAndroid Build Coastguard Worker            y, _, _ = self._gen_sparse(2, 20, [10, 100], dtype, device, coalesced)
2589*da0073e9SAndroid Build Coastguard Worker            res = x.mv(y)
2590*da0073e9SAndroid Build Coastguard Worker
2591*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
2592*da0073e9SAndroid Build Coastguard Worker    def test_sparse_add_coalesce(self, device, dtype):
2593*da0073e9SAndroid Build Coastguard Worker        i = self.index_tensor([[1, 2, 1]], device=device)
2594*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([3, 4, 5], dtype=dtype, device=device)
2595*da0073e9SAndroid Build Coastguard Worker        x = self.sparse_tensor(i, v, torch.Size([3]))
2596*da0073e9SAndroid Build Coastguard Worker        y = self.sparse_tensor(i, v, torch.Size([3]))
2597*da0073e9SAndroid Build Coastguard Worker        z = x + y
2598*da0073e9SAndroid Build Coastguard Worker
2599*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(z._indices().numel() != 2 and z.is_coalesced())
2600*da0073e9SAndroid Build Coastguard Worker
2601*da0073e9SAndroid Build Coastguard Worker        i = self.index_tensor([[1, 2, 1]], device=device)
2602*da0073e9SAndroid Build Coastguard Worker        v = torch.empty([3, 0], dtype=dtype, device=device)
2603*da0073e9SAndroid Build Coastguard Worker        x = self.sparse_tensor(i, v, torch.Size([3, 0]))
2604*da0073e9SAndroid Build Coastguard Worker        y = self.sparse_tensor(i, v, torch.Size([3, 0]))
2605*da0073e9SAndroid Build Coastguard Worker        z = x + y
2606*da0073e9SAndroid Build Coastguard Worker
2607*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(z._indices().numel() != 2 and z.is_coalesced())
2608*da0073e9SAndroid Build Coastguard Worker
2609*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
2610*da0073e9SAndroid Build Coastguard Worker    def test_storage_not_null(self, device):
2611*da0073e9SAndroid Build Coastguard Worker        x = torch.sparse_coo_tensor((2,), dtype=torch.float32, device=device)
2612*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(x.get_device(), -1)
2613*da0073e9SAndroid Build Coastguard Worker
2614*da0073e9SAndroid Build Coastguard Worker        x = torch.sparse_coo_tensor((2, 0), dtype=torch.float32, device=device)
2615*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(x.get_device(), -1)
2616*da0073e9SAndroid Build Coastguard Worker
2617*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
2618*da0073e9SAndroid Build Coastguard Worker    @deviceCountAtLeast(2)
2619*da0073e9SAndroid Build Coastguard Worker    def test_same_gpu(self, devices):
2620*da0073e9SAndroid Build Coastguard Worker        def check_device(x, device_id):
2621*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.get_device(), device_id)
2622*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x._values().get_device(), device_id)
2623*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x._indices().get_device(), device_id)
2624*da0073e9SAndroid Build Coastguard Worker
2625*da0073e9SAndroid Build Coastguard Worker        dev1, dev2 = devices[0], devices[1]
2626*da0073e9SAndroid Build Coastguard Worker
2627*da0073e9SAndroid Build Coastguard Worker        i = self.index_tensor([[2]], device=dev2)
2628*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([5], device=dev2)
2629*da0073e9SAndroid Build Coastguard Worker        x = self.sparse_tensor(i, v, torch.Size([3]), device=1)
2630*da0073e9SAndroid Build Coastguard Worker        check_device(x, 1)
2631*da0073e9SAndroid Build Coastguard Worker
2632*da0073e9SAndroid Build Coastguard Worker        i = self.index_tensor([[2]], device=dev2)
2633*da0073e9SAndroid Build Coastguard Worker        v = torch.empty(1, 0, device=dev2)
2634*da0073e9SAndroid Build Coastguard Worker        x = self.sparse_tensor(i, v, torch.Size([3, 0]), device=1)
2635*da0073e9SAndroid Build Coastguard Worker        check_device(x, 1)
2636*da0073e9SAndroid Build Coastguard Worker
2637*da0073e9SAndroid Build Coastguard Worker        x = self.sparse_empty(3, device=1)
2638*da0073e9SAndroid Build Coastguard Worker        check_device(x, 1)
2639*da0073e9SAndroid Build Coastguard Worker
2640*da0073e9SAndroid Build Coastguard Worker        x = self.sparse_empty(3, 0, device=1)
2641*da0073e9SAndroid Build Coastguard Worker        check_device(x, 1)
2642*da0073e9SAndroid Build Coastguard Worker
2643*da0073e9SAndroid Build Coastguard Worker    def _test_new_device(self, size, device=torch.cuda):
2644*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(device):
2645*da0073e9SAndroid Build Coastguard Worker            x = torch.sparse_coo_tensor(size, device='cuda', dtype=torch.float64)
2646*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.get_device(), device)
2647*da0073e9SAndroid Build Coastguard Worker        x1 = x.new()
2648*da0073e9SAndroid Build Coastguard Worker        x2 = x.new(2, 3)
2649*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x1.get_device(), device)
2650*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x2.get_device(), device)
2651*da0073e9SAndroid Build Coastguard Worker
2652*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
2653*da0073e9SAndroid Build Coastguard Worker    def test_new_device_single_gpu(self):
2654*da0073e9SAndroid Build Coastguard Worker        self._test_new_device((), 0)
2655*da0073e9SAndroid Build Coastguard Worker        self._test_new_device((30, 20), 0)
2656*da0073e9SAndroid Build Coastguard Worker        self._test_new_device((30, 20, 10), 0)
2657*da0073e9SAndroid Build Coastguard Worker        self._test_new_device((30, 20, 10, 0), 0)
2658*da0073e9SAndroid Build Coastguard Worker
2659*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
2660*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
2661*da0073e9SAndroid Build Coastguard Worker    def test_new_device_multi_gpu(self):
2662*da0073e9SAndroid Build Coastguard Worker        self._test_new_device((), 1)
2663*da0073e9SAndroid Build Coastguard Worker        self._test_new_device((30, 20), 1)
2664*da0073e9SAndroid Build Coastguard Worker        self._test_new_device((30, 20, 10), 1)
2665*da0073e9SAndroid Build Coastguard Worker        self._test_new_device((30, 20, 10, 0), 1)
2666*da0073e9SAndroid Build Coastguard Worker
2667*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
2668*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
2669*da0073e9SAndroid Build Coastguard Worker    def test_new(self, device, dtype, coalesced):
2670*da0073e9SAndroid Build Coastguard Worker        def test_shape(sparse_dims, nnz, with_size):
2671*da0073e9SAndroid Build Coastguard Worker            x, indices, values = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)
2672*da0073e9SAndroid Build Coastguard Worker            if not x.is_cuda:
2673*da0073e9SAndroid Build Coastguard Worker                # CUDA sparse tensors currently requires the size to be
2674*da0073e9SAndroid Build Coastguard Worker                # specified if nDimV > 0
2675*da0073e9SAndroid Build Coastguard Worker                out = x.new(indices, values).coalesce()
2676*da0073e9SAndroid Build Coastguard Worker                x_c = x.coalesce()
2677*da0073e9SAndroid Build Coastguard Worker                self.assertEqual((out.indices(), out.values()), (x_c.indices(), x_c.values()))
2678*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.new(indices, values, x.size()), x)
2679*da0073e9SAndroid Build Coastguard Worker
2680*da0073e9SAndroid Build Coastguard Worker        test_shape(3, 10, 100)
2681*da0073e9SAndroid Build Coastguard Worker        test_shape(3, 0, [100, 100, 0])
2682*da0073e9SAndroid Build Coastguard Worker
2683*da0073e9SAndroid Build Coastguard Worker    @onlyCPU  # not really, but we only really want to run this once
2684*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float64, torch.float32, torch.float16, torch.cfloat, torch.cdouble)
2685*da0073e9SAndroid Build Coastguard Worker    def test_factory(self, device, dtype):
2686*da0073e9SAndroid Build Coastguard Worker        for test_empty_tensor in [True, False]:
2687*da0073e9SAndroid Build Coastguard Worker            if test_empty_tensor:
2688*da0073e9SAndroid Build Coastguard Worker                default_size = torch.Size([1, 3, 0])
2689*da0073e9SAndroid Build Coastguard Worker                size = torch.Size([3, 3, 0])
2690*da0073e9SAndroid Build Coastguard Worker            else:
2691*da0073e9SAndroid Build Coastguard Worker                default_size = torch.Size([1, 3])
2692*da0073e9SAndroid Build Coastguard Worker                size = torch.Size([3, 3])
2693*da0073e9SAndroid Build Coastguard Worker            for include_size in [True, False]:
2694*da0073e9SAndroid Build Coastguard Worker                for use_tensor_idx in [True, False]:
2695*da0073e9SAndroid Build Coastguard Worker                    for use_tensor_val in [True, False]:
2696*da0073e9SAndroid Build Coastguard Worker                        for use_cuda in ([False] if not torch.cuda.is_available() else [True, False]):
2697*da0073e9SAndroid Build Coastguard Worker                            # have to include size with cuda sparse tensors
2698*da0073e9SAndroid Build Coastguard Worker                            include_size = include_size or use_cuda
2699*da0073e9SAndroid Build Coastguard Worker                            long_dtype = torch.int64
2700*da0073e9SAndroid Build Coastguard Worker                            device = torch.device('cpu') if not use_cuda else \
2701*da0073e9SAndroid Build Coastguard Worker                                torch.device(torch.cuda.device_count() - 1)
2702*da0073e9SAndroid Build Coastguard Worker                            indices = torch.tensor(([0], [2]), dtype=long_dtype) if use_tensor_idx else ([0], [2])
2703*da0073e9SAndroid Build Coastguard Worker                            if test_empty_tensor:
2704*da0073e9SAndroid Build Coastguard Worker                                values = torch.empty(1, 0).to(dtype)
2705*da0073e9SAndroid Build Coastguard Worker                            else:
2706*da0073e9SAndroid Build Coastguard Worker                                if use_tensor_val:
2707*da0073e9SAndroid Build Coastguard Worker                                    values = torch.tensor([1.], dtype=dtype)
2708*da0073e9SAndroid Build Coastguard Worker                                else:
2709*da0073e9SAndroid Build Coastguard Worker                                    values = 1.
2710*da0073e9SAndroid Build Coastguard Worker                            if include_size:
2711*da0073e9SAndroid Build Coastguard Worker                                sparse_tensor = torch.sparse_coo_tensor(indices, values, size, dtype=dtype,
2712*da0073e9SAndroid Build Coastguard Worker                                                                        device=device, requires_grad=True)
2713*da0073e9SAndroid Build Coastguard Worker                            else:
2714*da0073e9SAndroid Build Coastguard Worker                                sparse_tensor = torch.sparse_coo_tensor(indices, values, dtype=dtype,
2715*da0073e9SAndroid Build Coastguard Worker                                                                        device=device, requires_grad=True)
2716*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(indices, sparse_tensor._indices())
2717*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(values, sparse_tensor._values())
2718*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(size if include_size else default_size, sparse_tensor.size())
2719*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(dtype, sparse_tensor.dtype)
2720*da0073e9SAndroid Build Coastguard Worker                            if use_cuda:
2721*da0073e9SAndroid Build Coastguard Worker                                self.assertEqual(device, sparse_tensor._values().device)
2722*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(True, sparse_tensor.requires_grad)
2723*da0073e9SAndroid Build Coastguard Worker
2724*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
2725*da0073e9SAndroid Build Coastguard Worker    def test_factory_size_check(self, device, dtype):
2726*da0073e9SAndroid Build Coastguard Worker        indices = self.index_tensor([[1, 2],
2727*da0073e9SAndroid Build Coastguard Worker                                    [0, 2]], device=device)
2728*da0073e9SAndroid Build Coastguard Worker        values = torch.tensor([.5, .5], dtype=dtype, device=device)
2729*da0073e9SAndroid Build Coastguard Worker        sizes = torch.Size([2, 3])
2730*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "size is inconsistent with indices"):
2731*da0073e9SAndroid Build Coastguard Worker            torch.sparse_coo_tensor(indices, values, sizes, dtype=dtype, device=device)
2732*da0073e9SAndroid Build Coastguard Worker
2733*da0073e9SAndroid Build Coastguard Worker        indices.fill_(-1)
2734*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "found negative index"):
2735*da0073e9SAndroid Build Coastguard Worker            torch.sparse_coo_tensor(indices, values, sizes, dtype=dtype, device=device)
2736*da0073e9SAndroid Build Coastguard Worker
2737*da0073e9SAndroid Build Coastguard Worker        indices = self.index_tensor([[1, 2],
2738*da0073e9SAndroid Build Coastguard Worker                                    [0, 2]], device=device)
2739*da0073e9SAndroid Build Coastguard Worker        values = torch.empty([2, 1, 0], dtype=dtype, device=device)
2740*da0073e9SAndroid Build Coastguard Worker        sizes = torch.Size([2, 3, 1, 0])
2741*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "size is inconsistent with indices"):
2742*da0073e9SAndroid Build Coastguard Worker            torch.sparse_coo_tensor(indices, values, sizes, dtype=dtype, device=device)
2743*da0073e9SAndroid Build Coastguard Worker
2744*da0073e9SAndroid Build Coastguard Worker        indices = self.index_tensor([[1, 2],
2745*da0073e9SAndroid Build Coastguard Worker                                    [0, 2]], device=device)
2746*da0073e9SAndroid Build Coastguard Worker        values = torch.empty([2, 2, 2], dtype=dtype, device=device)
2747*da0073e9SAndroid Build Coastguard Worker        sizes = torch.Size([0, 0, 2, 2])
2748*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "size is inconsistent with indices"):
2749*da0073e9SAndroid Build Coastguard Worker            torch.sparse_coo_tensor(indices, values, sizes, dtype=dtype, device=device)
2750*da0073e9SAndroid Build Coastguard Worker
2751*da0073e9SAndroid Build Coastguard Worker        indices = self.index_tensor([[1, 2],
2752*da0073e9SAndroid Build Coastguard Worker                                    [0, 2]], device=device)
2753*da0073e9SAndroid Build Coastguard Worker        values = torch.tensor([[1, 1, 1], [1, 1, 1]], dtype=dtype, device=device)
2754*da0073e9SAndroid Build Coastguard Worker        sizes = torch.Size([3, 3, 2])
2755*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "values has incorrect size"):
2756*da0073e9SAndroid Build Coastguard Worker            torch.sparse_coo_tensor(indices, values, sizes, dtype=dtype, device=device)
2757*da0073e9SAndroid Build Coastguard Worker
2758*da0073e9SAndroid Build Coastguard Worker        indices = self.index_tensor([[1, 2],
2759*da0073e9SAndroid Build Coastguard Worker                                    [0, 2]], device=device)
2760*da0073e9SAndroid Build Coastguard Worker        values = torch.empty([2, 1, 0], dtype=dtype, device=device)
2761*da0073e9SAndroid Build Coastguard Worker        sizes = torch.Size([3, 3, 2, 0])
2762*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "values has incorrect size"):
2763*da0073e9SAndroid Build Coastguard Worker            torch.sparse_coo_tensor(indices, values, sizes, dtype=dtype, device=device)
2764*da0073e9SAndroid Build Coastguard Worker
2765*da0073e9SAndroid Build Coastguard Worker    def test_factory_empty_indices(self, device):
2766*da0073e9SAndroid Build Coastguard Worker        tensor = torch.sparse_coo_tensor(torch.Size([2, 0]), device=device)
2767*da0073e9SAndroid Build Coastguard Worker        expected_indices = torch.empty((2, 0), dtype=torch.long, device=device)
2768*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor._indices(), expected_indices)
2769*da0073e9SAndroid Build Coastguard Worker
2770*da0073e9SAndroid Build Coastguard Worker        tensor = torch.sparse_coo_tensor(torch.Size([2, 2, 0]), device=device)
2771*da0073e9SAndroid Build Coastguard Worker        expected_indices = torch.empty((3, 0), dtype=torch.long, device=device)
2772*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor._indices(), expected_indices)
2773*da0073e9SAndroid Build Coastguard Worker
2774*da0073e9SAndroid Build Coastguard Worker        tensor = torch.sparse_coo_tensor(torch.Size([2, 2, 0, 0]), device=device)
2775*da0073e9SAndroid Build Coastguard Worker        expected_indices = torch.empty((4, 0), dtype=torch.long, device=device)
2776*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor._indices(), expected_indices)
2777*da0073e9SAndroid Build Coastguard Worker
2778*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
2779*da0073e9SAndroid Build Coastguard Worker    def test_factory_nnz(self, device, dtype):
2780*da0073e9SAndroid Build Coastguard Worker        indices = self.index_tensor([[0]], device=device)  # (sparse_dim, nnz): (1, 1)
2781*da0073e9SAndroid Build Coastguard Worker        values = torch.tensor([[1, 1], [1, 1]], dtype=dtype, device=device)  # (nnz, ...): (2, 2)
2782*da0073e9SAndroid Build Coastguard Worker        sizes = torch.Size([2, 2])
2783*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "indices and values must have same nnz"):
2784*da0073e9SAndroid Build Coastguard Worker            torch.sparse_coo_tensor(indices, values, sizes, dtype=dtype, device=device)
2785*da0073e9SAndroid Build Coastguard Worker
2786*da0073e9SAndroid Build Coastguard Worker        indices = self.index_tensor([[0]], device=device)  # (sparse_dim, nnz): (1, 1)
2787*da0073e9SAndroid Build Coastguard Worker        values = torch.empty([2, 0], dtype=dtype, device=device)  # (nnz, ...): (2, 0)
2788*da0073e9SAndroid Build Coastguard Worker        sizes = torch.Size([2, 0])
2789*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "indices and values must have same nnz"):
2790*da0073e9SAndroid Build Coastguard Worker            torch.sparse_coo_tensor(indices, values, sizes, dtype=dtype, device=device)
2791*da0073e9SAndroid Build Coastguard Worker
2792*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
2793*da0073e9SAndroid Build Coastguard Worker    def test_factory_nnz_zero(self, device, dtype):
2794*da0073e9SAndroid Build Coastguard Worker        def test_shape(i_shape, v_shape, size, expected_size):
2795*da0073e9SAndroid Build Coastguard Worker            if size:
2796*da0073e9SAndroid Build Coastguard Worker                t = torch.sparse_coo_tensor(torch.empty(i_shape), torch.empty(v_shape), torch.Size(size),
2797*da0073e9SAndroid Build Coastguard Worker                                            dtype=dtype, device=device)
2798*da0073e9SAndroid Build Coastguard Worker            else:
2799*da0073e9SAndroid Build Coastguard Worker                t = torch.sparse_coo_tensor(torch.empty(i_shape), torch.empty(v_shape), dtype=dtype, device=device)
2800*da0073e9SAndroid Build Coastguard Worker            expected_indices = torch.empty(i_shape, device=device, dtype=torch.int64)
2801*da0073e9SAndroid Build Coastguard Worker            expected_values = torch.empty(v_shape, device=device, dtype=dtype)
2802*da0073e9SAndroid Build Coastguard Worker            expected_size = torch.Size(expected_size)
2803*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t._indices(), expected_indices)
2804*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t._values(), expected_values)
2805*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t.size(), expected_size)
2806*da0073e9SAndroid Build Coastguard Worker
2807*da0073e9SAndroid Build Coastguard Worker        test_shape([1, 0], [0, 2, 4, 0], None, [0, 2, 4, 0])
2808*da0073e9SAndroid Build Coastguard Worker        test_shape([3, 0], [0, 2, 4, 0], None, [0, 0, 0, 2, 4, 0])
2809*da0073e9SAndroid Build Coastguard Worker        test_shape([1, 0], [0, 2, 4, 0], [0, 2, 4, 0], [0, 2, 4, 0])
2810*da0073e9SAndroid Build Coastguard Worker        test_shape([3, 0], [0, 2, 4, 0], [0, 0, 0, 2, 4, 0], [0, 0, 0, 2, 4, 0])
2811*da0073e9SAndroid Build Coastguard Worker        test_shape([3, 0], [0, 2, 4, 0], [1, 2, 3, 2, 4, 0], [1, 2, 3, 2, 4, 0])
2812*da0073e9SAndroid Build Coastguard Worker
2813*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
2814*da0073e9SAndroid Build Coastguard Worker    def test_factory_dense_dim(self, device, dtype):
2815*da0073e9SAndroid Build Coastguard Worker        indices = self.index_tensor([[0]], device=device)
2816*da0073e9SAndroid Build Coastguard Worker        values = torch.tensor([[[1, 1, 1], [1, 1, 1]]], dtype=dtype, device=device)
2817*da0073e9SAndroid Build Coastguard Worker        sizes = torch.Size([1, 3, 4])
2818*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "values has incorrect size"):
2819*da0073e9SAndroid Build Coastguard Worker            torch.sparse_coo_tensor(indices, values, sizes)
2820*da0073e9SAndroid Build Coastguard Worker
2821*da0073e9SAndroid Build Coastguard Worker        indices = self.index_tensor([[0]], device=device)
2822*da0073e9SAndroid Build Coastguard Worker        values = torch.empty([1, 2, 3, 0], dtype=dtype, device=device)
2823*da0073e9SAndroid Build Coastguard Worker        sizes = torch.Size([1, 3, 4, 0])
2824*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "values has incorrect size"):
2825*da0073e9SAndroid Build Coastguard Worker            torch.sparse_coo_tensor(indices, values, sizes)
2826*da0073e9SAndroid Build Coastguard Worker
2827*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
2828*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float16, torch.float32, torch.float64, torch.cfloat, torch.cdouble, torch.int64)
2829*da0073e9SAndroid Build Coastguard Worker    def test_factory_type_inference(self, device, dtype):
2830*da0073e9SAndroid Build Coastguard Worker        t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.tensor([1.], dtype=dtype))
2831*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dtype, t.dtype)
2832*da0073e9SAndroid Build Coastguard Worker        t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.tensor([1]))
2833*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.int64, t.dtype)
2834*da0073e9SAndroid Build Coastguard Worker
2835*da0073e9SAndroid Build Coastguard Worker        t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.HalfTensor(1, 0))
2836*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.float16, t.dtype)
2837*da0073e9SAndroid Build Coastguard Worker        t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.FloatTensor(1, 0))
2838*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.float32, t.dtype)
2839*da0073e9SAndroid Build Coastguard Worker        t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.DoubleTensor(1, 0))
2840*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.float64, t.dtype)
2841*da0073e9SAndroid Build Coastguard Worker        t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.LongTensor(1, 0))
2842*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.int64, t.dtype)
2843*da0073e9SAndroid Build Coastguard Worker
2844*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
2845*da0073e9SAndroid Build Coastguard Worker    def test_factory_device_type_inference(self, device):
2846*da0073e9SAndroid Build Coastguard Worker        # both indices/values are CUDA
2847*da0073e9SAndroid Build Coastguard Worker
2848*da0073e9SAndroid Build Coastguard Worker        cpu_cuda = ('cpu', 'cuda')
2849*da0073e9SAndroid Build Coastguard Worker        cpu_cuda_none = cpu_cuda + (None,)
2850*da0073e9SAndroid Build Coastguard Worker        for indices_device, values_device, device in itertools.product(cpu_cuda,
2851*da0073e9SAndroid Build Coastguard Worker                                                                       cpu_cuda,
2852*da0073e9SAndroid Build Coastguard Worker                                                                       cpu_cuda_none):
2853*da0073e9SAndroid Build Coastguard Worker            indices = torch.tensor(([0], [2]), device=indices_device)
2854*da0073e9SAndroid Build Coastguard Worker            values = torch.tensor([1.], device=values_device)
2855*da0073e9SAndroid Build Coastguard Worker            empty_values = torch.empty(1, 0).to(values_device)
2856*da0073e9SAndroid Build Coastguard Worker            shape = (1, 3)
2857*da0073e9SAndroid Build Coastguard Worker            empty_shape = (1, 3, 0)
2858*da0073e9SAndroid Build Coastguard Worker            if device is None and indices_device != values_device:
2859*da0073e9SAndroid Build Coastguard Worker                with self.assertRaises(RuntimeError):
2860*da0073e9SAndroid Build Coastguard Worker                    torch.sparse_coo_tensor(indices, values, shape, device=device)
2861*da0073e9SAndroid Build Coastguard Worker                with self.assertRaises(RuntimeError):
2862*da0073e9SAndroid Build Coastguard Worker                    torch.sparse_coo_tensor(indices, empty_values, empty_shape, device=device)
2863*da0073e9SAndroid Build Coastguard Worker            else:
2864*da0073e9SAndroid Build Coastguard Worker                t = torch.sparse_coo_tensor(indices, values, shape, device=device)
2865*da0073e9SAndroid Build Coastguard Worker                t_empty = torch.sparse_coo_tensor(indices, empty_values, empty_shape, device=device)
2866*da0073e9SAndroid Build Coastguard Worker                should_be_cuda = (device == 'cuda' or (device is None and values_device == 'cuda'))
2867*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(should_be_cuda, t.is_cuda)
2868*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t.is_cuda, t_empty.is_cuda)
2869*da0073e9SAndroid Build Coastguard Worker
2870*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
2871*da0073e9SAndroid Build Coastguard Worker    def test_factory_copy(self, device):
2872*da0073e9SAndroid Build Coastguard Worker        def test_tensor(indices, values, indices_equal, values_equal):
2873*da0073e9SAndroid Build Coastguard Worker            sparse_tensor = torch.sparse_coo_tensor(indices, values, dtype=torch.float64, device=device)
2874*da0073e9SAndroid Build Coastguard Worker            if indices_equal:
2875*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(indices.data_ptr(), sparse_tensor._indices().data_ptr())
2876*da0073e9SAndroid Build Coastguard Worker            else:
2877*da0073e9SAndroid Build Coastguard Worker                self.assertNotEqual(indices.data_ptr(), sparse_tensor._indices().data_ptr())
2878*da0073e9SAndroid Build Coastguard Worker            if values_equal:
2879*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(values.data_ptr(), sparse_tensor._values().data_ptr())
2880*da0073e9SAndroid Build Coastguard Worker            else:
2881*da0073e9SAndroid Build Coastguard Worker                self.assertNotEqual(values.data_ptr(), sparse_tensor._values().data_ptr())
2882*da0073e9SAndroid Build Coastguard Worker
2883*da0073e9SAndroid Build Coastguard Worker        # both correct
2884*da0073e9SAndroid Build Coastguard Worker        indices = torch.tensor(([0], [2]), dtype=torch.int64)
2885*da0073e9SAndroid Build Coastguard Worker        values = torch.tensor([1.], dtype=torch.float64)
2886*da0073e9SAndroid Build Coastguard Worker        test_tensor(indices, values, True, True)
2887*da0073e9SAndroid Build Coastguard Worker
2888*da0073e9SAndroid Build Coastguard Worker        indices = torch.tensor(([0], [2]), dtype=torch.int64)
2889*da0073e9SAndroid Build Coastguard Worker        values = torch.DoubleTensor(1, 0)
2890*da0073e9SAndroid Build Coastguard Worker        test_tensor(indices, values, True, True)
2891*da0073e9SAndroid Build Coastguard Worker
2892*da0073e9SAndroid Build Coastguard Worker        # only indices correct
2893*da0073e9SAndroid Build Coastguard Worker        indices = torch.tensor(([0], [2]), dtype=torch.int64)
2894*da0073e9SAndroid Build Coastguard Worker        values = torch.tensor([1.], dtype=torch.float32)
2895*da0073e9SAndroid Build Coastguard Worker        test_tensor(indices, values, True, False)
2896*da0073e9SAndroid Build Coastguard Worker
2897*da0073e9SAndroid Build Coastguard Worker        indices = torch.tensor(([0], [2]), dtype=torch.int64)
2898*da0073e9SAndroid Build Coastguard Worker        values = torch.tensor([1.], dtype=torch.float16)
2899*da0073e9SAndroid Build Coastguard Worker        test_tensor(indices, values, True, False)
2900*da0073e9SAndroid Build Coastguard Worker
2901*da0073e9SAndroid Build Coastguard Worker        indices = torch.tensor(([0], [2]), dtype=torch.int64)
2902*da0073e9SAndroid Build Coastguard Worker        values = torch.FloatTensor(1, 0)
2903*da0073e9SAndroid Build Coastguard Worker        test_tensor(indices, values, True, True)  # An empty tensor's data_ptr is always equal to 0
2904*da0073e9SAndroid Build Coastguard Worker
2905*da0073e9SAndroid Build Coastguard Worker        # only values correct
2906*da0073e9SAndroid Build Coastguard Worker        indices = torch.tensor(([0], [2]), dtype=torch.int32)
2907*da0073e9SAndroid Build Coastguard Worker        values = torch.tensor([1.], dtype=torch.float64)
2908*da0073e9SAndroid Build Coastguard Worker        test_tensor(indices, values, False, True)
2909*da0073e9SAndroid Build Coastguard Worker
2910*da0073e9SAndroid Build Coastguard Worker        indices = torch.tensor(([0], [2]), dtype=torch.int32)
2911*da0073e9SAndroid Build Coastguard Worker        values = torch.DoubleTensor(1, 0)
2912*da0073e9SAndroid Build Coastguard Worker        test_tensor(indices, values, False, True)
2913*da0073e9SAndroid Build Coastguard Worker
2914*da0073e9SAndroid Build Coastguard Worker        # neither correct
2915*da0073e9SAndroid Build Coastguard Worker        indices = torch.tensor(([0], [2]), dtype=torch.int32)
2916*da0073e9SAndroid Build Coastguard Worker        values = torch.tensor([1.], dtype=torch.float32)
2917*da0073e9SAndroid Build Coastguard Worker        test_tensor(indices, values, False, False)
2918*da0073e9SAndroid Build Coastguard Worker
2919*da0073e9SAndroid Build Coastguard Worker        indices = torch.tensor(([0], [2]), dtype=torch.int32)
2920*da0073e9SAndroid Build Coastguard Worker        values = torch.FloatTensor(1, 0)
2921*da0073e9SAndroid Build Coastguard Worker        test_tensor(indices, values, False, True)  # An empty tensor's data_ptr is always equal to 0
2922*da0073e9SAndroid Build Coastguard Worker
2923*da0073e9SAndroid Build Coastguard Worker        # complex support
2924*da0073e9SAndroid Build Coastguard Worker        indices = torch.tensor(([0], [2]), dtype=torch.int64)
2925*da0073e9SAndroid Build Coastguard Worker        values = make_tensor([1, ], dtype=torch.cdouble, device=device)
2926*da0073e9SAndroid Build Coastguard Worker        test_tensor(indices, values, True, False)
2927*da0073e9SAndroid Build Coastguard Worker
2928*da0073e9SAndroid Build Coastguard Worker        indices = torch.tensor(([0], [2]), dtype=torch.int32)
2929*da0073e9SAndroid Build Coastguard Worker        values = make_tensor([1, 1], dtype=torch.cdouble, device=device)
2930*da0073e9SAndroid Build Coastguard Worker        test_tensor(indices, values, False, False)
2931*da0073e9SAndroid Build Coastguard Worker
2932*da0073e9SAndroid Build Coastguard Worker    @onlyCPU  # just run once, we test both cpu and cuda
2933*da0073e9SAndroid Build Coastguard Worker    def test_legacy_new_device(self, device):
2934*da0073e9SAndroid Build Coastguard Worker        i = torch.tensor([[0, 1, 1], [2, 0, 2]])
2935*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([3., 4., 5.])
2936*da0073e9SAndroid Build Coastguard Worker        size = torch.Size([2, 3])
2937*da0073e9SAndroid Build Coastguard Worker
2938*da0073e9SAndroid Build Coastguard Worker        x = torch.sparse_coo_tensor(i, v, size, device='cpu')
2939*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: x.new(device='cuda'))
2940*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: x.new(i, v, device='cuda'))
2941*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: x.new(i, v, size, device='cuda'))
2942*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: x.new(torch.Size([2, 3, 4]), device='cuda'))
2943*da0073e9SAndroid Build Coastguard Worker
2944*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
2945*da0073e9SAndroid Build Coastguard Worker            x = torch.sparse_coo_tensor(i, v, size, device='cuda')
2946*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: x.new(device='cpu'))
2947*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: x.new(i, v, device='cpu'))
2948*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: x.new(i, v, size, device='cpu'))
2949*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: x.new(torch.Size([2, 3, 4]), device='cpu'))
2950*da0073e9SAndroid Build Coastguard Worker
2951*da0073e9SAndroid Build Coastguard Worker    def test_legacy_new(self, device):
2952*da0073e9SAndroid Build Coastguard Worker        i = torch.tensor([[0, 1, 1], [2, 0, 2]])
2953*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([3., 4., 5.])
2954*da0073e9SAndroid Build Coastguard Worker        size = torch.Size([2, 3])
2955*da0073e9SAndroid Build Coastguard Worker        s = torch.sparse_coo_tensor(i, v, size)
2956*da0073e9SAndroid Build Coastguard Worker
2957*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.sparse_coo, s.new(device='cpu').layout)
2958*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(TypeError, lambda: s.new(v.untyped_storage()))
2959*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(TypeError, lambda: s.new(v))
2960*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.sparse_coo, s.new(torch.Size([2, 3])).layout)
2961*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(TypeError, lambda: s.new([6]))
2962*da0073e9SAndroid Build Coastguard Worker
2963*da0073e9SAndroid Build Coastguard Worker    @onlyCPU  # not really, but we only really want to run this once
2964*da0073e9SAndroid Build Coastguard Worker    def test_dtypes(self, device):
2965*da0073e9SAndroid Build Coastguard Worker        all_sparse_dtypes = all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)
2966*da0073e9SAndroid Build Coastguard Worker        do_test_dtypes(self, all_sparse_dtypes, torch.sparse_coo, torch.device('cpu'))
2967*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
2968*da0073e9SAndroid Build Coastguard Worker            do_test_dtypes(self, all_sparse_dtypes, torch.sparse_coo, torch.device('cuda:0'))
2969*da0073e9SAndroid Build Coastguard Worker
2970*da0073e9SAndroid Build Coastguard Worker    def _test_empty_full(self, device, dtype, requires_grad):
2971*da0073e9SAndroid Build Coastguard Worker        shape = (2, 3)
2972*da0073e9SAndroid Build Coastguard Worker        layout = torch.sparse_coo
2973*da0073e9SAndroid Build Coastguard Worker
2974*da0073e9SAndroid Build Coastguard Worker        def check_value(tensor, value=None, dtype=dtype, requires_grad=requires_grad):
2975*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(shape, tensor.shape)
2976*da0073e9SAndroid Build Coastguard Worker            self.assertIs(dtype, tensor.dtype)
2977*da0073e9SAndroid Build Coastguard Worker            self.assertIs(layout, tensor.layout)
2978*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor.requires_grad, requires_grad)
2979*da0073e9SAndroid Build Coastguard Worker            if tensor.is_cuda and device is not None:
2980*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(device, tensor.device)
2981*da0073e9SAndroid Build Coastguard Worker            if value is not None:
2982*da0073e9SAndroid Build Coastguard Worker                fill = tensor.empty(shape, dtype=dtype).fill_(value)
2983*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(tensor, fill)
2984*da0073e9SAndroid Build Coastguard Worker
2985*da0073e9SAndroid Build Coastguard Worker        v = torch.sparse_coo_tensor(shape, dtype=dtype, device=device, requires_grad=requires_grad)
2986*da0073e9SAndroid Build Coastguard Worker        check_value(v)
2987*da0073e9SAndroid Build Coastguard Worker
2988*da0073e9SAndroid Build Coastguard Worker        out = v.new()
2989*da0073e9SAndroid Build Coastguard Worker        check_value(torch.zeros(shape, out=out, device=device, requires_grad=requires_grad))
2990*da0073e9SAndroid Build Coastguard Worker
2991*da0073e9SAndroid Build Coastguard Worker        int64_dtype = torch.int64
2992*da0073e9SAndroid Build Coastguard Worker        check_value(v.new_empty(shape), requires_grad=False)
2993*da0073e9SAndroid Build Coastguard Worker        check_value(v.new_empty(shape, dtype=int64_dtype, device=device, requires_grad=False),
2994*da0073e9SAndroid Build Coastguard Worker                    dtype=int64_dtype, requires_grad=False)
2995*da0073e9SAndroid Build Coastguard Worker        check_value(torch.empty_like(v), requires_grad=False)
2996*da0073e9SAndroid Build Coastguard Worker        check_value(torch.empty_like(v, dtype=int64_dtype, layout=layout, device=device, requires_grad=False),
2997*da0073e9SAndroid Build Coastguard Worker                    dtype=int64_dtype, requires_grad=False)
2998*da0073e9SAndroid Build Coastguard Worker
2999*da0073e9SAndroid Build Coastguard Worker    @onlyCPU  # not really, but we only really want to run this once
3000*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3001*da0073e9SAndroid Build Coastguard Worker    @parametrize('requires_grad', (True, False))
3002*da0073e9SAndroid Build Coastguard Worker    def test_empty_full(self, device, dtype, requires_grad):
3003*da0073e9SAndroid Build Coastguard Worker        if requires_grad and not (dtype.is_floating_point or dtype.is_complex):
3004*da0073e9SAndroid Build Coastguard Worker            self.skipTest(f'requires_grad==True requires float or complex dtype, got {dtype}')
3005*da0073e9SAndroid Build Coastguard Worker
3006*da0073e9SAndroid Build Coastguard Worker        self._test_empty_full(device, dtype, requires_grad)
3007*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
3008*da0073e9SAndroid Build Coastguard Worker            self._test_empty_full(None, dtype, requires_grad)
3009*da0073e9SAndroid Build Coastguard Worker            self._test_empty_full(torch.device('cuda:0'), dtype, requires_grad)
3010*da0073e9SAndroid Build Coastguard Worker
3011*da0073e9SAndroid Build Coastguard Worker    def test_is_sparse(self, device):
3012*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 3)
3013*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(x.is_sparse)
3014*da0073e9SAndroid Build Coastguard Worker
3015*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 3, 0)
3016*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(x.is_sparse)
3017*da0073e9SAndroid Build Coastguard Worker
3018*da0073e9SAndroid Build Coastguard Worker        x = self.sparse_empty(1, 0, device=device)
3019*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x.is_sparse)
3020*da0073e9SAndroid Build Coastguard Worker
3021*da0073e9SAndroid Build Coastguard Worker    def test_resize_as(self, device):
3022*da0073e9SAndroid Build Coastguard Worker        def do_test(t):
3023*da0073e9SAndroid Build Coastguard Worker            y = t.new().resize_as_(t).zero_()
3024*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y.shape, t.shape)
3025*da0073e9SAndroid Build Coastguard Worker            # Check that y can be added to t. Currently, this requires that
3026*da0073e9SAndroid Build Coastguard Worker            # sparse_dim and dense_dim match.
3027*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t, t + y)
3028*da0073e9SAndroid Build Coastguard Worker
3029*da0073e9SAndroid Build Coastguard Worker        do_test(self.sparse_empty([3, 0], device=device))
3030*da0073e9SAndroid Build Coastguard Worker        do_test(self.sparse_empty([3, 3], device=device))
3031*da0073e9SAndroid Build Coastguard Worker
3032*da0073e9SAndroid Build Coastguard Worker    def _test_resize_shape(self, x_i, x_v, x_size, y_i, y_v, y_size, dtype, device):
3033*da0073e9SAndroid Build Coastguard Worker        x_v_numel = torch.zeros(x_v).numel()
3034*da0073e9SAndroid Build Coastguard Worker        x = torch.sparse_coo_tensor(torch.zeros(x_i),
3035*da0073e9SAndroid Build Coastguard Worker                                    torch.arange(x_v_numel).resize_(x_v).to(torch.float),
3036*da0073e9SAndroid Build Coastguard Worker                                    torch.Size(x_size), dtype=dtype, device=device)
3037*da0073e9SAndroid Build Coastguard Worker        x_dense = x.to_dense()
3038*da0073e9SAndroid Build Coastguard Worker        y = torch.sparse_coo_tensor(torch.zeros(y_i),
3039*da0073e9SAndroid Build Coastguard Worker                                    torch.ones(y_v).to(torch.float),
3040*da0073e9SAndroid Build Coastguard Worker                                    torch.Size(y_size), dtype=dtype, device=device)
3041*da0073e9SAndroid Build Coastguard Worker        y_dense = y.to_dense()
3042*da0073e9SAndroid Build Coastguard Worker        x.resize_as_(y)
3043*da0073e9SAndroid Build Coastguard Worker        x_dense.resize_as_(y_dense)
3044*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.shape, y.shape)
3045*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.sparse_dim(), y.sparse_dim())
3046*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.dense_dim(), y.dense_dim())
3047*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.shape, x_dense.shape)
3048*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.shape, y_dense.shape)
3049*da0073e9SAndroid Build Coastguard Worker        # Here we make sure that the original data are preserved after resizing
3050*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.to_dense().view(-1)[0:x_v_numel].view(x_v),
3051*da0073e9SAndroid Build Coastguard Worker                         x_dense.view(-1)[0:x_v_numel].view(x_v))
3052*da0073e9SAndroid Build Coastguard Worker
3053*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
3054*da0073e9SAndroid Build Coastguard Worker    def test_resize(self, device, dtype):
3055*da0073e9SAndroid Build Coastguard Worker        # 1. Expand the size of some dense dimensions [Supported]
3056*da0073e9SAndroid Build Coastguard Worker        self._test_resize_shape([1, 1], [1, 2, 3], [2, 2, 3],
3057*da0073e9SAndroid Build Coastguard Worker                                [1, 1], [1, 2, 4], [2, 2, 4],
3058*da0073e9SAndroid Build Coastguard Worker                                dtype=dtype, device=device)
3059*da0073e9SAndroid Build Coastguard Worker
3060*da0073e9SAndroid Build Coastguard Worker        self._test_resize_shape([1, 1], [1, 2, 0], [2, 2, 0],
3061*da0073e9SAndroid Build Coastguard Worker                                [1, 1], [1, 2, 4], [2, 2, 4],
3062*da0073e9SAndroid Build Coastguard Worker                                dtype=dtype, device=device)
3063*da0073e9SAndroid Build Coastguard Worker
3064*da0073e9SAndroid Build Coastguard Worker        # 2. Expand the size of some sparse dimensions [Supported]
3065*da0073e9SAndroid Build Coastguard Worker        self._test_resize_shape([1, 1], [1, 2, 3], [2, 2, 3],
3066*da0073e9SAndroid Build Coastguard Worker                                [1, 1], [1, 2, 3], [4, 2, 3],
3067*da0073e9SAndroid Build Coastguard Worker                                dtype=dtype, device=device)
3068*da0073e9SAndroid Build Coastguard Worker
3069*da0073e9SAndroid Build Coastguard Worker        # 3. Change the shapes of both sparse and dense dimensions when nnz is zero [Supported]
3070*da0073e9SAndroid Build Coastguard Worker        self._test_resize_shape([1, 0], [0, 2, 3], [2, 2, 3],
3071*da0073e9SAndroid Build Coastguard Worker                                [2, 0], [0, 2, 4, 5], [1, 1, 2, 4, 5],
3072*da0073e9SAndroid Build Coastguard Worker                                dtype=dtype, device=device)
3073*da0073e9SAndroid Build Coastguard Worker
3074*da0073e9SAndroid Build Coastguard Worker        self._test_resize_shape([1, 0], [0, 2, 3], [2, 2, 3],
3075*da0073e9SAndroid Build Coastguard Worker                                [2, 0], [0, 2, 4, 0], [1, 1, 2, 4, 0],
3076*da0073e9SAndroid Build Coastguard Worker                                dtype=dtype, device=device)
3077*da0073e9SAndroid Build Coastguard Worker
3078*da0073e9SAndroid Build Coastguard Worker        # 4. Add dims to dense dimensions [Not Supported]
3079*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "changing the number of dense dimensions"):
3080*da0073e9SAndroid Build Coastguard Worker            self._test_resize_shape([1, 1], [1, 2, 3], [2, 2, 3],
3081*da0073e9SAndroid Build Coastguard Worker                                    [1, 1], [1, 2, 3, 4], [2, 2, 3, 4],
3082*da0073e9SAndroid Build Coastguard Worker                                    dtype=dtype, device=device)
3083*da0073e9SAndroid Build Coastguard Worker
3084*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "changing the number of dense dimensions"):
3085*da0073e9SAndroid Build Coastguard Worker            self._test_resize_shape([1, 1], [1, 2, 3], [2, 2, 3],
3086*da0073e9SAndroid Build Coastguard Worker                                    [1, 1], [1, 2, 3, 0], [2, 2, 3, 0],
3087*da0073e9SAndroid Build Coastguard Worker                                    dtype=dtype, device=device)
3088*da0073e9SAndroid Build Coastguard Worker
3089*da0073e9SAndroid Build Coastguard Worker        # 5. Remove dims from dense dimensions [Not Supported]
3090*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "changing the number of dense dimensions"):
3091*da0073e9SAndroid Build Coastguard Worker            self._test_resize_shape([1, 1], [1, 2, 3], [2, 2, 3],
3092*da0073e9SAndroid Build Coastguard Worker                                    [1, 1], [1, 2], [2, 2],
3093*da0073e9SAndroid Build Coastguard Worker                                    dtype=dtype, device=device)
3094*da0073e9SAndroid Build Coastguard Worker
3095*da0073e9SAndroid Build Coastguard Worker        # 6. Change the number of sparse dimensions on a non-empty sparse tensor [Not Supported]
3096*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "changing the number of sparse dimensions"):
3097*da0073e9SAndroid Build Coastguard Worker            self._test_resize_shape([1, 1], [1, 2, 3], [2, 2, 3],
3098*da0073e9SAndroid Build Coastguard Worker                                    [2, 1], [1, 2, 3], [1, 2, 2, 3],
3099*da0073e9SAndroid Build Coastguard Worker                                    dtype=dtype, device=device)
3100*da0073e9SAndroid Build Coastguard Worker
3101*da0073e9SAndroid Build Coastguard Worker        # 7. Shrink the size of some sparse dimensions on a non-empty sparse tensor [Not Supported]
3102*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "shrinking the size of sparse dimensions"):
3103*da0073e9SAndroid Build Coastguard Worker            self._test_resize_shape([1, 1], [1, 2, 3], [2, 2, 3],
3104*da0073e9SAndroid Build Coastguard Worker                                    [1, 1], [1, 2, 3], [1, 2, 3],
3105*da0073e9SAndroid Build Coastguard Worker                                    dtype=dtype, device=device)
3106*da0073e9SAndroid Build Coastguard Worker
3107*da0073e9SAndroid Build Coastguard Worker        # 8. Shrink the size of some dense dimensions on a non-empty sparse tensor [Not Supported]
3108*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "shrinking the size of dense dimensions"):
3109*da0073e9SAndroid Build Coastguard Worker            self._test_resize_shape([1, 1], [1, 2, 3], [2, 2, 3],
3110*da0073e9SAndroid Build Coastguard Worker                                    [1, 1], [1, 2, 2], [2, 2, 2],
3111*da0073e9SAndroid Build Coastguard Worker                                    dtype=dtype, device=device)
3112*da0073e9SAndroid Build Coastguard Worker
3113*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "shrinking the size of dense dimensions"):
3114*da0073e9SAndroid Build Coastguard Worker            self._test_resize_shape([1, 1], [1, 2, 3], [2, 2, 3],
3115*da0073e9SAndroid Build Coastguard Worker                                    [1, 1], [1, 2, 0], [2, 2, 0],
3116*da0073e9SAndroid Build Coastguard Worker                                    dtype=dtype, device=device)
3117*da0073e9SAndroid Build Coastguard Worker
3118*da0073e9SAndroid Build Coastguard Worker    def test_is_nonzero(self, device):
3119*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.sparse_coo_tensor(([0],), 1., (1,), device=device).is_nonzero())
3120*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.sparse_coo_tensor(([0],), 0., (1,), device=device).is_nonzero())
3121*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.sparse_coo_tensor(([0], [0]), 0., (1, 1), device=device).is_nonzero())
3122*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.sparse_coo_tensor(([0, 0],), (0., 0.), (1,), device=device).is_nonzero())
3123*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.sparse_coo_tensor(([0, 0],), (-1., 1.), (1,), device=device).is_nonzero())
3124*da0073e9SAndroid Build Coastguard Worker
3125*da0073e9SAndroid Build Coastguard Worker        # scalar sparse tensor
3126*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.sparse_coo_tensor(torch.zeros(0, 1), 12.3, [], device=device).is_nonzero())
3127*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with no values is ambiguous"):
3128*da0073e9SAndroid Build Coastguard Worker            torch.sparse_coo_tensor(([0, 1],), torch.empty(2, 0), (4, 0), device=device).is_nonzero()
3129*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.sparse_coo_tensor(([0],), 2.3 - 4.5j, (1,), dtype=torch.cfloat, device=device)
3130*da0073e9SAndroid Build Coastguard Worker                        .is_nonzero())
3131*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.sparse_coo_tensor(([0],), 2.3 - 4.5j, (1,), dtype=torch.cdouble, device=device)
3132*da0073e9SAndroid Build Coastguard Worker                        .is_nonzero())
3133*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.sparse_coo_tensor(([0],), 0. + 0j, (1,), dtype=torch.cfloat, device=device)
3134*da0073e9SAndroid Build Coastguard Worker                         .is_nonzero())
3135*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.sparse_coo_tensor(([0],), 0. + 0j, (1,), dtype=torch.cdouble, device=device)
3136*da0073e9SAndroid Build Coastguard Worker                         .is_nonzero())
3137*da0073e9SAndroid Build Coastguard Worker
3138*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
3139*da0073e9SAndroid Build Coastguard Worker    def test_change_tensor_metadata(self, device, dtype):
3140*da0073e9SAndroid Build Coastguard Worker        i = self.index_tensor([[0], [1]], device=device)
3141*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([[3, 4, 5]], dtype=dtype, device=device)
3142*da0073e9SAndroid Build Coastguard Worker        t = torch.sparse_coo_tensor(i, v, torch.Size([1, 2, 3]), dtype=dtype, device=device)
3143*da0073e9SAndroid Build Coastguard Worker        i.resize_(2, 3)
3144*da0073e9SAndroid Build Coastguard Worker        v.resize_(4, 5)
3145*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(t.coalesce().indices().size()), [2, 1])
3146*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(t.coalesce().values().size()), [1, 3])
3147*da0073e9SAndroid Build Coastguard Worker
3148*da0073e9SAndroid Build Coastguard Worker        i = self.index_tensor([[0], [1]], device=device)
3149*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([[3, 4, 5]], dtype=dtype, device=device)
3150*da0073e9SAndroid Build Coastguard Worker        t = torch.sparse_coo_tensor(i, v, torch.Size([1, 2, 3]))
3151*da0073e9SAndroid Build Coastguard Worker        i.resize_as_(self.index_tensor([0, 1], device=device))
3152*da0073e9SAndroid Build Coastguard Worker        v.resize_as_(torch.tensor([3, 4, 5], dtype=dtype, device=device))
3153*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(t.coalesce().indices().size()), [2, 1])
3154*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(t.coalesce().values().size()), [1, 3])
3155*da0073e9SAndroid Build Coastguard Worker
3156*da0073e9SAndroid Build Coastguard Worker        i = self.index_tensor([[0], [1]], device=device)
3157*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([[3, 4, 5]], dtype=dtype, device=device)
3158*da0073e9SAndroid Build Coastguard Worker        t = torch.sparse_coo_tensor(i, v, torch.Size([1, 2, 3]))
3159*da0073e9SAndroid Build Coastguard Worker        i.as_strided_((2, 1), (1, 1))
3160*da0073e9SAndroid Build Coastguard Worker        v.as_strided_((1, 3), (1, 1))
3161*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(t.coalesce().indices().size()), [2, 1])
3162*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(t.coalesce().values().size()), [1, 3])
3163*da0073e9SAndroid Build Coastguard Worker
3164*da0073e9SAndroid Build Coastguard Worker        i = self.index_tensor([[0], [1]], device=device)
3165*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([[3, 4, 5]], dtype=dtype, device=device)
3166*da0073e9SAndroid Build Coastguard Worker        t = torch.sparse_coo_tensor(i, v, torch.Size([1, 2, 3]))
3167*da0073e9SAndroid Build Coastguard Worker        i.set_(self.index_tensor([0, 1], device=device))
3168*da0073e9SAndroid Build Coastguard Worker        v.set_(torch.tensor([3, 4, 5], dtype=dtype, device=device))
3169*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(t.coalesce().indices().size()), [2, 1])
3170*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(t.coalesce().values().size()), [1, 3])
3171*da0073e9SAndroid Build Coastguard Worker
3172*da0073e9SAndroid Build Coastguard Worker        i = self.index_tensor([[0], [1]], device=device)
3173*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([[3, 4, 5]], dtype=dtype, device=device)
3174*da0073e9SAndroid Build Coastguard Worker        t = torch.sparse_coo_tensor(i, v, torch.Size([1, 2, 3]))
3175*da0073e9SAndroid Build Coastguard Worker        i.transpose_(0, 1)
3176*da0073e9SAndroid Build Coastguard Worker        v.transpose_(0, 1)
3177*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(t.coalesce().indices().size()), [2, 1])
3178*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(t.coalesce().values().size()), [1, 3])
3179*da0073e9SAndroid Build Coastguard Worker
3180*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
3181*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
3182*da0073e9SAndroid Build Coastguard Worker    def test_pickle(self, device, dtype, coalesced):
3183*da0073e9SAndroid Build Coastguard Worker        import pickle
3184*da0073e9SAndroid Build Coastguard Worker
3185*da0073e9SAndroid Build Coastguard Worker        shape_sparse_dim_nnz = [
3186*da0073e9SAndroid Build Coastguard Worker            ((), 0, 2),
3187*da0073e9SAndroid Build Coastguard Worker            ((0,), 0, 10),
3188*da0073e9SAndroid Build Coastguard Worker            ((2,), 0, 3),
3189*da0073e9SAndroid Build Coastguard Worker            ((100, 3), 1, 3),
3190*da0073e9SAndroid Build Coastguard Worker            ((100, 20, 3), 2, 0),
3191*da0073e9SAndroid Build Coastguard Worker            ((10, 0, 3), 0, 3),
3192*da0073e9SAndroid Build Coastguard Worker            ((10, 0, 3), 0, 0),
3193*da0073e9SAndroid Build Coastguard Worker        ]
3194*da0073e9SAndroid Build Coastguard Worker
3195*da0073e9SAndroid Build Coastguard Worker        for shape, sparse_dim, nnz in shape_sparse_dim_nnz:
3196*da0073e9SAndroid Build Coastguard Worker            indices_shape = torch.Size((sparse_dim, nnz))
3197*da0073e9SAndroid Build Coastguard Worker            values_shape = torch.Size((nnz,) + shape[sparse_dim:])
3198*da0073e9SAndroid Build Coastguard Worker            indices = torch.arange(indices_shape.numel(), dtype=self.index_tensor(0).dtype,
3199*da0073e9SAndroid Build Coastguard Worker                                   device=device).view(indices_shape)
3200*da0073e9SAndroid Build Coastguard Worker            for d in range(sparse_dim):
3201*da0073e9SAndroid Build Coastguard Worker                indices[d].clamp_(max=(shape[d] - 1))  # make it valid index
3202*da0073e9SAndroid Build Coastguard Worker            if not coalesced and indices.numel() > 0:
3203*da0073e9SAndroid Build Coastguard Worker                indices[:, -1] = indices[:, 0]  # make it uncoalesced
3204*da0073e9SAndroid Build Coastguard Worker            values_numel = values_shape.numel()
3205*da0073e9SAndroid Build Coastguard Worker            values = torch.arange(values_numel, dtype=dtype,
3206*da0073e9SAndroid Build Coastguard Worker                                  device=device).view(values_shape).div_(values_numel / 2.)
3207*da0073e9SAndroid Build Coastguard Worker            sp_tensor = self.sparse_tensor(indices, values, shape)
3208*da0073e9SAndroid Build Coastguard Worker            serialized = pickle.dumps(sp_tensor)
3209*da0073e9SAndroid Build Coastguard Worker            sp_tensor_loaded = pickle.loads(serialized)
3210*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(sp_tensor, sp_tensor_loaded)
3211*da0073e9SAndroid Build Coastguard Worker
3212*da0073e9SAndroid Build Coastguard Worker    def test_any(self, device):
3213*da0073e9SAndroid Build Coastguard Worker        t = torch.sparse_coo_tensor(torch.tensor(([0, 0], [2, 0])), torch.tensor([False, False]), device=device)
3214*da0073e9SAndroid Build Coastguard Worker        t_any = torch.tensor(False)
3215*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.any(t), t_any)
3216*da0073e9SAndroid Build Coastguard Worker        t = torch.sparse_coo_tensor(torch.tensor(([0, 0], [2, 0])), torch.tensor([True, False]), device=device)
3217*da0073e9SAndroid Build Coastguard Worker        t_any = torch.tensor(True)
3218*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.any(t), t_any)
3219*da0073e9SAndroid Build Coastguard Worker
3220*da0073e9SAndroid Build Coastguard Worker    def test_isnan(self, device):
3221*da0073e9SAndroid Build Coastguard Worker        t = torch.sparse_coo_tensor(torch.tensor(([0, 0], [0, 2])), torch.tensor([1, 4]), device=device)
3222*da0073e9SAndroid Build Coastguard Worker        t_nan = torch.sparse_coo_tensor(torch.tensor(([0, 0], [0, 2])), torch.tensor([False, False]), device=device)
3223*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.isnan(t).int(), t_nan.int())
3224*da0073e9SAndroid Build Coastguard Worker        t = torch.sparse_coo_tensor(torch.tensor(([0, 0], [0, 2])), torch.tensor([1, float("nan")]), device=device)
3225*da0073e9SAndroid Build Coastguard Worker        t_nan = torch.sparse_coo_tensor(torch.tensor(([0, 0], [0, 2])), torch.tensor([False, True]), device=device)
3226*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.isnan(t).int(), t_nan.int())
3227*da0073e9SAndroid Build Coastguard Worker
3228*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
3229*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.float64)
3230*da0073e9SAndroid Build Coastguard Worker    def test_div_rounding_mode(self, device, dtype, coalesced):
3231*da0073e9SAndroid Build Coastguard Worker        sparse, _, _ = self._gen_sparse(2, 10, (10, 10), dtype,
3232*da0073e9SAndroid Build Coastguard Worker                                        device, coalesced)
3233*da0073e9SAndroid Build Coastguard Worker        dense = self.safeToDense(sparse)
3234*da0073e9SAndroid Build Coastguard Worker
3235*da0073e9SAndroid Build Coastguard Worker        for mode in (None, 'floor', 'trunc'):
3236*da0073e9SAndroid Build Coastguard Worker            actual = sparse.div(-2, rounding_mode=mode)
3237*da0073e9SAndroid Build Coastguard Worker            expect = dense.div(-2, rounding_mode=mode)
3238*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(self.safeToDense(actual), expect)
3239*da0073e9SAndroid Build Coastguard Worker
3240*da0073e9SAndroid Build Coastguard Worker            # Test inplace
3241*da0073e9SAndroid Build Coastguard Worker            actual = sparse.clone().div_(-2, rounding_mode=mode)
3242*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(self.safeToDense(actual), expect)
3243*da0073e9SAndroid Build Coastguard Worker
3244*da0073e9SAndroid Build Coastguard Worker            # Test out argument
3245*da0073e9SAndroid Build Coastguard Worker            actual.zero_()
3246*da0073e9SAndroid Build Coastguard Worker            torch.div(sparse, -2, rounding_mode=mode, out=actual)
3247*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(self.safeToDense(actual), expect)
3248*da0073e9SAndroid Build Coastguard Worker
3249*da0073e9SAndroid Build Coastguard Worker    def test_div_by_sparse_error(self, device):
3250*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(RuntimeError, 'Sparse division requires',
3251*da0073e9SAndroid Build Coastguard Worker                               lambda: torch.tensor(1., device=device).to_sparse()
3252*da0073e9SAndroid Build Coastguard Worker                               / torch.tensor(1., device=device).to_sparse())
3253*da0073e9SAndroid Build Coastguard Worker
3254*da0073e9SAndroid Build Coastguard Worker    def test_floor_divide_by_sparse_error(self, device):
3255*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(RuntimeError, 'Sparse floor division requires',
3256*da0073e9SAndroid Build Coastguard Worker                               lambda: torch.tensor(1., device=device).to_sparse()
3257*da0073e9SAndroid Build Coastguard Worker                               // torch.tensor(1., device=device).to_sparse())
3258*da0073e9SAndroid Build Coastguard Worker
3259*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
3260*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
3261*da0073e9SAndroid Build Coastguard Worker    def test_sparse_to_numpy(self, device):
3262*da0073e9SAndroid Build Coastguard Worker        t = torch.sparse_coo_tensor(torch.tensor(([0, 0], [2, 0])), torch.tensor([1, 4]))
3263*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(TypeError, lambda: t.numpy())
3264*da0073e9SAndroid Build Coastguard Worker
3265*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
3266*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
3267*da0073e9SAndroid Build Coastguard Worker    def test_softmax(self, device, dtype, coalesced):
3268*da0073e9SAndroid Build Coastguard Worker        import torch.nn.functional as F
3269*da0073e9SAndroid Build Coastguard Worker
3270*da0073e9SAndroid Build Coastguard Worker        def to_dense(sparse, fill_value=None):
3271*da0073e9SAndroid Build Coastguard Worker            """
3272*da0073e9SAndroid Build Coastguard Worker            Return dense tensor from a sparse tensor using given fill value.
3273*da0073e9SAndroid Build Coastguard Worker            """
3274*da0073e9SAndroid Build Coastguard Worker            if fill_value is None or fill_value == 0:
3275*da0073e9SAndroid Build Coastguard Worker                return sparse.to_dense()
3276*da0073e9SAndroid Build Coastguard Worker            sparse = sparse.coalesce()
3277*da0073e9SAndroid Build Coastguard Worker            dense = torch.full(sparse.shape, fill_value, dtype=sparse.dtype, device=sparse.device)
3278*da0073e9SAndroid Build Coastguard Worker            for idx, value in zip(sparse._indices().t(), sparse._values()):
3279*da0073e9SAndroid Build Coastguard Worker                dense[tuple(idx)] = value
3280*da0073e9SAndroid Build Coastguard Worker            return dense
3281*da0073e9SAndroid Build Coastguard Worker
3282*da0073e9SAndroid Build Coastguard Worker        def softmax_to_dense(sparse, dim):
3283*da0073e9SAndroid Build Coastguard Worker            """Dense softmax of a sparse tensor. Useful only for testing softmax
3284*da0073e9SAndroid Build Coastguard Worker            correctness.
3285*da0073e9SAndroid Build Coastguard Worker
3286*da0073e9SAndroid Build Coastguard Worker            When computing softmax of a sparse tensor, the value of
3287*da0073e9SAndroid Build Coastguard Worker            unspecified items is negative infinity rather than zero so
3288*da0073e9SAndroid Build Coastguard Worker            that
3289*da0073e9SAndroid Build Coastguard Worker
3290*da0073e9SAndroid Build Coastguard Worker              softmax(sparse.to_dense(fill_value=-inf), dim) == softmax(sparse, dim).to_dense()
3291*da0073e9SAndroid Build Coastguard Worker
3292*da0073e9SAndroid Build Coastguard Worker            holds for non-empty lines. One empty lines, the softmax
3293*da0073e9SAndroid Build Coastguard Worker            values are defined as 0 in order to preserve the sparsity
3294*da0073e9SAndroid Build Coastguard Worker            of result.
3295*da0073e9SAndroid Build Coastguard Worker
3296*da0073e9SAndroid Build Coastguard Worker            Note that in PyTorch, ``to_dense`` method does not
3297*da0073e9SAndroid Build Coastguard Worker            implement the ``fill_value`` keyword argument.
3298*da0073e9SAndroid Build Coastguard Worker            """
3299*da0073e9SAndroid Build Coastguard Worker            dtype = sparse.dtype
3300*da0073e9SAndroid Build Coastguard Worker            device = sparse.device
3301*da0073e9SAndroid Build Coastguard Worker            dense = to_dense(sparse, fill_value=-float('inf'))
3302*da0073e9SAndroid Build Coastguard Worker            r = F.softmax(dense, dim)
3303*da0073e9SAndroid Build Coastguard Worker            # softmax on empty lines results nan, replace with zeros to match the definition
3304*da0073e9SAndroid Build Coastguard Worker            r[r != r] = 0
3305*da0073e9SAndroid Build Coastguard Worker            return r
3306*da0073e9SAndroid Build Coastguard Worker
3307*da0073e9SAndroid Build Coastguard Worker        def sparse_softmax(sparse, dim):
3308*da0073e9SAndroid Build Coastguard Worker            """Pure Python softmax of a sparse tensor. Assuming -inf for
3309*da0073e9SAndroid Build Coastguard Worker            unspecified sparse tensor data. This is a prototype of
3310*da0073e9SAndroid Build Coastguard Worker            sparse softmax algorithm in Python.
3311*da0073e9SAndroid Build Coastguard Worker            """
3312*da0073e9SAndroid Build Coastguard Worker            dtype = sparse.dtype
3313*da0073e9SAndroid Build Coastguard Worker            device = sparse.device
3314*da0073e9SAndroid Build Coastguard Worker
3315*da0073e9SAndroid Build Coastguard Worker            # softmax is non-linear operation, so sparse tensors must
3316*da0073e9SAndroid Build Coastguard Worker            # be coalesced.
3317*da0073e9SAndroid Build Coastguard Worker            sparse = sparse.coalesce()
3318*da0073e9SAndroid Build Coastguard Worker            inf = float('inf')
3319*da0073e9SAndroid Build Coastguard Worker            indices = sparse._indices()
3320*da0073e9SAndroid Build Coastguard Worker            values = sparse._values()
3321*da0073e9SAndroid Build Coastguard Worker
3322*da0073e9SAndroid Build Coastguard Worker            if dim < sparse.sparse_dim():
3323*da0073e9SAndroid Build Coastguard Worker                nnz = sparse._nnz()
3324*da0073e9SAndroid Build Coastguard Worker
3325*da0073e9SAndroid Build Coastguard Worker                # compute pool indices
3326*da0073e9SAndroid Build Coastguard Worker                size = sparse.size()
3327*da0073e9SAndroid Build Coastguard Worker                strides = torch.ones((sparse.sparse_dim(), 1), dtype=indices.dtype, device=indices.device)
3328*da0073e9SAndroid Build Coastguard Worker                for i in reversed(range(sparse.sparse_dim() - 1)):
3329*da0073e9SAndroid Build Coastguard Worker                    strides[i, 0] = strides[i + 1, 0] * size[i + 1]
3330*da0073e9SAndroid Build Coastguard Worker                strides[dim, 0] = 0
3331*da0073e9SAndroid Build Coastguard Worker
3332*da0073e9SAndroid Build Coastguard Worker                pool = (indices * strides).sum(dim=0)
3333*da0073e9SAndroid Build Coastguard Worker                i2p = {}
3334*da0073e9SAndroid Build Coastguard Worker                for i in range(nnz):
3335*da0073e9SAndroid Build Coastguard Worker                    c = int(pool[i])
3336*da0073e9SAndroid Build Coastguard Worker                    if c not in i2p:
3337*da0073e9SAndroid Build Coastguard Worker                        i2p[c] = len(i2p)
3338*da0073e9SAndroid Build Coastguard Worker                    pool[i] = i2p[c]
3339*da0073e9SAndroid Build Coastguard Worker
3340*da0073e9SAndroid Build Coastguard Worker                # compute max
3341*da0073e9SAndroid Build Coastguard Worker                dense_size = tuple(size[sparse.sparse_dim():])
3342*da0073e9SAndroid Build Coastguard Worker                mx = torch.empty((pool.max() + 1,) + dense_size, dtype=dtype, device=device)
3343*da0073e9SAndroid Build Coastguard Worker                mx[:] = -inf
3344*da0073e9SAndroid Build Coastguard Worker                for n in range(nnz):
3345*da0073e9SAndroid Build Coastguard Worker                    p = pool[n]
3346*da0073e9SAndroid Build Coastguard Worker                    mx[p] = torch.max(mx[p], values[n])
3347*da0073e9SAndroid Build Coastguard Worker
3348*da0073e9SAndroid Build Coastguard Worker                # apply exp to (v - mx) and sum the results
3349*da0073e9SAndroid Build Coastguard Worker                exp_values = torch.empty_like(values)
3350*da0073e9SAndroid Build Coastguard Worker                exp_sums = torch.zeros_like(mx)
3351*da0073e9SAndroid Build Coastguard Worker                for n in range(nnz):
3352*da0073e9SAndroid Build Coastguard Worker                    p = pool[n]
3353*da0073e9SAndroid Build Coastguard Worker                    v = exp_values[n] = (values[n] - mx[p]).exp()
3354*da0073e9SAndroid Build Coastguard Worker                    exp_sums[p] = exp_sums[p] + v
3355*da0073e9SAndroid Build Coastguard Worker
3356*da0073e9SAndroid Build Coastguard Worker                # normalize with the sum of exponents
3357*da0073e9SAndroid Build Coastguard Worker                for n in range(nnz):
3358*da0073e9SAndroid Build Coastguard Worker                    p = pool[n]
3359*da0073e9SAndroid Build Coastguard Worker                    exp_values[n] = exp_values[n] / exp_sums[p]
3360*da0073e9SAndroid Build Coastguard Worker
3361*da0073e9SAndroid Build Coastguard Worker                return torch.sparse_coo_tensor(indices,
3362*da0073e9SAndroid Build Coastguard Worker                                               exp_values,
3363*da0073e9SAndroid Build Coastguard Worker                                               sparse.size(),
3364*da0073e9SAndroid Build Coastguard Worker                                               dtype=dtype, device=device)
3365*da0073e9SAndroid Build Coastguard Worker
3366*da0073e9SAndroid Build Coastguard Worker            elif dim < sparse.sparse_dim() + sparse.dense_dim():
3367*da0073e9SAndroid Build Coastguard Worker                return torch.sparse_coo_tensor(indices,
3368*da0073e9SAndroid Build Coastguard Worker                                               F.softmax(values, dim - sparse.sparse_dim() + 1),
3369*da0073e9SAndroid Build Coastguard Worker                                               sparse.size(),
3370*da0073e9SAndroid Build Coastguard Worker                                               dtype=dtype, device=device)
3371*da0073e9SAndroid Build Coastguard Worker            else:
3372*da0073e9SAndroid Build Coastguard Worker                raise ValueError(
3373*da0073e9SAndroid Build Coastguard Worker                    f'`dim(={dim})` must be smaller than `sparse_dim(={sparse.sparse_dim()}) + dense_dim(={sparse.dense_dim()})`')
3374*da0073e9SAndroid Build Coastguard Worker
3375*da0073e9SAndroid Build Coastguard Worker        def softmax_jacobian_analytic(x, dim):
3376*da0073e9SAndroid Build Coastguard Worker            """Return Jacobian of softmax using analytic formula
3377*da0073e9SAndroid Build Coastguard Worker
3378*da0073e9SAndroid Build Coastguard Worker               D_jS_i = S_i * (1[i==j] - S_j).
3379*da0073e9SAndroid Build Coastguard Worker
3380*da0073e9SAndroid Build Coastguard Worker            where S = softmax(x, dim), x is dense tensor, i,j in
3381*da0073e9SAndroid Build Coastguard Worker            range(x.shape[dim]).
3382*da0073e9SAndroid Build Coastguard Worker            """
3383*da0073e9SAndroid Build Coastguard Worker            y = F.softmax(x, dim)
3384*da0073e9SAndroid Build Coastguard Worker            y[y != y] = 0  # replace nan-s with zeros
3385*da0073e9SAndroid Build Coastguard Worker            J = torch.zeros((x.shape[dim],) + tuple(x.shape), dtype=x.dtype, device=x.device)
3386*da0073e9SAndroid Build Coastguard Worker            si = [slice(None)] * len(y.shape)
3387*da0073e9SAndroid Build Coastguard Worker            sj = [slice(None)] * len(y.shape)
3388*da0073e9SAndroid Build Coastguard Worker            s = [slice(None)] * len(J.shape)
3389*da0073e9SAndroid Build Coastguard Worker            for i in range(y.shape[dim]):
3390*da0073e9SAndroid Build Coastguard Worker                si[dim] = i
3391*da0073e9SAndroid Build Coastguard Worker                s[dim + 1] = i
3392*da0073e9SAndroid Build Coastguard Worker                yi = y[tuple(si)]
3393*da0073e9SAndroid Build Coastguard Worker                for j in range(y.shape[dim]):
3394*da0073e9SAndroid Build Coastguard Worker                    sj[dim] = j
3395*da0073e9SAndroid Build Coastguard Worker                    s[0] = j
3396*da0073e9SAndroid Build Coastguard Worker                    if i == j:
3397*da0073e9SAndroid Build Coastguard Worker                        J[tuple(s)] = yi * (1 - yi)
3398*da0073e9SAndroid Build Coastguard Worker                    else:
3399*da0073e9SAndroid Build Coastguard Worker                        yj = y[tuple(sj)]
3400*da0073e9SAndroid Build Coastguard Worker                        J[tuple(s)] = - yi * yj
3401*da0073e9SAndroid Build Coastguard Worker                    sj[dim] = slice(None)
3402*da0073e9SAndroid Build Coastguard Worker                si[dim] = slice(None)
3403*da0073e9SAndroid Build Coastguard Worker                s[dim + 1] = slice(None)
3404*da0073e9SAndroid Build Coastguard Worker            return J
3405*da0073e9SAndroid Build Coastguard Worker
3406*da0073e9SAndroid Build Coastguard Worker        def softmax_jacobian_autograd(x, dim, log=False):
3407*da0073e9SAndroid Build Coastguard Worker            """Return Jacobian of softmax using PyTorch autograd feature.
3408*da0073e9SAndroid Build Coastguard Worker
3409*da0073e9SAndroid Build Coastguard Worker            x can be dense or sparse tensor.
3410*da0073e9SAndroid Build Coastguard Worker            """
3411*da0073e9SAndroid Build Coastguard Worker            import itertools
3412*da0073e9SAndroid Build Coastguard Worker
3413*da0073e9SAndroid Build Coastguard Worker            if x.is_sparse:
3414*da0073e9SAndroid Build Coastguard Worker                x = x.coalesce()
3415*da0073e9SAndroid Build Coastguard Worker
3416*da0073e9SAndroid Build Coastguard Worker            dtype = x.dtype
3417*da0073e9SAndroid Build Coastguard Worker            device = x.device
3418*da0073e9SAndroid Build Coastguard Worker            shape = tuple(x.shape)
3419*da0073e9SAndroid Build Coastguard Worker            J = torch.zeros((shape[dim],) + shape, dtype=dtype, device=device)
3420*da0073e9SAndroid Build Coastguard Worker            for i in range(shape[dim]):
3421*da0073e9SAndroid Build Coastguard Worker                if x.is_sparse:
3422*da0073e9SAndroid Build Coastguard Worker                    sparse_dim = x.sparse_dim()
3423*da0073e9SAndroid Build Coastguard Worker                    dense_dim = x.dense_dim()
3424*da0073e9SAndroid Build Coastguard Worker                    if dim < sparse_dim:
3425*da0073e9SAndroid Build Coastguard Worker                        ranges = []
3426*da0073e9SAndroid Build Coastguard Worker                        for j, sz in enumerate(shape[:sparse_dim]):
3427*da0073e9SAndroid Build Coastguard Worker                            if dim == j:
3428*da0073e9SAndroid Build Coastguard Worker                                ranges.append([i])
3429*da0073e9SAndroid Build Coastguard Worker                            else:
3430*da0073e9SAndroid Build Coastguard Worker                                ranges.append(list(range(sz)))
3431*da0073e9SAndroid Build Coastguard Worker                        indices = torch.tensor(list(itertools.product(*ranges)), dtype=torch.long, device=device).t()
3432*da0073e9SAndroid Build Coastguard Worker                        values = torch.ones((indices.shape[1],) + shape[sparse_dim:], dtype=dtype, device=device)
3433*da0073e9SAndroid Build Coastguard Worker                    else:
3434*da0073e9SAndroid Build Coastguard Worker                        ranges = []
3435*da0073e9SAndroid Build Coastguard Worker                        for j, sz in enumerate(shape[:sparse_dim]):
3436*da0073e9SAndroid Build Coastguard Worker                            ranges.append(list(range(sz)))
3437*da0073e9SAndroid Build Coastguard Worker                        indices = torch.tensor(list(itertools.product(*ranges)), dtype=torch.long, device=device).t()
3438*da0073e9SAndroid Build Coastguard Worker                        values = torch.zeros((indices.shape[1],) + shape[sparse_dim:], dtype=dtype, device=device)
3439*da0073e9SAndroid Build Coastguard Worker                        sv = [slice(None)] * (dense_dim + 1)
3440*da0073e9SAndroid Build Coastguard Worker                        sv[dim - sparse_dim + 1] = i
3441*da0073e9SAndroid Build Coastguard Worker                        values[tuple(sv)] = 1
3442*da0073e9SAndroid Build Coastguard Worker                    v = torch.sparse_coo_tensor(indices, values, shape, dtype=dtype, device=device)
3443*da0073e9SAndroid Build Coastguard Worker                else:
3444*da0073e9SAndroid Build Coastguard Worker                    v = torch.zeros_like(x)
3445*da0073e9SAndroid Build Coastguard Worker                    sv = [slice(None)] * len(v.shape)
3446*da0073e9SAndroid Build Coastguard Worker                    sv[dim] = i
3447*da0073e9SAndroid Build Coastguard Worker                    v[tuple(sv)] = 1
3448*da0073e9SAndroid Build Coastguard Worker                x_ = x.clone()
3449*da0073e9SAndroid Build Coastguard Worker                x_.requires_grad_(True)
3450*da0073e9SAndroid Build Coastguard Worker
3451*da0073e9SAndroid Build Coastguard Worker                if log:
3452*da0073e9SAndroid Build Coastguard Worker                    if x_.is_sparse:
3453*da0073e9SAndroid Build Coastguard Worker                        y = torch.sparse.log_softmax(x_, dim)
3454*da0073e9SAndroid Build Coastguard Worker                    else:
3455*da0073e9SAndroid Build Coastguard Worker                        y = F.log_softmax(x_, dim)
3456*da0073e9SAndroid Build Coastguard Worker                else:
3457*da0073e9SAndroid Build Coastguard Worker                    if x_.is_sparse:
3458*da0073e9SAndroid Build Coastguard Worker                        y = torch.sparse.softmax(x_, dim)
3459*da0073e9SAndroid Build Coastguard Worker                    else:
3460*da0073e9SAndroid Build Coastguard Worker                        y = F.softmax(x_, dim)
3461*da0073e9SAndroid Build Coastguard Worker                        # replace nan-s with zeros
3462*da0073e9SAndroid Build Coastguard Worker                        y.data[y != y] = 0
3463*da0073e9SAndroid Build Coastguard Worker                y.backward(v)
3464*da0073e9SAndroid Build Coastguard Worker                g = x_.grad
3465*da0073e9SAndroid Build Coastguard Worker                if not g.is_sparse:
3466*da0073e9SAndroid Build Coastguard Worker                    # replace nan-s with zeros
3467*da0073e9SAndroid Build Coastguard Worker                    g.data[g != g] = 0
3468*da0073e9SAndroid Build Coastguard Worker                J[i] = g.to_dense() if g.is_sparse else g
3469*da0073e9SAndroid Build Coastguard Worker            return J
3470*da0073e9SAndroid Build Coastguard Worker
3471*da0073e9SAndroid Build Coastguard Worker        @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1166")
3472*da0073e9SAndroid Build Coastguard Worker        def test_op(sparse_dims, nnz, with_size, coalesced):
3473*da0073e9SAndroid Build Coastguard Worker            if isinstance(with_size, Number):
3474*da0073e9SAndroid Build Coastguard Worker                with_size = [with_size] * sparse_dims
3475*da0073e9SAndroid Build Coastguard Worker
3476*da0073e9SAndroid Build Coastguard Worker            x, i, v = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)
3477*da0073e9SAndroid Build Coastguard Worker
3478*da0073e9SAndroid Build Coastguard Worker            def sparse_log(x):
3479*da0073e9SAndroid Build Coastguard Worker                return torch.sparse_coo_tensor(x._indices(), x._values().log(),
3480*da0073e9SAndroid Build Coastguard Worker                                               x.size(), dtype=x.dtype, device=x.device)
3481*da0073e9SAndroid Build Coastguard Worker
3482*da0073e9SAndroid Build Coastguard Worker            # Check dim out of bounds
3483*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(IndexError, r"Dimension out of range"):
3484*da0073e9SAndroid Build Coastguard Worker                torch.sparse.softmax(x, x.dim())
3485*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(IndexError, r"Dimension out of range"):
3486*da0073e9SAndroid Build Coastguard Worker                torch.sparse.softmax(x, -x.dim() - 1)
3487*da0073e9SAndroid Build Coastguard Worker
3488*da0073e9SAndroid Build Coastguard Worker            for dim in range(x.dim()):
3489*da0073e9SAndroid Build Coastguard Worker                # Check sparse softmax definition
3490*da0073e9SAndroid Build Coastguard Worker
3491*da0073e9SAndroid Build Coastguard Worker                # check Python sparse softmax
3492*da0073e9SAndroid Build Coastguard Worker                y = sparse_softmax(x, dim)
3493*da0073e9SAndroid Build Coastguard Worker                r1 = softmax_to_dense(x, dim)
3494*da0073e9SAndroid Build Coastguard Worker                r2 = y.to_dense()
3495*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(r1, r2)
3496*da0073e9SAndroid Build Coastguard Worker
3497*da0073e9SAndroid Build Coastguard Worker                # check C++ sparse softmax
3498*da0073e9SAndroid Build Coastguard Worker                for d in (dim, dim - x.dim()):
3499*da0073e9SAndroid Build Coastguard Worker                    y1 = torch.sparse.softmax(x, d)
3500*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(y, y1)
3501*da0073e9SAndroid Build Coastguard Worker
3502*da0073e9SAndroid Build Coastguard Worker                    # check C++ sparse log_softmax
3503*da0073e9SAndroid Build Coastguard Worker                    ly1 = torch.sparse.log_softmax(x, d)
3504*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(ly1, sparse_log(y1))
3505*da0073e9SAndroid Build Coastguard Worker
3506*da0073e9SAndroid Build Coastguard Worker                # Check autograd support on sparse softmax
3507*da0073e9SAndroid Build Coastguard Worker
3508*da0073e9SAndroid Build Coastguard Worker                # check softmax Jacobian definition for dense input
3509*da0073e9SAndroid Build Coastguard Worker                x1 = to_dense(x, fill_value=float('-inf'))
3510*da0073e9SAndroid Build Coastguard Worker                J = softmax_jacobian_analytic(x1, dim)
3511*da0073e9SAndroid Build Coastguard Worker                assert J.shape[0] == x.shape[dim]
3512*da0073e9SAndroid Build Coastguard Worker                assert J.shape[dim + 1] == x.shape[dim]
3513*da0073e9SAndroid Build Coastguard Worker
3514*da0073e9SAndroid Build Coastguard Worker                # check softmax Jacobian from autograd, dense input
3515*da0073e9SAndroid Build Coastguard Worker                J2 = softmax_jacobian_autograd(x1, dim)
3516*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(J, J2)
3517*da0073e9SAndroid Build Coastguard Worker
3518*da0073e9SAndroid Build Coastguard Worker                # check softmax Jacobian from autograd, sparse input
3519*da0073e9SAndroid Build Coastguard Worker                J3 = softmax_jacobian_autograd(x, dim)
3520*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(J, J3)
3521*da0073e9SAndroid Build Coastguard Worker
3522*da0073e9SAndroid Build Coastguard Worker                '''
3523*da0073e9SAndroid Build Coastguard Worker                y = softmax(x, dim)
3524*da0073e9SAndroid Build Coastguard Worker                z = log(y) = log_softmax(x, dim)
3525*da0073e9SAndroid Build Coastguard Worker                Dy/Dx = J
3526*da0073e9SAndroid Build Coastguard Worker                Dz/Dx = Dz/Dy Dy/Dx = 1/y * J
3527*da0073e9SAndroid Build Coastguard Worker                => J = J_log * y
3528*da0073e9SAndroid Build Coastguard Worker                '''
3529*da0073e9SAndroid Build Coastguard Worker                # log_softmax Jacobian from autograd, dense input
3530*da0073e9SAndroid Build Coastguard Worker                J2_log = softmax_jacobian_autograd(x1, dim, log=True)
3531*da0073e9SAndroid Build Coastguard Worker
3532*da0073e9SAndroid Build Coastguard Worker                # log_softmax Jacobian from autograd, sparse input
3533*da0073e9SAndroid Build Coastguard Worker                J3_log = softmax_jacobian_autograd(x, dim, log=True)
3534*da0073e9SAndroid Build Coastguard Worker
3535*da0073e9SAndroid Build Coastguard Worker                J = J.transpose(0, dim + 1)
3536*da0073e9SAndroid Build Coastguard Worker                J2_log = J2_log.transpose(0, dim + 1)
3537*da0073e9SAndroid Build Coastguard Worker                J3_log = J3_log.transpose(0, dim + 1)
3538*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(J, J2_log * r1)
3539*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(J, J3_log * r1)
3540*da0073e9SAndroid Build Coastguard Worker
3541*da0073e9SAndroid Build Coastguard Worker                if dim == 0:
3542*da0073e9SAndroid Build Coastguard Worker                    # check dtype argument
3543*da0073e9SAndroid Build Coastguard Worker                    other_dtype = torch.float32
3544*da0073e9SAndroid Build Coastguard Worker                    y2 = torch.sparse.softmax(x, dim, dtype=other_dtype)
3545*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(y2.dtype, other_dtype)
3546*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(y2, y1.type(other_dtype))
3547*da0073e9SAndroid Build Coastguard Worker
3548*da0073e9SAndroid Build Coastguard Worker                    ly2 = torch.sparse.log_softmax(x, dim, dtype=other_dtype)
3549*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(ly2.dtype, other_dtype)
3550*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(ly2, ly1.type(other_dtype))
3551*da0073e9SAndroid Build Coastguard Worker
3552*da0073e9SAndroid Build Coastguard Worker        test_op(1, 10, [3], coalesced)
3553*da0073e9SAndroid Build Coastguard Worker        test_op(1, 10, [2, 3], coalesced)
3554*da0073e9SAndroid Build Coastguard Worker        test_op(1, 10, [3, 2], coalesced)
3555*da0073e9SAndroid Build Coastguard Worker        test_op(2, 10, [2, 3, 4], coalesced)
3556*da0073e9SAndroid Build Coastguard Worker        test_op(2, 10, [3, 4], coalesced)
3557*da0073e9SAndroid Build Coastguard Worker        test_op(2, 5, [5, 4], coalesced)
3558*da0073e9SAndroid Build Coastguard Worker        test_op(2, 10, [3, 4, 2], coalesced)
3559*da0073e9SAndroid Build Coastguard Worker        test_op(3, 10, [3, 4, 2], coalesced)
3560*da0073e9SAndroid Build Coastguard Worker        test_op(3, 100, [3, 4, 2], coalesced)
3561*da0073e9SAndroid Build Coastguard Worker        test_op(3, 100, [3, 4, 2, 3], coalesced)
3562*da0073e9SAndroid Build Coastguard Worker        test_op(3, 100, [3, 4, 2, 3, 5, 2], coalesced)
3563*da0073e9SAndroid Build Coastguard Worker        test_op(4, 100, [3, 4, 2, 3, 5, 2], coalesced)
3564*da0073e9SAndroid Build Coastguard Worker
3565*da0073e9SAndroid Build Coastguard Worker
3566*da0073e9SAndroid Build Coastguard Worker    def _check_zero_nnz_softmax_op(self, func, ndim, device, dtype):
3567*da0073e9SAndroid Build Coastguard Worker        # create a sparse tensor with shape (0,..., 3) it has no materialize values
3568*da0073e9SAndroid Build Coastguard Worker        t = torch.sparse_coo_tensor([[] for _ in range(ndim)], [], (0,) * (ndim - 1) + (3,), device=device, dtype=dtype)
3569*da0073e9SAndroid Build Coastguard Worker        out = func(t, 0)
3570*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, torch.zeros_like(t))
3571*da0073e9SAndroid Build Coastguard Worker
3572*da0073e9SAndroid Build Coastguard Worker        # gradient
3573*da0073e9SAndroid Build Coastguard Worker        t = t.requires_grad_()
3574*da0073e9SAndroid Build Coastguard Worker        gradcheck(lambda x: func(x, 0).to_dense(), (t,), masked=True)
3575*da0073e9SAndroid Build Coastguard Worker
3576*da0073e9SAndroid Build Coastguard Worker
3577*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.float)
3578*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
3579*da0073e9SAndroid Build Coastguard Worker    def test_softmax_zero_nnz(self, device, dtype):
3580*da0073e9SAndroid Build Coastguard Worker        self._check_zero_nnz_softmax_op(torch.sparse.softmax, 1, device, dtype)
3581*da0073e9SAndroid Build Coastguard Worker        self._check_zero_nnz_softmax_op(torch.sparse.softmax, 10, device, dtype)
3582*da0073e9SAndroid Build Coastguard Worker
3583*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.float)
3584*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
3585*da0073e9SAndroid Build Coastguard Worker    def test_log_softmax_zero_nnz(self, device, dtype):
3586*da0073e9SAndroid Build Coastguard Worker        self._check_zero_nnz_softmax_op(torch.sparse.log_softmax, 1, device, dtype)
3587*da0073e9SAndroid Build Coastguard Worker        self._check_zero_nnz_softmax_op(torch.sparse.log_softmax, 10, device, dtype)
3588*da0073e9SAndroid Build Coastguard Worker
3589*da0073e9SAndroid Build Coastguard Worker    # TODO: Check after why ROCm's cusparseXcsrgemm2Nnz function doesn't return the same nnz value as CUDA
3590*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm
3591*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
3592*da0073e9SAndroid Build Coastguard Worker    @dtypes(*floating_and_complex_types())
3593*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*floating_types_and(*[torch.half] if SM53OrLater else [],
3594*da0073e9SAndroid Build Coastguard Worker                                      *[torch.bfloat16] if SM80OrLater else [],
3595*da0073e9SAndroid Build Coastguard Worker                                      torch.complex64,
3596*da0073e9SAndroid Build Coastguard Worker                                      *[torch.complex128] if CUSPARSE_SPMM_COMPLEX128_SUPPORTED else []))
3597*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_CROSSREF, "not working with fake tensor")
3598*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2, torch.complex64: 1e-2, torch.float32: 1e-2})
3599*da0073e9SAndroid Build Coastguard Worker    def test_sparse_matmul(self, device, dtype, coalesced):
3600*da0073e9SAndroid Build Coastguard Worker        """
3601*da0073e9SAndroid Build Coastguard Worker        This function test `torch.sparse.mm` when both the mat1 and mat2 are sparse tensors.
3602*da0073e9SAndroid Build Coastguard Worker        """
3603*da0073e9SAndroid Build Coastguard Worker
3604*da0073e9SAndroid Build Coastguard Worker        def ref_sparse_mm(a, b):
3605*da0073e9SAndroid Build Coastguard Worker            return a.to_dense() @ b.to_dense()
3606*da0073e9SAndroid Build Coastguard Worker
3607*da0073e9SAndroid Build Coastguard Worker        def grad_with_custom_sparsity_pattern_test_helper(sparse_dims, nnz, shape_a, shape_b):
3608*da0073e9SAndroid Build Coastguard Worker            def test_grad_dense(a_s, b_s, g_s):
3609*da0073e9SAndroid Build Coastguard Worker                a = a_s.to_dense().detach()
3610*da0073e9SAndroid Build Coastguard Worker                b = b_s.to_dense().detach()
3611*da0073e9SAndroid Build Coastguard Worker                g = g_s.to_dense().detach()
3612*da0073e9SAndroid Build Coastguard Worker
3613*da0073e9SAndroid Build Coastguard Worker                a.requires_grad_(True)
3614*da0073e9SAndroid Build Coastguard Worker                b.requires_grad_(True)
3615*da0073e9SAndroid Build Coastguard Worker                c = a @ b
3616*da0073e9SAndroid Build Coastguard Worker                c.backward(g)
3617*da0073e9SAndroid Build Coastguard Worker                return a.grad.sparse_mask(a_s.coalesce()), b.grad.sparse_mask(b_s.coalesce())
3618*da0073e9SAndroid Build Coastguard Worker
3619*da0073e9SAndroid Build Coastguard Worker            a, _, _ = self._gen_sparse(sparse_dims, nnz, shape_a, dtype, device, coalesced)
3620*da0073e9SAndroid Build Coastguard Worker            b, _, _ = self._gen_sparse(sparse_dims, nnz, shape_b, dtype, device, coalesced)
3621*da0073e9SAndroid Build Coastguard Worker            a.requires_grad_(True)
3622*da0073e9SAndroid Build Coastguard Worker            b.requires_grad_(True)
3623*da0073e9SAndroid Build Coastguard Worker
3624*da0073e9SAndroid Build Coastguard Worker            c = torch.sparse.mm(a, b)
3625*da0073e9SAndroid Build Coastguard Worker            c2 = c.to_dense().detach()
3626*da0073e9SAndroid Build Coastguard Worker            c2 = torch.rand_like(c2)
3627*da0073e9SAndroid Build Coastguard Worker            g = c2.sparse_mask(c.coalesce())
3628*da0073e9SAndroid Build Coastguard Worker
3629*da0073e9SAndroid Build Coastguard Worker            c.backward(g)
3630*da0073e9SAndroid Build Coastguard Worker
3631*da0073e9SAndroid Build Coastguard Worker            a_grad, b_grad = test_grad_dense(a, b, g)
3632*da0073e9SAndroid Build Coastguard Worker
3633*da0073e9SAndroid Build Coastguard Worker            # We convert grad to dense since dense and sparse mm
3634*da0073e9SAndroid Build Coastguard Worker            # implementations handle materialized zeroes differently.
3635*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a.grad.to_dense(), a_grad.to_dense())
3636*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(b.grad.to_dense(), b_grad.to_dense())
3637*da0073e9SAndroid Build Coastguard Worker
3638*da0073e9SAndroid Build Coastguard Worker        def test_sparse_matmul(sparse_dims, nnz, shape_a, shape_b):
3639*da0073e9SAndroid Build Coastguard Worker            a, i_a, v_a = self._gen_sparse(sparse_dims, nnz, shape_a, dtype, device, coalesced)
3640*da0073e9SAndroid Build Coastguard Worker            b, i_b, v_b = self._gen_sparse(sparse_dims, nnz, shape_b, dtype, device, coalesced)
3641*da0073e9SAndroid Build Coastguard Worker
3642*da0073e9SAndroid Build Coastguard Worker            # dense implementation
3643*da0073e9SAndroid Build Coastguard Worker            r1 = ref_sparse_mm(a, b)
3644*da0073e9SAndroid Build Coastguard Worker
3645*da0073e9SAndroid Build Coastguard Worker            # cpp implementation
3646*da0073e9SAndroid Build Coastguard Worker            r2 = torch.sparse.mm(a, b)
3647*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(r1, r2.to_dense())
3648*da0073e9SAndroid Build Coastguard Worker
3649*da0073e9SAndroid Build Coastguard Worker            # Check result is truly coalesced
3650*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(r2.is_coalesced() and is_coalesced_indices(r2))
3651*da0073e9SAndroid Build Coastguard Worker
3652*da0073e9SAndroid Build Coastguard Worker            if dtype in [torch.double, torch.cdouble]:
3653*da0073e9SAndroid Build Coastguard Worker                a.requires_grad_(True)
3654*da0073e9SAndroid Build Coastguard Worker                b.requires_grad_(True)
3655*da0073e9SAndroid Build Coastguard Worker
3656*da0073e9SAndroid Build Coastguard Worker                # check autograd support on sparse matmul
3657*da0073e9SAndroid Build Coastguard Worker                def fn(D1, D2):
3658*da0073e9SAndroid Build Coastguard Worker                    return torch.sparse.mm(D1, D2).to_dense()
3659*da0073e9SAndroid Build Coastguard Worker
3660*da0073e9SAndroid Build Coastguard Worker                if a.is_cuda:
3661*da0073e9SAndroid Build Coastguard Worker                    # For cuda, `nondet_tol` is set with `1e-5`
3662*da0073e9SAndroid Build Coastguard Worker                    # This is because cuSparse sometimes returns approximate zero values like `~e-323`
3663*da0073e9SAndroid Build Coastguard Worker                    # TODO: Check this cuSparse issue.
3664*da0073e9SAndroid Build Coastguard Worker                    # This happens when you do chain multiplication `torch.sparse.mm` operations
3665*da0073e9SAndroid Build Coastguard Worker                    gradcheck(fn, (a, b), nondet_tol=1e-5, masked=True)
3666*da0073e9SAndroid Build Coastguard Worker                else:
3667*da0073e9SAndroid Build Coastguard Worker                    gradcheck(fn, (a, b), masked=True)
3668*da0073e9SAndroid Build Coastguard Worker                grad_with_custom_sparsity_pattern_test_helper(sparse_dims, nnz, shape_a, shape_b)
3669*da0073e9SAndroid Build Coastguard Worker
3670*da0073e9SAndroid Build Coastguard Worker        def test_error_cases():
3671*da0073e9SAndroid Build Coastguard Worker            def fn(sparse_dims, nnz, shape_a, shape_b):
3672*da0073e9SAndroid Build Coastguard Worker                a, i_a, v_a = self._gen_sparse(sparse_dims, nnz, shape_a, dtype, device, coalesced)
3673*da0073e9SAndroid Build Coastguard Worker                b, i_b, v_b = self._gen_sparse(sparse_dims, nnz, shape_b, dtype, device, coalesced)
3674*da0073e9SAndroid Build Coastguard Worker                r2 = torch.sparse.mm(a, b)
3675*da0073e9SAndroid Build Coastguard Worker
3676*da0073e9SAndroid Build Coastguard Worker            # This is not a matrix
3677*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: fn(3, 4, [2, 2, 2], [2, 2, 2]))
3678*da0073e9SAndroid Build Coastguard Worker
3679*da0073e9SAndroid Build Coastguard Worker            # Shapes does not
3680*da0073e9SAndroid Build Coastguard Worker            self.assertRaisesRegex(RuntimeError,
3681*da0073e9SAndroid Build Coastguard Worker                                   r"mat1 and mat2 shapes cannot be multiplied \(2x3 and 4x2\)",
3682*da0073e9SAndroid Build Coastguard Worker                                   lambda: fn(2, 10, [2, 3], [4, 2]))
3683*da0073e9SAndroid Build Coastguard Worker
3684*da0073e9SAndroid Build Coastguard Worker            def different_dtypes():
3685*da0073e9SAndroid Build Coastguard Worker                a, i_a, v_a = self._gen_sparse(2, 10, [2, 2], dtype, device, coalesced)
3686*da0073e9SAndroid Build Coastguard Worker                b, i_b, v_b = self._gen_sparse(2, 10, [2, 2], dtype, device, coalesced)
3687*da0073e9SAndroid Build Coastguard Worker                r2 = torch.sparse.mm(a.to(torch.float64), a.to(torch.float32))
3688*da0073e9SAndroid Build Coastguard Worker
3689*da0073e9SAndroid Build Coastguard Worker            self.assertRaisesRegex(RuntimeError, 'mat1 dtype Double does not match mat2 dtype Float', different_dtypes)
3690*da0073e9SAndroid Build Coastguard Worker
3691*da0073e9SAndroid Build Coastguard Worker        def test_backward_noncontiguous():
3692*da0073e9SAndroid Build Coastguard Worker            # Sparse.mm backward used to wrong with non-contiguous grads,
3693*da0073e9SAndroid Build Coastguard Worker            # see https://github.com/pytorch/pytorch/issues/102493.
3694*da0073e9SAndroid Build Coastguard Worker            n_reps = 7
3695*da0073e9SAndroid Build Coastguard Worker            for _ in range(n_reps):
3696*da0073e9SAndroid Build Coastguard Worker                A = torch.eye(5).to_sparse().requires_grad_(True)
3697*da0073e9SAndroid Build Coastguard Worker                B = torch.eye(5).to_sparse()
3698*da0073e9SAndroid Build Coastguard Worker                out = torch.sparse.mm(A, B)
3699*da0073e9SAndroid Build Coastguard Worker                out.coalesce().values().sum().backward()
3700*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(A.grad, A)
3701*da0073e9SAndroid Build Coastguard Worker
3702*da0073e9SAndroid Build Coastguard Worker        for n in range(2, 5):
3703*da0073e9SAndroid Build Coastguard Worker            for m in range(2, 8):
3704*da0073e9SAndroid Build Coastguard Worker                for p in range(2, 8):
3705*da0073e9SAndroid Build Coastguard Worker                    test_sparse_matmul(2, 10, [n, m], [m, p])
3706*da0073e9SAndroid Build Coastguard Worker
3707*da0073e9SAndroid Build Coastguard Worker        test_sparse_matmul(2, 0, [0, 0], [0, 0])
3708*da0073e9SAndroid Build Coastguard Worker        test_sparse_matmul(2, 0, [0, 10], [10, 0])
3709*da0073e9SAndroid Build Coastguard Worker        test_error_cases()
3710*da0073e9SAndroid Build Coastguard Worker        test_backward_noncontiguous()
3711*da0073e9SAndroid Build Coastguard Worker
3712*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
3713*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
3714*da0073e9SAndroid Build Coastguard Worker    def test_assign(self, device, dtype, coalesced):
3715*da0073e9SAndroid Build Coastguard Worker        def assign_to():
3716*da0073e9SAndroid Build Coastguard Worker            a, i_a, v_a = self._gen_sparse(2, 5, [2, 3], dtype, device, coalesced)
3717*da0073e9SAndroid Build Coastguard Worker            a[0] = 100
3718*da0073e9SAndroid Build Coastguard Worker
3719*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(TypeError, assign_to)
3720*da0073e9SAndroid Build Coastguard Worker
3721*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
3722*da0073e9SAndroid Build Coastguard Worker    def test_full_broadcast_to(self, device, dtype):
3723*da0073e9SAndroid Build Coastguard Worker        def can_broadcast(s0, s1):
3724*da0073e9SAndroid Build Coastguard Worker            s0 = tuple(reversed(s0))
3725*da0073e9SAndroid Build Coastguard Worker            s1 = tuple(reversed(s1))
3726*da0073e9SAndroid Build Coastguard Worker            for i in range(len(s0)):
3727*da0073e9SAndroid Build Coastguard Worker                if s0[i] != 1 and s0[i] != s1[i]:
3728*da0073e9SAndroid Build Coastguard Worker                    return False
3729*da0073e9SAndroid Build Coastguard Worker            return True
3730*da0073e9SAndroid Build Coastguard Worker        sizes = (
3731*da0073e9SAndroid Build Coastguard Worker            (), (1,), (2,), (1, 1), (3, 1), (3, 2), (4, 1, 1), (4, 3, 2)
3732*da0073e9SAndroid Build Coastguard Worker        )
3733*da0073e9SAndroid Build Coastguard Worker        for s0, s1 in itertools.combinations(sizes, r=2):
3734*da0073e9SAndroid Build Coastguard Worker            t = make_tensor(s0, dtype=dtype, device=device, low=-9, high=9)
3735*da0073e9SAndroid Build Coastguard Worker            for sparse_dims in range(1, len(s0) + 1):
3736*da0073e9SAndroid Build Coastguard Worker                s = t.to_sparse(sparse_dims)
3737*da0073e9SAndroid Build Coastguard Worker                if can_broadcast(s0, s1):
3738*da0073e9SAndroid Build Coastguard Worker                    t_res = torch.broadcast_to(t, s1)
3739*da0073e9SAndroid Build Coastguard Worker                    s_res = torch._sparse_broadcast_to(s, s1)
3740*da0073e9SAndroid Build Coastguard Worker                    torch._validate_sparse_coo_tensor_args(s_res._indices(), s_res._values(), s_res.shape)
3741*da0073e9SAndroid Build Coastguard Worker                    if s_res.is_coalesced():
3742*da0073e9SAndroid Build Coastguard Worker                        # ensure that is_coalesced is estimated correctly
3743*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(s_res, torch.sparse_coo_tensor(s_res._indices(), s_res._values(), s_res.shape).coalesce())
3744*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(s_res.to_dense(), t_res)
3745*da0073e9SAndroid Build Coastguard Worker                else:
3746*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(RuntimeError,
3747*da0073e9SAndroid Build Coastguard Worker                                                r"The expanded size of the tensor \(\d\) "
3748*da0073e9SAndroid Build Coastguard Worker                                                r"must match the existing size \(\d\)"):
3749*da0073e9SAndroid Build Coastguard Worker                        torch._sparse_broadcast_to(s, s1)
3750*da0073e9SAndroid Build Coastguard Worker
3751*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
3752*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
3753*da0073e9SAndroid Build Coastguard Worker    def test_sparse_broadcast_to(self, device, dtype, coalesced):
3754*da0073e9SAndroid Build Coastguard Worker        def test(sparse_dims, nnz, with_size, new_size):
3755*da0073e9SAndroid Build Coastguard Worker            x = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0]
3756*da0073e9SAndroid Build Coastguard Worker            y = self.safeToDense(x)
3757*da0073e9SAndroid Build Coastguard Worker            x1 = torch._sparse_broadcast_to(x, new_size)
3758*da0073e9SAndroid Build Coastguard Worker            y1 = y.broadcast_to(new_size)
3759*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(self.safeToDense(x1), y1)
3760*da0073e9SAndroid Build Coastguard Worker
3761*da0073e9SAndroid Build Coastguard Worker        test(4, 6, [7, 3, 1, 3, 0], [7, 3, 4, 3, 0])
3762*da0073e9SAndroid Build Coastguard Worker        test(4, 6, [7, 3, 1, 3, 0], [2, 7, 3, 1, 3, 0])
3763*da0073e9SAndroid Build Coastguard Worker        test(4, 6, [7, 3, 1, 3, 1, 3], [7, 3, 1, 3, 2, 3])
3764*da0073e9SAndroid Build Coastguard Worker        test(4, 6, [7, 3, 1, 3, 2, 1], [7, 3, 1, 3, 2, 3])
3765*da0073e9SAndroid Build Coastguard Worker
3766*da0073e9SAndroid Build Coastguard Worker    def _test_mul_skips(self, device, dtype, coalesced):
3767*da0073e9SAndroid Build Coastguard Worker        skipTestIfUncoalesced = False
3768*da0073e9SAndroid Build Coastguard Worker        # This case always coalesce inputs and that could lead to loss of precision,
3769*da0073e9SAndroid Build Coastguard Worker        # hence it is inhibited for float16/bfloat16 by providing already coalesced tensors.
3770*da0073e9SAndroid Build Coastguard Worker        if not coalesced and dtype in {torch.float16, torch.bfloat16}:
3771*da0073e9SAndroid Build Coastguard Worker            skipTestIfUncoalesced = True
3772*da0073e9SAndroid Build Coastguard Worker        # to_dense is problematic for boolean non-coalesced CUDA tensors
3773*da0073e9SAndroid Build Coastguard Worker        # see https://github.com/pytorch/pytorch/issues/81648
3774*da0073e9SAndroid Build Coastguard Worker        if not coalesced and dtype == torch.bool and torch.device(device).type == "cuda":
3775*da0073e9SAndroid Build Coastguard Worker            skipTestIfUncoalesced = True
3776*da0073e9SAndroid Build Coastguard Worker
3777*da0073e9SAndroid Build Coastguard Worker        if skipTestIfUncoalesced:
3778*da0073e9SAndroid Build Coastguard Worker            self.skipTest(f"Test with dtype={dtype}, device={device} runs only with coalesced inputs")
3779*da0073e9SAndroid Build Coastguard Worker
3780*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
3781*da0073e9SAndroid Build Coastguard Worker    # NOTE: addcmul_out is not implemented for bool.
3782*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.bfloat16, torch.float16))
3783*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2})
3784*da0073e9SAndroid Build Coastguard Worker    def test_sparse_sparse_mul(self, device, dtype, coalesced):
3785*da0073e9SAndroid Build Coastguard Worker        self._test_mul_skips(device, dtype, coalesced)
3786*da0073e9SAndroid Build Coastguard Worker
3787*da0073e9SAndroid Build Coastguard Worker        shape = (2, 3, 4, 10)
3788*da0073e9SAndroid Build Coastguard Worker        nnz = 10
3789*da0073e9SAndroid Build Coastguard Worker
3790*da0073e9SAndroid Build Coastguard Worker        def check(self, x, y):
3791*da0073e9SAndroid Build Coastguard Worker            res_sparse = x * y
3792*da0073e9SAndroid Build Coastguard Worker            res_dense = x.to_dense() * y.to_dense()
3793*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res_sparse.to_dense(), res_dense)
3794*da0073e9SAndroid Build Coastguard Worker
3795*da0073e9SAndroid Build Coastguard Worker        def check_empty(sparse_shape, nnz, dense_shape, coalesce):
3796*da0073e9SAndroid Build Coastguard Worker            from itertools import product
3797*da0073e9SAndroid Build Coastguard Worker            for nnz_val, shape_suffix in product((nnz, 0), ((), (0,))):
3798*da0073e9SAndroid Build Coastguard Worker                empty_sparse_shape = sparse_shape + shape_suffix
3799*da0073e9SAndroid Build Coastguard Worker                empty_dense_shape = dense_shape + shape_suffix
3800*da0073e9SAndroid Build Coastguard Worker                x = self._gen_sparse(sparse_dim, nnz_val, empty_sparse_shape, dtype, device, coalesce)[0]
3801*da0073e9SAndroid Build Coastguard Worker                check(self, x, x)
3802*da0073e9SAndroid Build Coastguard Worker
3803*da0073e9SAndroid Build Coastguard Worker        # TODO: uncomment once backward is implemented for sparse tensors that broadcast in dense dims.
3804*da0073e9SAndroid Build Coastguard Worker        # def check_autograd(x, y):
3805*da0073e9SAndroid Build Coastguard Worker        #     if dtype in {torch.double, torch.cdouble}:
3806*da0073e9SAndroid Build Coastguard Worker        #         xa = x.detach().clone().requires_grad_(True)
3807*da0073e9SAndroid Build Coastguard Worker        #         ya = y.detach().clone().requires_grad_(True)
3808*da0073e9SAndroid Build Coastguard Worker        #         gradcheck(lambda a, b: (a * b).to_dense(), (xa, ya), masked=True)
3809*da0073e9SAndroid Build Coastguard Worker        #         gradcheck(lambda a, b: (a * b).to_dense(), (ya, xa), masked=True)
3810*da0073e9SAndroid Build Coastguard Worker
3811*da0073e9SAndroid Build Coastguard Worker        for dim in range(len(shape) + 1):
3812*da0073e9SAndroid Build Coastguard Worker            sub_shape = shape[dim:]
3813*da0073e9SAndroid Build Coastguard Worker            sparse_dim = len(sub_shape) // 2
3814*da0073e9SAndroid Build Coastguard Worker
3815*da0073e9SAndroid Build Coastguard Worker            check_empty(sub_shape, nnz, shape, coalesced)
3816*da0073e9SAndroid Build Coastguard Worker
3817*da0073e9SAndroid Build Coastguard Worker            x = self._gen_sparse(sparse_dim, nnz, sub_shape, dtype, device, coalesced)[0]
3818*da0073e9SAndroid Build Coastguard Worker            y = self._gen_sparse(sparse_dim, nnz, sub_shape, dtype, device, coalesced)[0]
3819*da0073e9SAndroid Build Coastguard Worker            check(self, x, y)
3820*da0073e9SAndroid Build Coastguard Worker            # TODO: uncomment once supported
3821*da0073e9SAndroid Build Coastguard Worker            # check_autograd(x, y)
3822*da0073e9SAndroid Build Coastguard Worker
3823*da0073e9SAndroid Build Coastguard Worker            # check broadcasting in dense dims
3824*da0073e9SAndroid Build Coastguard Worker            for d in range(sparse_dim, len(sub_shape)):
3825*da0073e9SAndroid Build Coastguard Worker                new_shape = sub_shape[:d] + (1,) + sub_shape[d + 1:]
3826*da0073e9SAndroid Build Coastguard Worker                y = self._gen_sparse(sparse_dim, nnz, new_shape, dtype, device, coalesced)[0]
3827*da0073e9SAndroid Build Coastguard Worker                check(self, x, y)
3828*da0073e9SAndroid Build Coastguard Worker                # TODO: uncomment once supported
3829*da0073e9SAndroid Build Coastguard Worker                # check_autograd(x, y)
3830*da0073e9SAndroid Build Coastguard Worker
3831*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
3832*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))
3833*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2})
3834*da0073e9SAndroid Build Coastguard Worker    def test_sparse_dense_mul(self, device, dtype, coalesced):
3835*da0073e9SAndroid Build Coastguard Worker        self._test_mul_skips(device, dtype, coalesced)
3836*da0073e9SAndroid Build Coastguard Worker
3837*da0073e9SAndroid Build Coastguard Worker        shape = (2, 3, 4, 10)
3838*da0073e9SAndroid Build Coastguard Worker        nnz = 10
3839*da0073e9SAndroid Build Coastguard Worker
3840*da0073e9SAndroid Build Coastguard Worker        def check(self, s, d):
3841*da0073e9SAndroid Build Coastguard Worker            res = d * s
3842*da0073e9SAndroid Build Coastguard Worker
3843*da0073e9SAndroid Build Coastguard Worker            # check commutativity
3844*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, s * d)
3845*da0073e9SAndroid Build Coastguard Worker
3846*da0073e9SAndroid Build Coastguard Worker            # check correctness
3847*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.to_dense(), s.to_dense() * d)
3848*da0073e9SAndroid Build Coastguard Worker
3849*da0073e9SAndroid Build Coastguard Worker            # check in-placeness for dense
3850*da0073e9SAndroid Build Coastguard Worker            if d.dim() >= s.dim():
3851*da0073e9SAndroid Build Coastguard Worker                dc = d.clone()
3852*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(d.mul_(s), dc.mul_(s.to_dense()))
3853*da0073e9SAndroid Build Coastguard Worker
3854*da0073e9SAndroid Build Coastguard Worker            # check in-placeness for sparse
3855*da0073e9SAndroid Build Coastguard Worker            if s.dim() >= d.dim():
3856*da0073e9SAndroid Build Coastguard Worker                # for sparse
3857*da0073e9SAndroid Build Coastguard Worker                sc = s.clone()
3858*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(s.mul_(d).to_dense(), sc.to_dense().mul_(d))
3859*da0073e9SAndroid Build Coastguard Worker
3860*da0073e9SAndroid Build Coastguard Worker        for dim in range(len(shape) + 1):
3861*da0073e9SAndroid Build Coastguard Worker            sub_shape = shape[dim:]
3862*da0073e9SAndroid Build Coastguard Worker            sparse_dim = len(sub_shape) // 2
3863*da0073e9SAndroid Build Coastguard Worker
3864*da0073e9SAndroid Build Coastguard Worker            def check_empty(sparse_shape, nnz, dense_shape, coalesce):
3865*da0073e9SAndroid Build Coastguard Worker                from itertools import product
3866*da0073e9SAndroid Build Coastguard Worker                for nnz_val, shape_suffix in product((nnz, 0), ((), (0,))):
3867*da0073e9SAndroid Build Coastguard Worker                    empty_sparse_shape = sparse_shape + shape_suffix
3868*da0073e9SAndroid Build Coastguard Worker                    empty_dense_shape = dense_shape + shape_suffix
3869*da0073e9SAndroid Build Coastguard Worker                    s = self._gen_sparse(sparse_dim, nnz_val, empty_sparse_shape, dtype, device, coalesce)[0]
3870*da0073e9SAndroid Build Coastguard Worker                    d = make_tensor(empty_dense_shape, dtype=dtype, device=device)
3871*da0073e9SAndroid Build Coastguard Worker                    check(self, s, d)
3872*da0073e9SAndroid Build Coastguard Worker
3873*da0073e9SAndroid Build Coastguard Worker            # check scalar multiplication
3874*da0073e9SAndroid Build Coastguard Worker            s = self._gen_sparse(sparse_dim, nnz, sub_shape, dtype, device, coalesced)[0]
3875*da0073e9SAndroid Build Coastguard Worker            for scalar in (True, 1, 1.0):
3876*da0073e9SAndroid Build Coastguard Worker                res_sparse_right = s * scalar
3877*da0073e9SAndroid Build Coastguard Worker                res_sparse_left = scalar * s
3878*da0073e9SAndroid Build Coastguard Worker                res_dense = s.to_dense() * scalar
3879*da0073e9SAndroid Build Coastguard Worker                # check correctness and dtype
3880*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(s.to(res_sparse_right.dtype), res_sparse_right)
3881*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res_sparse_right, res_sparse_left)
3882*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res_sparse_right.dtype, res_dense.dtype)
3883*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res_sparse_left.dtype, res_dense.dtype)
3884*da0073e9SAndroid Build Coastguard Worker                # check scalar as 0-dim sparse tensor
3885*da0073e9SAndroid Build Coastguard Worker                tscalar = torch.tensor(scalar, device=device)
3886*da0073e9SAndroid Build Coastguard Worker                sscalar = tscalar.to_sparse()
3887*da0073e9SAndroid Build Coastguard Worker                res_sparse_right = s * sscalar
3888*da0073e9SAndroid Build Coastguard Worker                res_sparse_left = sscalar * s
3889*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res_sparse_right, res_sparse_left)
3890*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(s.to(res_sparse_right.dtype), res_sparse_right)
3891*da0073e9SAndroid Build Coastguard Worker
3892*da0073e9SAndroid Build Coastguard Worker            # check non-coalesced 0-dim scalar
3893*da0073e9SAndroid Build Coastguard Worker            # we skip torch.bool because for such tensors
3894*da0073e9SAndroid Build Coastguard Worker            # coalesce.to_dense != to_dense
3895*da0073e9SAndroid Build Coastguard Worker            if dtype == torch.bool:
3896*da0073e9SAndroid Build Coastguard Worker                return
3897*da0073e9SAndroid Build Coastguard Worker
3898*da0073e9SAndroid Build Coastguard Worker            for scalar_dtype in (int, float):
3899*da0073e9SAndroid Build Coastguard Worker                scalar = scalar_dtype(1)
3900*da0073e9SAndroid Build Coastguard Worker                idx = torch.tensor([], device=device).reshape(0, 2)
3901*da0073e9SAndroid Build Coastguard Worker                val = torch.tensor([scalar, scalar], device=device)
3902*da0073e9SAndroid Build Coastguard Worker                sscalar = torch.sparse_coo_tensor(idx, val, ())
3903*da0073e9SAndroid Build Coastguard Worker                res_dense = s.to_dense() * sscalar.to_dense()
3904*da0073e9SAndroid Build Coastguard Worker                self.assertEqual((s * sscalar).to_dense(), res_dense)
3905*da0073e9SAndroid Build Coastguard Worker                self.assertEqual((sscalar * s).to_dense(), res_dense)
3906*da0073e9SAndroid Build Coastguard Worker
3907*da0073e9SAndroid Build Coastguard Worker            # Case 1: sparse broadcasts over dense
3908*da0073e9SAndroid Build Coastguard Worker            s = self._gen_sparse(sparse_dim, nnz, sub_shape, dtype, device, coalesced)[0]
3909*da0073e9SAndroid Build Coastguard Worker            d = make_tensor(shape, dtype=dtype, device=device)
3910*da0073e9SAndroid Build Coastguard Worker            check(self, s, d)
3911*da0073e9SAndroid Build Coastguard Worker            check_empty(sub_shape, nnz, shape, coalesced)
3912*da0073e9SAndroid Build Coastguard Worker
3913*da0073e9SAndroid Build Coastguard Worker            # Case 2: dense broadcasts over sparse
3914*da0073e9SAndroid Build Coastguard Worker            s = self._gen_sparse(3, nnz, shape, dtype, device, coalesced)[0]
3915*da0073e9SAndroid Build Coastguard Worker            d = make_tensor(sub_shape, dtype=dtype, device=device)
3916*da0073e9SAndroid Build Coastguard Worker            check(self, s, d)
3917*da0073e9SAndroid Build Coastguard Worker            check_empty(shape, nnz, sub_shape, coalesced)
3918*da0073e9SAndroid Build Coastguard Worker
3919*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "NumPy is not available")
3920*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
3921*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.bool))
3922*da0073e9SAndroid Build Coastguard Worker    def test_sparse_spdiags(self, device, dtype):
3923*da0073e9SAndroid Build Coastguard Worker
3924*da0073e9SAndroid Build Coastguard Worker        make_diags = functools.partial(make_tensor, dtype=dtype, device=device)
3925*da0073e9SAndroid Build Coastguard Worker        make_offsets = functools.partial(torch.tensor, dtype=torch.long, device=device)
3926*da0073e9SAndroid Build Coastguard Worker
3927*da0073e9SAndroid Build Coastguard Worker        if TEST_SCIPY:
3928*da0073e9SAndroid Build Coastguard Worker            def reference(diags, offsets, shape):
3929*da0073e9SAndroid Build Coastguard Worker                return scipy.sparse.spdiags(diags, offsets, *shape).toarray()
3930*da0073e9SAndroid Build Coastguard Worker
3931*da0073e9SAndroid Build Coastguard Worker        else:
3932*da0073e9SAndroid Build Coastguard Worker            def reference(diags, offsets, shape):
3933*da0073e9SAndroid Build Coastguard Worker                result = torch.zeros(shape, dtype=dtype, device=device)
3934*da0073e9SAndroid Build Coastguard Worker                for i, off in enumerate(offsets):
3935*da0073e9SAndroid Build Coastguard Worker                    res_view = result.diagonal(off)
3936*da0073e9SAndroid Build Coastguard Worker                    data = diags[i]
3937*da0073e9SAndroid Build Coastguard Worker                    if off > 0:
3938*da0073e9SAndroid Build Coastguard Worker                        data = data[off:]
3939*da0073e9SAndroid Build Coastguard Worker
3940*da0073e9SAndroid Build Coastguard Worker                    m = min(res_view.shape[0], data.shape[0])
3941*da0073e9SAndroid Build Coastguard Worker                    res_view[:m] = data[:m]
3942*da0073e9SAndroid Build Coastguard Worker                return result
3943*da0073e9SAndroid Build Coastguard Worker
3944*da0073e9SAndroid Build Coastguard Worker        def check_valid(diags, offsets, shape, layout=None):
3945*da0073e9SAndroid Build Coastguard Worker            ref_out = reference(diags, offsets, shape)
3946*da0073e9SAndroid Build Coastguard Worker            out = torch.sparse.spdiags(diags, offsets, shape, layout=layout)
3947*da0073e9SAndroid Build Coastguard Worker            if layout is None:
3948*da0073e9SAndroid Build Coastguard Worker                ex_layout = torch.sparse_coo
3949*da0073e9SAndroid Build Coastguard Worker            else:
3950*da0073e9SAndroid Build Coastguard Worker                ex_layout = layout
3951*da0073e9SAndroid Build Coastguard Worker            out_dense = out.to_dense()
3952*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(out.layout == ex_layout, f"Output layout {out.layout} expected {ex_layout}")
3953*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out_dense, ref_out, f"Result:\n{out_dense} does not match reference:\n{ref_out}")
3954*da0073e9SAndroid Build Coastguard Worker
3955*da0073e9SAndroid Build Coastguard Worker        def check_invalid(args, error):
3956*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, error):
3957*da0073e9SAndroid Build Coastguard Worker                torch.sparse.spdiags(*args)
3958*da0073e9SAndroid Build Coastguard Worker
3959*da0073e9SAndroid Build Coastguard Worker        def valid_cases():
3960*da0073e9SAndroid Build Coastguard Worker            # some normal cases
3961*da0073e9SAndroid Build Coastguard Worker            yield (make_diags((1, 5)), make_offsets([0]), (5, 5))
3962*da0073e9SAndroid Build Coastguard Worker            yield (make_diags((3, 3)), make_offsets([-1, 0, 1]), (4, 4))
3963*da0073e9SAndroid Build Coastguard Worker            # noncontigous diags
3964*da0073e9SAndroid Build Coastguard Worker            yield (make_diags((5, 4), noncontiguous=True), make_offsets([-1, 1, 0, 2, -2]), (5, 5))
3965*da0073e9SAndroid Build Coastguard Worker            # noncontigous offsets
3966*da0073e9SAndroid Build Coastguard Worker            yield (make_diags((3, 4)), make_offsets([1, -1, 0, -2, 2])[::2], (5, 5))
3967*da0073e9SAndroid Build Coastguard Worker            # noncontigous diags + offsets
3968*da0073e9SAndroid Build Coastguard Worker            yield (make_diags((3, 4), noncontiguous=True), make_offsets([1, -1, 0, -2, 2])[::2], (5, 5))
3969*da0073e9SAndroid Build Coastguard Worker            # correct dimensionality, 2d, 2d , and shapes match, but the number of diagonals is zero
3970*da0073e9SAndroid Build Coastguard Worker            yield (make_diags((0, 3)), make_offsets([]), (3, 3))
3971*da0073e9SAndroid Build Coastguard Worker            # forward rotation of upper diagonals
3972*da0073e9SAndroid Build Coastguard Worker            yield (make_diags((3, 8)), make_offsets([1, 2, 3]), (4, 4))
3973*da0073e9SAndroid Build Coastguard Worker            # rotation exausts input space to read from
3974*da0073e9SAndroid Build Coastguard Worker            yield (make_diags((2, 3)), make_offsets([2, 1]), (3, 3))
3975*da0073e9SAndroid Build Coastguard Worker            # Simple cases repeated with special output format
3976*da0073e9SAndroid Build Coastguard Worker            yield (make_diags((1, 5)), make_offsets([0]), (5, 5), torch.sparse_csc)
3977*da0073e9SAndroid Build Coastguard Worker            yield (make_diags((3, 3)), make_offsets([-1, 0, 1]), (4, 4), torch.sparse_csr)
3978*da0073e9SAndroid Build Coastguard Worker            # vector diags
3979*da0073e9SAndroid Build Coastguard Worker            yield (make_diags((3, )), make_offsets([1]), (4, 4))
3980*da0073e9SAndroid Build Coastguard Worker            # Scalar offset
3981*da0073e9SAndroid Build Coastguard Worker            yield (make_diags((1, 3)), make_offsets(2), (4, 4))
3982*da0073e9SAndroid Build Coastguard Worker            # offsets out of range
3983*da0073e9SAndroid Build Coastguard Worker            yield (make_diags((1, 3)), make_offsets([3]), (3, 3))
3984*da0073e9SAndroid Build Coastguard Worker            yield (make_diags((1, 3)), make_offsets([-3]), (3, 3))
3985*da0073e9SAndroid Build Coastguard Worker
3986*da0073e9SAndroid Build Coastguard Worker        for case in valid_cases():
3987*da0073e9SAndroid Build Coastguard Worker            check_valid(*case)
3988*da0073e9SAndroid Build Coastguard Worker
3989*da0073e9SAndroid Build Coastguard Worker        def invalid_cases():
3990*da0073e9SAndroid Build Coastguard Worker            yield (make_diags((1, 3)), make_offsets([0]), (3, 2, 3)), "Output shape must be 2d"
3991*da0073e9SAndroid Build Coastguard Worker            yield (make_diags((2, 3)), make_offsets([[1, 2], [0, 3]]), (3, 3)), "Offsets must be scalar or vector"
3992*da0073e9SAndroid Build Coastguard Worker            yield (make_diags((3, 2, 3)), make_offsets([0, 1, 2]), (4, 4)), "Diagonals must be vector or matrix"
3993*da0073e9SAndroid Build Coastguard Worker            yield (make_diags((3, 3)), make_offsets([-1, 0]), (3, 3)), \
3994*da0073e9SAndroid Build Coastguard Worker                r"Number of diagonals \(\d\) does not match the number of offsets \(\d\)"
3995*da0073e9SAndroid Build Coastguard Worker            yield (make_diags((5,)), make_offsets([0, 1, 2, 3, 4]), (3, 3)), \
3996*da0073e9SAndroid Build Coastguard Worker                r"Number of diagonals \(\d\) does not match the number of offsets \(\d\)"
3997*da0073e9SAndroid Build Coastguard Worker            yield (make_diags((2, 2)), make_offsets([-1, 0]), (2, 3), torch.strided), \
3998*da0073e9SAndroid Build Coastguard Worker                r"Only output layouts \(\w+, \w+, \w+\) are supported, got \w+"
3999*da0073e9SAndroid Build Coastguard Worker            yield (make_diags((2, 5)), make_offsets([0, 0]), (5, 5)), "Offset tensor contains duplicate values"
4000*da0073e9SAndroid Build Coastguard Worker            yield (make_diags((1, 5)), make_offsets([0]).to(torch.int32), (5, 5)), r"Offset Tensor must have dtype Long but got \w+"
4001*da0073e9SAndroid Build Coastguard Worker
4002*da0073e9SAndroid Build Coastguard Worker
4003*da0073e9SAndroid Build Coastguard Worker        for case, error_regex in invalid_cases():
4004*da0073e9SAndroid Build Coastguard Worker            check_invalid(case, error_regex)
4005*da0073e9SAndroid Build Coastguard Worker
4006*da0073e9SAndroid Build Coastguard Worker    def test_small_nnz_coalesced(self):
4007*da0073e9SAndroid Build Coastguard Worker        # creating a coo tensor with nnz == 0 is always coalesced
4008*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.sparse_coo_tensor([[], []], [], (2, 2)).is_coalesced())
4009*da0073e9SAndroid Build Coastguard Worker        # same for a coo tensor with only 1 nnz
4010*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.sparse_coo_tensor([[0], [0]], [1], (2, 2)).is_coalesced())
4011*da0073e9SAndroid Build Coastguard Worker        # two or more nnz coalesced is false as it can't be verified without an expensive check
4012*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.sparse_coo_tensor([[0, 0], [0, 0]], [1, 2], (2, 2)).is_coalesced())
4013*da0073e9SAndroid Build Coastguard Worker        # even if there are no duplicates
4014*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.sparse_coo_tensor([[0, 1], [0, 1]], [1, 2], (2, 2)).is_coalesced())
4015*da0073e9SAndroid Build Coastguard Worker
4016*da0073e9SAndroid Build Coastguard Worker    @coalescedonoff
4017*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.bool))
4018*da0073e9SAndroid Build Coastguard Worker    def test_sum(self, device, dtype, coalesced):
4019*da0073e9SAndroid Build Coastguard Worker        def run_test(shape, nnz):
4020*da0073e9SAndroid Build Coastguard Worker            a = self._gen_sparse(2, nnz, shape, dtype, device, coalesced)[0]
4021*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a.sum(), a._values().sum())
4022*da0073e9SAndroid Build Coastguard Worker            if dtype.is_floating_point or dtype.is_complex:
4023*da0073e9SAndroid Build Coastguard Worker                a.requires_grad_(True)
4024*da0073e9SAndroid Build Coastguard Worker                a_inter = a.sum()
4025*da0073e9SAndroid Build Coastguard Worker                a_inter.abs().backward()
4026*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
4027*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(a.grad, torch.ones(shape, dtype=dtype, device=device) * torch.sgn(a_inter))
4028*da0073e9SAndroid Build Coastguard Worker        for shape in [(10, 5), (10, 10)]:
4029*da0073e9SAndroid Build Coastguard Worker            run_test(shape, 0)
4030*da0073e9SAndroid Build Coastguard Worker            run_test(shape, max(shape))
4031*da0073e9SAndroid Build Coastguard Worker            run_test(shape, shape[0] * shape[1])
4032*da0073e9SAndroid Build Coastguard Worker
4033*da0073e9SAndroid Build Coastguard Worker
4034*da0073e9SAndroid Build Coastguard Workerclass TestSparseOneOff(TestCase):
4035*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
4036*da0073e9SAndroid Build Coastguard Worker    def test_cuda_from_cpu(self):
4037*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
4038*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
4039*da0073e9SAndroid Build Coastguard Worker                "Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!"):
4040*da0073e9SAndroid Build Coastguard Worker            torch.sparse_coo_tensor(torch.zeros(1, 4).long().cuda(),
4041*da0073e9SAndroid Build Coastguard Worker                                    torch.randn(4, 4, 4),
4042*da0073e9SAndroid Build Coastguard Worker                                    [3, 4, 4])
4043*da0073e9SAndroid Build Coastguard Worker
4044*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
4045*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
4046*da0073e9SAndroid Build Coastguard Worker                "Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!"):
4047*da0073e9SAndroid Build Coastguard Worker            torch.sparse_coo_tensor(torch.zeros(1, 4).long().cuda(),
4048*da0073e9SAndroid Build Coastguard Worker                                    torch.randn(4, 4, 4, 0),
4049*da0073e9SAndroid Build Coastguard Worker                                    [3, 4, 4, 0])
4050*da0073e9SAndroid Build Coastguard Worker
4051*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
4052*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
4053*da0073e9SAndroid Build Coastguard Worker                "Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!"):
4054*da0073e9SAndroid Build Coastguard Worker            torch.sparse_coo_tensor(torch.empty(1, 0).long().cuda(),
4055*da0073e9SAndroid Build Coastguard Worker                                    torch.randn(0, 4, 4, 0),
4056*da0073e9SAndroid Build Coastguard Worker                                    [0, 4, 4, 0])
4057*da0073e9SAndroid Build Coastguard Worker
4058*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
4059*da0073e9SAndroid Build Coastguard Worker    def test_cuda_sparse_cpu_dense_add(self):
4060*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(3, 4, 4)
4061*da0073e9SAndroid Build Coastguard Worker        sparse_y = torch.sparse_coo_tensor(torch.zeros(1, 4).long().cuda(),
4062*da0073e9SAndroid Build Coastguard Worker                                           torch.randn(4, 4, 4).cuda(),
4063*da0073e9SAndroid Build Coastguard Worker                                           [3, 4, 4])
4064*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "add: expected 'self' to be a CUDA tensor, but got a CPU tensor"):
4065*da0073e9SAndroid Build Coastguard Worker            x + sparse_y
4066*da0073e9SAndroid Build Coastguard Worker
4067*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(3, 4, 4, 0)
4068*da0073e9SAndroid Build Coastguard Worker        sparse_y = torch.sparse_coo_tensor(torch.zeros(1, 4).long().cuda(),
4069*da0073e9SAndroid Build Coastguard Worker                                           torch.randn(4, 4, 4, 0).cuda(),
4070*da0073e9SAndroid Build Coastguard Worker                                           [3, 4, 4, 0])
4071*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "add: expected 'self' to be a CUDA tensor, but got a CPU tensor"):
4072*da0073e9SAndroid Build Coastguard Worker            x + sparse_y
4073*da0073e9SAndroid Build Coastguard Worker
4074*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(0, 4, 4, 0)
4075*da0073e9SAndroid Build Coastguard Worker        sparse_y = torch.sparse_coo_tensor(torch.empty(1, 0).long().cuda(),
4076*da0073e9SAndroid Build Coastguard Worker                                           torch.randn(0, 4, 4, 0).cuda(),
4077*da0073e9SAndroid Build Coastguard Worker                                           [0, 4, 4, 0])
4078*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "add: expected 'self' to be a CUDA tensor, but got a CPU tensor"):
4079*da0073e9SAndroid Build Coastguard Worker            x + sparse_y
4080*da0073e9SAndroid Build Coastguard Worker
4081*da0073e9SAndroid Build Coastguard Worker
4082*da0073e9SAndroid Build Coastguard Workerdef _sparse_to_dense(tensor):
4083*da0073e9SAndroid Build Coastguard Worker    if tensor.dtype != torch.bool:
4084*da0073e9SAndroid Build Coastguard Worker        return tensor.to_dense(masked_grad=True)
4085*da0073e9SAndroid Build Coastguard Worker
4086*da0073e9SAndroid Build Coastguard Worker    # to_dense uses coalesce which isn't implemented for bool
4087*da0073e9SAndroid Build Coastguard Worker    return tensor.to(torch.int8).to_dense().to(torch.bool)
4088*da0073e9SAndroid Build Coastguard Worker
4089*da0073e9SAndroid Build Coastguard Worker
4090*da0073e9SAndroid Build Coastguard Worker_sparse_unary_ops = ops(sparse_unary_ufuncs, dtypes=OpDTypes.supported,
4091*da0073e9SAndroid Build Coastguard Worker                        allowed_dtypes=all_types_and_complex())
4092*da0073e9SAndroid Build Coastguard Workerclass TestSparseUnaryUfuncs(TestCase):
4093*da0073e9SAndroid Build Coastguard Worker    exact_dtype = True
4094*da0073e9SAndroid Build Coastguard Worker
4095*da0073e9SAndroid Build Coastguard Worker
4096*da0073e9SAndroid Build Coastguard Worker    @_sparse_unary_ops
4097*da0073e9SAndroid Build Coastguard Worker    def test_sparse_consistency(self, device, dtype, op):
4098*da0073e9SAndroid Build Coastguard Worker        sample = first_sample(self, op.sample_inputs(device, dtype))
4099*da0073e9SAndroid Build Coastguard Worker        assert isinstance(sample.input, torch.Tensor)
4100*da0073e9SAndroid Build Coastguard Worker
4101*da0073e9SAndroid Build Coastguard Worker        expected = op(sample.input, *sample.args, **sample.kwargs)
4102*da0073e9SAndroid Build Coastguard Worker        assert torch.is_tensor(expected)
4103*da0073e9SAndroid Build Coastguard Worker        output = op(sample.input.to_sparse(), *sample.args, **sample.kwargs)
4104*da0073e9SAndroid Build Coastguard Worker        assert torch.is_tensor(output)
4105*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(_sparse_to_dense(output), expected)
4106*da0073e9SAndroid Build Coastguard Worker
4107*da0073e9SAndroid Build Coastguard Worker    @_sparse_unary_ops
4108*da0073e9SAndroid Build Coastguard Worker    def test_out(self, device, dtype, op):
4109*da0073e9SAndroid Build Coastguard Worker        if not op.supports_out:
4110*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Skipped! Out not supported")
4111*da0073e9SAndroid Build Coastguard Worker
4112*da0073e9SAndroid Build Coastguard Worker        sample = first_sample(self, op.sample_inputs(device, dtype))
4113*da0073e9SAndroid Build Coastguard Worker        sample.input = sample.input.to_sparse()
4114*da0073e9SAndroid Build Coastguard Worker        expect = op(sample.input, *sample.args, **sample.kwargs)
4115*da0073e9SAndroid Build Coastguard Worker
4116*da0073e9SAndroid Build Coastguard Worker        out = torch.sparse_coo_tensor(sample.input.shape, device=device,
4117*da0073e9SAndroid Build Coastguard Worker                                      dtype=expect.dtype)
4118*da0073e9SAndroid Build Coastguard Worker        op(sample.input, *sample.args, **sample.kwargs, out=out)
4119*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, expect)
4120*da0073e9SAndroid Build Coastguard Worker
4121*da0073e9SAndroid Build Coastguard Worker    @_sparse_unary_ops
4122*da0073e9SAndroid Build Coastguard Worker    def test_inplace(self, device, dtype, op):
4123*da0073e9SAndroid Build Coastguard Worker        if op.inplace_variant is None:
4124*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Skipped! Out not supported")
4125*da0073e9SAndroid Build Coastguard Worker
4126*da0073e9SAndroid Build Coastguard Worker        sample = first_sample(self, op.sample_inputs(device, dtype))
4127*da0073e9SAndroid Build Coastguard Worker        sample.input = sample.input.to_sparse().coalesce()
4128*da0073e9SAndroid Build Coastguard Worker        expect = op(sample.input, *sample.args, **sample.kwargs)
4129*da0073e9SAndroid Build Coastguard Worker
4130*da0073e9SAndroid Build Coastguard Worker        if not torch.can_cast(expect.dtype, dtype):
4131*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "result type .* can't be cast to"):
4132*da0073e9SAndroid Build Coastguard Worker                op.inplace_variant(sample.input, *sample.args, **sample.kwargs)
4133*da0073e9SAndroid Build Coastguard Worker            return
4134*da0073e9SAndroid Build Coastguard Worker
4135*da0073e9SAndroid Build Coastguard Worker        actual = op.inplace_variant(sample.input, *sample.args, **sample.kwargs)
4136*da0073e9SAndroid Build Coastguard Worker        self.assertIs(actual, sample.input)
4137*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, expect)
4138*da0073e9SAndroid Build Coastguard Worker
4139*da0073e9SAndroid Build Coastguard Worker    @_sparse_unary_ops
4140*da0073e9SAndroid Build Coastguard Worker    def test_sparse_zero_dims(self, device, dtype, op):
4141*da0073e9SAndroid Build Coastguard Worker        # test 0x0 sparse_coo_tensor
4142*da0073e9SAndroid Build Coastguard Worker        indices = torch.empty(2, 0, dtype=torch.int64)
4143*da0073e9SAndroid Build Coastguard Worker        values = torch.empty(0, dtype=dtype)
4144*da0073e9SAndroid Build Coastguard Worker        sparse_0x0 = torch.sparse_coo_tensor(indices, values, (0, 0))
4145*da0073e9SAndroid Build Coastguard Worker        expected = torch.sparse_coo_tensor(indices, op(values), (0, 0))
4146*da0073e9SAndroid Build Coastguard Worker        actual = op(sparse_0x0)
4147*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
4148*da0073e9SAndroid Build Coastguard Worker
4149*da0073e9SAndroid Build Coastguard Worker    @_sparse_unary_ops
4150*da0073e9SAndroid Build Coastguard Worker    def test_sparse_zeros(self, device, dtype, op):
4151*da0073e9SAndroid Build Coastguard Worker        samples = op.sample_inputs(device, dtype)
4152*da0073e9SAndroid Build Coastguard Worker
4153*da0073e9SAndroid Build Coastguard Worker        zero_input = torch.zeros((), device=device, dtype=dtype)
4154*da0073e9SAndroid Build Coastguard Worker        sparse_input = torch.sparse_coo_tensor((), dtype=dtype, device=device)
4155*da0073e9SAndroid Build Coastguard Worker
4156*da0073e9SAndroid Build Coastguard Worker        expect = op(zero_input)
4157*da0073e9SAndroid Build Coastguard Worker        actual = op(sparse_input)
4158*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expect, _sparse_to_dense(actual))
4159*da0073e9SAndroid Build Coastguard Worker
4160*da0073e9SAndroid Build Coastguard Worker    @ops(sparse_unary_ufuncs, dtypes=OpDTypes.supported,
4161*da0073e9SAndroid Build Coastguard Worker         allowed_dtypes=[torch.double, torch.cdouble])
4162*da0073e9SAndroid Build Coastguard Worker    def test_sparse_fn_grad(self, device, dtype, op):
4163*da0073e9SAndroid Build Coastguard Worker        if not op.supports_autograd:
4164*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Skipped! Op doesn't support autograd")
4165*da0073e9SAndroid Build Coastguard Worker
4166*da0073e9SAndroid Build Coastguard Worker        for sample in op.sample_inputs(device, dtype):
4167*da0073e9SAndroid Build Coastguard Worker            sparse_input = sample.input.to_sparse().detach().requires_grad_(True)
4168*da0073e9SAndroid Build Coastguard Worker
4169*da0073e9SAndroid Build Coastguard Worker            def fn(x):
4170*da0073e9SAndroid Build Coastguard Worker                return _sparse_to_dense(
4171*da0073e9SAndroid Build Coastguard Worker                    op(x, *sample.args, **sample.kwargs))
4172*da0073e9SAndroid Build Coastguard Worker
4173*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(gradcheck(
4174*da0073e9SAndroid Build Coastguard Worker                fn,
4175*da0073e9SAndroid Build Coastguard Worker                (sparse_input,),
4176*da0073e9SAndroid Build Coastguard Worker                check_batched_grad=False,
4177*da0073e9SAndroid Build Coastguard Worker                check_grad_dtypes=True,
4178*da0073e9SAndroid Build Coastguard Worker                nondet_tol=op.gradcheck_nondet_tol,
4179*da0073e9SAndroid Build Coastguard Worker                fast_mode=op.gradcheck_fast_mode,
4180*da0073e9SAndroid Build Coastguard Worker                masked=True))
4181*da0073e9SAndroid Build Coastguard Worker
4182*da0073e9SAndroid Build Coastguard Worker
4183*da0073e9SAndroid Build Coastguard Workerclass TestSparseMaskedReductions(TestCase):
4184*da0073e9SAndroid Build Coastguard Worker    exact_dtype = True
4185*da0073e9SAndroid Build Coastguard Worker
4186*da0073e9SAndroid Build Coastguard Worker    fp16_low_precision_list = {
4187*da0073e9SAndroid Build Coastguard Worker        'masked.prod',
4188*da0073e9SAndroid Build Coastguard Worker    }
4189*da0073e9SAndroid Build Coastguard Worker
4190*da0073e9SAndroid Build Coastguard Worker    @ops(sparse_masked_reduction_ops)
4191*da0073e9SAndroid Build Coastguard Worker    def test_future_empty_dim(self, device, dtype, op):
4192*da0073e9SAndroid Build Coastguard Worker        """Currently, `dim=()` in reductions operations means "reduce over
4193*da0073e9SAndroid Build Coastguard Worker        all dimensions" while in future, it will read "no reduce". See
4194*da0073e9SAndroid Build Coastguard Worker        https://github.com/pytorch/pytorch/issues/29137
4195*da0073e9SAndroid Build Coastguard Worker
4196*da0073e9SAndroid Build Coastguard Worker        For sparse masked reductions, we'll implement the current behavior.
4197*da0073e9SAndroid Build Coastguard Worker
4198*da0073e9SAndroid Build Coastguard Worker        For testing, we'll use samples with `dim=0` and map it to
4199*da0073e9SAndroid Build Coastguard Worker        `dim=()` until
4200*da0073e9SAndroid Build Coastguard Worker        torch.testing._internal.common_methods_invocations._generate_reduction_kwargs
4201*da0073e9SAndroid Build Coastguard Worker        is made to generate samples with `dim=()` for non-scalar
4202*da0073e9SAndroid Build Coastguard Worker        inputs. With this and after gh-29137 is resolved, this test
4203*da0073e9SAndroid Build Coastguard Worker        can be deleted. See also `torch.masked._canonical_dim`
4204*da0073e9SAndroid Build Coastguard Worker        implementation about changing the `dim=()` behavior.
4205*da0073e9SAndroid Build Coastguard Worker        """
4206*da0073e9SAndroid Build Coastguard Worker
4207*da0073e9SAndroid Build Coastguard Worker        samples = op.sample_inputs_func(op, device, dtype, requires_grad=False)
4208*da0073e9SAndroid Build Coastguard Worker        op_name = op.name.replace('masked.', '')
4209*da0073e9SAndroid Build Coastguard Worker        for sample_input in samples:
4210*da0073e9SAndroid Build Coastguard Worker            if sample_input.kwargs.get('dim') != 0:
4211*da0073e9SAndroid Build Coastguard Worker                continue
4212*da0073e9SAndroid Build Coastguard Worker            sample_input_kwargs = dict(sample_input.kwargs)
4213*da0073e9SAndroid Build Coastguard Worker            sample_input_kwargs['dim'] = ()    # reduce over all dimensions
4214*da0073e9SAndroid Build Coastguard Worker
4215*da0073e9SAndroid Build Coastguard Worker            t = sample_input.input
4216*da0073e9SAndroid Build Coastguard Worker            mask = sample_input_kwargs.get('mask')
4217*da0073e9SAndroid Build Coastguard Worker            if mask is None and op_name in {'prod', 'amax', 'amin'}:
4218*da0073e9SAndroid Build Coastguard Worker                # FIXME: for now reductions with non-zero reduction identity and
4219*da0073e9SAndroid Build Coastguard Worker                # unspecified mask are not supported for sparse COO
4220*da0073e9SAndroid Build Coastguard Worker                # tensors, see torch.masked.prod implementation
4221*da0073e9SAndroid Build Coastguard Worker                # for details.
4222*da0073e9SAndroid Build Coastguard Worker                continue
4223*da0073e9SAndroid Build Coastguard Worker            sparse_op_kwargs = dict(sample_input_kwargs)
4224*da0073e9SAndroid Build Coastguard Worker            actual = op(t.to_sparse(), *sample_input.args, **sample_input_kwargs)
4225*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual.layout, torch.sparse_coo)
4226*da0073e9SAndroid Build Coastguard Worker
4227*da0073e9SAndroid Build Coastguard Worker            expected = op(t, *sample_input.args, **sample_input_kwargs).to_sparse()
4228*da0073e9SAndroid Build Coastguard Worker            atol = None
4229*da0073e9SAndroid Build Coastguard Worker            rtol = None
4230*da0073e9SAndroid Build Coastguard Worker            if op.name in self.fp16_low_precision_list and dtype == torch.half:
4231*da0073e9SAndroid Build Coastguard Worker                atol = 1e-5
4232*da0073e9SAndroid Build Coastguard Worker                rtol = 2e-3
4233*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected, atol=atol, rtol=rtol)
4234*da0073e9SAndroid Build Coastguard Worker
4235*da0073e9SAndroid Build Coastguard Worker
4236*da0073e9SAndroid Build Coastguard Workerclass TestSparseMeta(TestCase):
4237*da0073e9SAndroid Build Coastguard Worker    exact_dtype = True
4238*da0073e9SAndroid Build Coastguard Worker
4239*da0073e9SAndroid Build Coastguard Worker    def _test_meta_sparse_coo(self, dtype):
4240*da0073e9SAndroid Build Coastguard Worker        r = torch.empty(4, 4, layout=torch.sparse_coo, device='meta', dtype=dtype)
4241*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(r.is_meta)
4242*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r.device.type, "meta")
4243*da0073e9SAndroid Build Coastguard Worker        r2 = torch.empty_like(r)
4244*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(r2.is_meta)
4245*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r, r2)
4246*da0073e9SAndroid Build Coastguard Worker        r3 = torch.sparse_coo_tensor(size=(4, 4), device='meta', dtype=dtype)
4247*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(r3.is_meta)
4248*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r, r3)
4249*da0073e9SAndroid Build Coastguard Worker        r.sparse_resize_((4, 4), 1, 1)
4250*da0073e9SAndroid Build Coastguard Worker        r.sparse_resize_and_clear_((4, 4, 4), 2, 1)
4251*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r.sparse_dim(), 2)
4252*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r.dense_dim(), 1)
4253*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r._dimV(), 1)
4254*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r._nnz(), 0)
4255*da0073e9SAndroid Build Coastguard Worker        # nnz zero sparse tensors should always be coalesced at creation
4256*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r.is_coalesced(), True)
4257*da0073e9SAndroid Build Coastguard Worker        # but we can force them into the uncoalesed state
4258*da0073e9SAndroid Build Coastguard Worker        r._coalesced_(False)
4259*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r.is_coalesced(), False)
4260*da0073e9SAndroid Build Coastguard Worker        # return the coalesced state for indices/values access
4261*da0073e9SAndroid Build Coastguard Worker        r._coalesced_(True)
4262*da0073e9SAndroid Build Coastguard Worker        # TODO: this sort of aliasing will need to be handled by
4263*da0073e9SAndroid Build Coastguard Worker        # functionalization
4264*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r._indices(), torch.empty(2, 0, device='meta', dtype=torch.int64))
4265*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r._values(), torch.empty(0, 4, device='meta', dtype=dtype))
4266*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r.indices(), torch.empty(2, 0, device='meta', dtype=torch.int64))
4267*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r.values(), torch.empty(0, 4, device='meta', dtype=dtype))
4268*da0073e9SAndroid Build Coastguard Worker
4269*da0073e9SAndroid Build Coastguard Worker    def _test_meta_sparse_compressed(self, dtype, layout, batchsize, densesize):
4270*da0073e9SAndroid Build Coastguard Worker        index_dtype = torch.int64
4271*da0073e9SAndroid Build Coastguard Worker        blocksize = (2, 3) if layout in {torch.sparse_bsr, torch.sparse_bsc} else ()
4272*da0073e9SAndroid Build Coastguard Worker        sparsesize = (4, 6)
4273*da0073e9SAndroid Build Coastguard Worker        nnz = 0
4274*da0073e9SAndroid Build Coastguard Worker
4275*da0073e9SAndroid Build Coastguard Worker        shape = (*batchsize, *sparsesize, *densesize)
4276*da0073e9SAndroid Build Coastguard Worker        compressed_dim = 0 if layout in {torch.sparse_csr, torch.sparse_bsr} else 1
4277*da0073e9SAndroid Build Coastguard Worker        nof_compressed_indices = (sparsesize[compressed_dim] // blocksize[compressed_dim] + 1 if blocksize
4278*da0073e9SAndroid Build Coastguard Worker                                  else sparsesize[compressed_dim] + 1)
4279*da0073e9SAndroid Build Coastguard Worker        compressed_indices = torch.empty((*batchsize, nof_compressed_indices), device='meta', dtype=index_dtype)
4280*da0073e9SAndroid Build Coastguard Worker        plain_indices = torch.empty((*batchsize, nnz), device='meta', dtype=index_dtype)
4281*da0073e9SAndroid Build Coastguard Worker
4282*da0073e9SAndroid Build Coastguard Worker        values = torch.empty((*batchsize, nnz, *blocksize, *densesize), device='meta', dtype=dtype)
4283*da0073e9SAndroid Build Coastguard Worker        r = torch.sparse_compressed_tensor(
4284*da0073e9SAndroid Build Coastguard Worker            compressed_indices,
4285*da0073e9SAndroid Build Coastguard Worker            plain_indices,
4286*da0073e9SAndroid Build Coastguard Worker            values,
4287*da0073e9SAndroid Build Coastguard Worker            shape,
4288*da0073e9SAndroid Build Coastguard Worker            layout=layout
4289*da0073e9SAndroid Build Coastguard Worker        )
4290*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(r.is_meta)
4291*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r.device.type, "meta")
4292*da0073e9SAndroid Build Coastguard Worker
4293*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r.sparse_dim(), 2)
4294*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r.dense_dim(), len(densesize))
4295*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r._nnz(), nnz)
4296*da0073e9SAndroid Build Coastguard Worker        batch_dims = r.ndim - r.sparse_dim() - r.dense_dim()
4297*da0073e9SAndroid Build Coastguard Worker        r_blocksize = r.values().shape[batch_dims + 1: batch_dims + 1 + len(blocksize)]
4298*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r_blocksize, blocksize)
4299*da0073e9SAndroid Build Coastguard Worker
4300*da0073e9SAndroid Build Coastguard Worker        r_compressed_indices = r.crow_indices() if layout in {torch.sparse_csr, torch.sparse_bsr} else r.ccol_indices()
4301*da0073e9SAndroid Build Coastguard Worker        r_plain_indices = r.col_indices() if layout in {torch.sparse_csr, torch.sparse_bsr} else r.row_indices()
4302*da0073e9SAndroid Build Coastguard Worker
4303*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r_compressed_indices,
4304*da0073e9SAndroid Build Coastguard Worker                         torch.empty((*batchsize, nof_compressed_indices), device='meta', dtype=index_dtype))
4305*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r_plain_indices, torch.empty((*batchsize, nnz), device='meta', dtype=index_dtype))
4306*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r.values(), torch.empty((*batchsize, nnz, *blocksize, *densesize), device='meta', dtype=dtype))
4307*da0073e9SAndroid Build Coastguard Worker
4308*da0073e9SAndroid Build Coastguard Worker        r2 = torch.empty_like(r)
4309*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(r2.is_meta)
4310*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r2, r)
4311*da0073e9SAndroid Build Coastguard Worker
4312*da0073e9SAndroid Build Coastguard Worker        if layout in {torch.sparse_csr, torch.sparse_csc}:
4313*da0073e9SAndroid Build Coastguard Worker            r3 = torch.empty((*batchsize, *sparsesize), dtype=dtype, layout=layout, device="meta")
4314*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(r3.is_meta)
4315*da0073e9SAndroid Build Coastguard Worker            if not densesize:
4316*da0073e9SAndroid Build Coastguard Worker                # dense dimensions cannot be specified for torch.empty
4317*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(r3, r)
4318*da0073e9SAndroid Build Coastguard Worker
4319*da0073e9SAndroid Build Coastguard Worker    @all_sparse_layouts('layout', include_strided=False)
4320*da0073e9SAndroid Build Coastguard Worker    @parametrize("dtype", [torch.float64])
4321*da0073e9SAndroid Build Coastguard Worker    def test_meta(self, dtype, layout):
4322*da0073e9SAndroid Build Coastguard Worker        if layout is torch.sparse_coo:
4323*da0073e9SAndroid Build Coastguard Worker            self._test_meta_sparse_coo(dtype)
4324*da0073e9SAndroid Build Coastguard Worker        else:
4325*da0073e9SAndroid Build Coastguard Worker            for batchsize, densesize in itertools.product([(), (2,)], [(), (3,)]):
4326*da0073e9SAndroid Build Coastguard Worker                self._test_meta_sparse_compressed(dtype, layout, batchsize, densesize)
4327*da0073e9SAndroid Build Coastguard Worker
4328*da0073e9SAndroid Build Coastguard Worker    def _test_print_meta_data(self, dtype, layout, batchsize, sparsesize, densesize):
4329*da0073e9SAndroid Build Coastguard Worker        index_dtype = torch.int64
4330*da0073e9SAndroid Build Coastguard Worker        nnz = 0
4331*da0073e9SAndroid Build Coastguard Worker        blocksize = (2, 3) if layout in {torch.sparse_bsr, torch.sparse_bsc} else ()
4332*da0073e9SAndroid Build Coastguard Worker        shape = (*batchsize, *sparsesize, *densesize)
4333*da0073e9SAndroid Build Coastguard Worker        values = torch.empty((*batchsize, nnz, *blocksize, *densesize), device='meta', dtype=dtype)
4334*da0073e9SAndroid Build Coastguard Worker        if layout is torch.sparse_coo:
4335*da0073e9SAndroid Build Coastguard Worker            indices = torch.empty((len(sparsesize), nnz), device='meta', dtype=index_dtype)
4336*da0073e9SAndroid Build Coastguard Worker            x = torch.sparse_coo_tensor(indices, values, shape)
4337*da0073e9SAndroid Build Coastguard Worker        else:
4338*da0073e9SAndroid Build Coastguard Worker            compressed_dim = 0 if layout in {torch.sparse_csr, torch.sparse_bsr} else 1
4339*da0073e9SAndroid Build Coastguard Worker            nof_compressed_indices = (sparsesize[compressed_dim] // blocksize[compressed_dim] + 1 if blocksize
4340*da0073e9SAndroid Build Coastguard Worker                                      else sparsesize[compressed_dim] + 1)
4341*da0073e9SAndroid Build Coastguard Worker            compressed_indices = torch.empty((*batchsize, nof_compressed_indices), device='meta', dtype=index_dtype)
4342*da0073e9SAndroid Build Coastguard Worker            plain_indices = torch.empty((*batchsize, nnz), device='meta', dtype=index_dtype)
4343*da0073e9SAndroid Build Coastguard Worker            x = torch.sparse_compressed_tensor(
4344*da0073e9SAndroid Build Coastguard Worker                compressed_indices,
4345*da0073e9SAndroid Build Coastguard Worker                plain_indices,
4346*da0073e9SAndroid Build Coastguard Worker                values,
4347*da0073e9SAndroid Build Coastguard Worker                shape,
4348*da0073e9SAndroid Build Coastguard Worker                layout=layout
4349*da0073e9SAndroid Build Coastguard Worker            )
4350*da0073e9SAndroid Build Coastguard Worker
4351*da0073e9SAndroid Build Coastguard Worker        printed = []
4352*da0073e9SAndroid Build Coastguard Worker        printed.append(f"########## {dtype}/{index_dtype}/size={batchsize}+{sparsesize}+{blocksize}+{densesize} ##########")
4353*da0073e9SAndroid Build Coastguard Worker        printed.append("# sparse meta tensor")
4354*da0073e9SAndroid Build Coastguard Worker        printed.append(str(x))
4355*da0073e9SAndroid Build Coastguard Worker
4356*da0073e9SAndroid Build Coastguard Worker        return printed
4357*da0073e9SAndroid Build Coastguard Worker
4358*da0073e9SAndroid Build Coastguard Worker    @all_sparse_layouts('layout', include_strided=False)
4359*da0073e9SAndroid Build Coastguard Worker    @parametrize("dtype", [torch.float64])
4360*da0073e9SAndroid Build Coastguard Worker    def test_print_meta(self, dtype, layout):
4361*da0073e9SAndroid Build Coastguard Worker        printed = []
4362*da0073e9SAndroid Build Coastguard Worker        for batchsize, sparsesize, densesize in itertools.product(
4363*da0073e9SAndroid Build Coastguard Worker                [(), (2,)], [(4, 6), (3, 5, 7)], [(), (3,)]
4364*da0073e9SAndroid Build Coastguard Worker        ):
4365*da0073e9SAndroid Build Coastguard Worker            if layout is torch.sparse_coo and batchsize:
4366*da0073e9SAndroid Build Coastguard Worker                # COO tensors don't have batch dimensions
4367*da0073e9SAndroid Build Coastguard Worker                continue
4368*da0073e9SAndroid Build Coastguard Worker            if layout is not torch.sparse_coo and len(sparsesize) != 2:
4369*da0073e9SAndroid Build Coastguard Worker                # CSR/CSC/BSR/BSC tensors must have 2 sparse dimensions
4370*da0073e9SAndroid Build Coastguard Worker                continue
4371*da0073e9SAndroid Build Coastguard Worker            printed += self._test_print_meta_data(dtype, layout, batchsize, sparsesize, densesize)
4372*da0073e9SAndroid Build Coastguard Worker
4373*da0073e9SAndroid Build Coastguard Worker        orig_maxDiff = self.maxDiff
4374*da0073e9SAndroid Build Coastguard Worker        self.maxDiff = None
4375*da0073e9SAndroid Build Coastguard Worker        try:
4376*da0073e9SAndroid Build Coastguard Worker            self.assertExpected('\n'.join(printed))
4377*da0073e9SAndroid Build Coastguard Worker            self.maxDiff = orig_maxDiff
4378*da0073e9SAndroid Build Coastguard Worker        except Exception:
4379*da0073e9SAndroid Build Coastguard Worker            self.maxDiff = orig_maxDiff
4380*da0073e9SAndroid Build Coastguard Worker            raise
4381*da0073e9SAndroid Build Coastguard Worker
4382*da0073e9SAndroid Build Coastguard Worker    def assertEqualMeta(self, x, y, expected_nnz):
4383*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.layout, y.layout)
4384*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.shape, y.shape)
4385*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.dtype, y.dtype)
4386*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.sparse_dim(), y.sparse_dim())
4387*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.dense_dim(), y.dense_dim())
4388*da0073e9SAndroid Build Coastguard Worker
4389*da0073e9SAndroid Build Coastguard Worker        def assertEqualAttrs(x, y, expected_shape):
4390*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.shape, expected_shape)
4391*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.dtype, y.dtype)
4392*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.layout, y.layout)
4393*da0073e9SAndroid Build Coastguard Worker            if not x.is_meta:
4394*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x.device, y.device)
4395*da0073e9SAndroid Build Coastguard Worker
4396*da0073e9SAndroid Build Coastguard Worker        if x.layout is torch.sparse_coo:
4397*da0073e9SAndroid Build Coastguard Worker            assertEqualAttrs(x._indices(), y._indices(), (*y._indices().shape[:-1], expected_nnz))
4398*da0073e9SAndroid Build Coastguard Worker            assertEqualAttrs(x._values(), y._values(), (expected_nnz, *y._values().shape[1:]))
4399*da0073e9SAndroid Build Coastguard Worker        elif x.layout in {torch.sparse_csr, torch.sparse_bsr}:
4400*da0073e9SAndroid Build Coastguard Worker            assertEqualAttrs(x.crow_indices(), y.crow_indices(), y.crow_indices().shape)
4401*da0073e9SAndroid Build Coastguard Worker            assertEqualAttrs(x.col_indices(), y.col_indices(), (*y.col_indices().shape[:-1], expected_nnz))
4402*da0073e9SAndroid Build Coastguard Worker            batch_dim = x.col_indices().ndim - 1
4403*da0073e9SAndroid Build Coastguard Worker            values_shape = (*y.values().shape[:batch_dim], expected_nnz, *y.values().shape[batch_dim + 1:])
4404*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.values().layout, y.values().layout)
4405*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.values().dtype, y.values().dtype)
4406*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.values().shape, values_shape)
4407*da0073e9SAndroid Build Coastguard Worker        elif x.layout in {torch.sparse_csc, torch.sparse_bsc}:
4408*da0073e9SAndroid Build Coastguard Worker            assertEqualAttrs(x.ccol_indices(), y.ccol_indices(), y.ccol_indices().shape)
4409*da0073e9SAndroid Build Coastguard Worker            assertEqualAttrs(x.row_indices(), y.row_indices(), (*y.row_indices().shape[:-1], expected_nnz))
4410*da0073e9SAndroid Build Coastguard Worker            batch_dim = x.row_indices().ndim - 1
4411*da0073e9SAndroid Build Coastguard Worker            values_shape = (*y.values().shape[:batch_dim], expected_nnz, *y.values().shape[batch_dim + 1:])
4412*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.values().layout, y.values().layout)
4413*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.values().dtype, y.values().dtype)
4414*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.values().shape, values_shape)
4415*da0073e9SAndroid Build Coastguard Worker
4416*da0073e9SAndroid Build Coastguard Worker    @all_sparse_layouts('layout', include_strided=False)
4417*da0073e9SAndroid Build Coastguard Worker    @parametrize("dtype", [torch.float64])
4418*da0073e9SAndroid Build Coastguard Worker    def test_to_meta(self, dtype, layout):
4419*da0073e9SAndroid Build Coastguard Worker        index_dtype = torch.int64
4420*da0073e9SAndroid Build Coastguard Worker        device = 'cpu'
4421*da0073e9SAndroid Build Coastguard Worker        for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype):
4422*da0073e9SAndroid Build Coastguard Worker            m = t.to(device="meta")
4423*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m.device.type, "meta")
4424*da0073e9SAndroid Build Coastguard Worker            self.assertEqualMeta(m, t, 0)
4425*da0073e9SAndroid Build Coastguard Worker
4426*da0073e9SAndroid Build Coastguard Worker    @all_sparse_layouts('layout', include_strided=False)
4427*da0073e9SAndroid Build Coastguard Worker    @parametrize("dtype", [torch.float64])
4428*da0073e9SAndroid Build Coastguard Worker    def test_zeros_like_meta(self, dtype, layout):
4429*da0073e9SAndroid Build Coastguard Worker        index_dtype = torch.int64
4430*da0073e9SAndroid Build Coastguard Worker        device = 'cpu'
4431*da0073e9SAndroid Build Coastguard Worker        for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype):
4432*da0073e9SAndroid Build Coastguard Worker            m = torch.zeros_like(t, device="meta")
4433*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m.device.type, "meta")
4434*da0073e9SAndroid Build Coastguard Worker            self.assertEqualMeta(m, t, 0)
4435*da0073e9SAndroid Build Coastguard Worker
4436*da0073e9SAndroid Build Coastguard Worker    @all_sparse_layouts('layout', include_strided=False)
4437*da0073e9SAndroid Build Coastguard Worker    @parametrize("dtype", [torch.float64])
4438*da0073e9SAndroid Build Coastguard Worker    def test_fake(self, dtype, layout):
4439*da0073e9SAndroid Build Coastguard Worker        from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
4440*da0073e9SAndroid Build Coastguard Worker        fake_mode = FakeTensorMode()
4441*da0073e9SAndroid Build Coastguard Worker        index_dtype = torch.int64
4442*da0073e9SAndroid Build Coastguard Worker        device = 'cpu'
4443*da0073e9SAndroid Build Coastguard Worker        for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype):
4444*da0073e9SAndroid Build Coastguard Worker            f = FakeTensor.from_tensor(t, fake_mode)
4445*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(f, FakeTensor)
4446*da0073e9SAndroid Build Coastguard Worker            self.assertEqualMeta(f, t, 0)
4447*da0073e9SAndroid Build Coastguard Worker
4448*da0073e9SAndroid Build Coastguard Worker            d = f.detach()
4449*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(d, FakeTensor)
4450*da0073e9SAndroid Build Coastguard Worker            self.assertEqualMeta(d, t, 0)
4451*da0073e9SAndroid Build Coastguard Worker
4452*da0073e9SAndroid Build Coastguard Worker    @all_sparse_layouts('layout', include_strided=False)
4453*da0073e9SAndroid Build Coastguard Worker    @parametrize("dtype", [torch.float64])
4454*da0073e9SAndroid Build Coastguard Worker    def test_zeros_like_fake(self, dtype, layout):
4455*da0073e9SAndroid Build Coastguard Worker        from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
4456*da0073e9SAndroid Build Coastguard Worker        from torch.utils._mode_utils import no_dispatch
4457*da0073e9SAndroid Build Coastguard Worker        fake_mode = FakeTensorMode()
4458*da0073e9SAndroid Build Coastguard Worker        index_dtype = torch.int64
4459*da0073e9SAndroid Build Coastguard Worker        device = 'cpu'
4460*da0073e9SAndroid Build Coastguard Worker        for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype):
4461*da0073e9SAndroid Build Coastguard Worker            f = FakeTensor.from_tensor(t, fake_mode)
4462*da0073e9SAndroid Build Coastguard Worker            expected = torch.zeros_like(t)
4463*da0073e9SAndroid Build Coastguard Worker            with no_dispatch():
4464*da0073e9SAndroid Build Coastguard Worker                result = torch.zeros_like(f, device=f.fake_device)
4465*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, expected)
4466*da0073e9SAndroid Build Coastguard Worker            self.assertEqualMeta(result, expected, 0)
4467*da0073e9SAndroid Build Coastguard Worker
4468*da0073e9SAndroid Build Coastguard Worker    @all_sparse_layouts('layout', include_strided=False)
4469*da0073e9SAndroid Build Coastguard Worker    @parametrize("dtype", [torch.float64])
4470*da0073e9SAndroid Build Coastguard Worker    def test_sum_meta(self, dtype, layout):
4471*da0073e9SAndroid Build Coastguard Worker        device = 'cpu'
4472*da0073e9SAndroid Build Coastguard Worker        index_dtype = torch.int64
4473*da0073e9SAndroid Build Coastguard Worker        for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype):
4474*da0073e9SAndroid Build Coastguard Worker            m = t.to(device='meta')
4475*da0073e9SAndroid Build Coastguard Worker            r = torch.sum(m)
4476*da0073e9SAndroid Build Coastguard Worker            expected = torch.sum(t).to(device="meta")
4477*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(r.is_meta)
4478*da0073e9SAndroid Build Coastguard Worker            self.assertEqualMeta(r, expected, 0)
4479*da0073e9SAndroid Build Coastguard Worker
4480*da0073e9SAndroid Build Coastguard Worker    @all_sparse_layouts('layout', include_strided=False)
4481*da0073e9SAndroid Build Coastguard Worker    @parametrize("dtype", [torch.float64])
4482*da0073e9SAndroid Build Coastguard Worker    def test_add_meta(self, dtype, layout):
4483*da0073e9SAndroid Build Coastguard Worker        device = 'cpu'
4484*da0073e9SAndroid Build Coastguard Worker        index_dtype = torch.int64
4485*da0073e9SAndroid Build Coastguard Worker        for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype):
4486*da0073e9SAndroid Build Coastguard Worker            expected = torch.add(t, t).to(device='meta')
4487*da0073e9SAndroid Build Coastguard Worker            m = t.to(device='meta')
4488*da0073e9SAndroid Build Coastguard Worker            r = torch.add(m, m)
4489*da0073e9SAndroid Build Coastguard Worker            self.assertEqualMeta(r, expected, 0)
4490*da0073e9SAndroid Build Coastguard Worker
4491*da0073e9SAndroid Build Coastguard Worker
4492*da0073e9SAndroid Build Coastguard Workerclass _SparseDataset(torch.utils.data.Dataset):
4493*da0073e9SAndroid Build Coastguard Worker    # An utility class used in TestSparseAny.test_dataloader method.
4494*da0073e9SAndroid Build Coastguard Worker
4495*da0073e9SAndroid Build Coastguard Worker    def __init__(self, sparse_tensors):
4496*da0073e9SAndroid Build Coastguard Worker        self.sparse_tensors = sparse_tensors
4497*da0073e9SAndroid Build Coastguard Worker
4498*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
4499*da0073e9SAndroid Build Coastguard Worker        return len(self.sparse_tensors)
4500*da0073e9SAndroid Build Coastguard Worker
4501*da0073e9SAndroid Build Coastguard Worker    def __getitem__(self, index):
4502*da0073e9SAndroid Build Coastguard Worker        return self.sparse_tensors[index]
4503*da0073e9SAndroid Build Coastguard Worker
4504*da0073e9SAndroid Build Coastguard Worker
4505*da0073e9SAndroid Build Coastguard Workerclass TestSparseAny(TestCase):
4506*da0073e9SAndroid Build Coastguard Worker
4507*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
4508*da0073e9SAndroid Build Coastguard Worker    @all_sparse_layouts('layout', include_strided=False)
4509*da0073e9SAndroid Build Coastguard Worker    @torch.sparse.check_sparse_tensor_invariants(enable=False)
4510*da0073e9SAndroid Build Coastguard Worker    def test_check_sparse_tensor_invariants(self, layout):
4511*da0073e9SAndroid Build Coastguard Worker
4512*da0073e9SAndroid Build Coastguard Worker        if layout is torch.sparse_coo:
4513*da0073e9SAndroid Build Coastguard Worker
4514*da0073e9SAndroid Build Coastguard Worker            def create_invalid_tensor(check_invariants=None):
4515*da0073e9SAndroid Build Coastguard Worker                shape = (2, 2)
4516*da0073e9SAndroid Build Coastguard Worker                invalid_indices = torch.tensor([[0], [3]])  # column index is out of range
4517*da0073e9SAndroid Build Coastguard Worker                values = torch.tensor([1])
4518*da0073e9SAndroid Build Coastguard Worker                if check_invariants is None:
4519*da0073e9SAndroid Build Coastguard Worker                    return torch.sparse_coo_tensor(invalid_indices, values, shape)
4520*da0073e9SAndroid Build Coastguard Worker                else:
4521*da0073e9SAndroid Build Coastguard Worker                    return torch.sparse_coo_tensor(invalid_indices, values, shape, check_invariants=check_invariants)
4522*da0073e9SAndroid Build Coastguard Worker
4523*da0073e9SAndroid Build Coastguard Worker            expected_exception_message = 'size is inconsistent with indices: for dim 1, size is 2 but found index 3'
4524*da0073e9SAndroid Build Coastguard Worker
4525*da0073e9SAndroid Build Coastguard Worker        elif layout in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}:
4526*da0073e9SAndroid Build Coastguard Worker
4527*da0073e9SAndroid Build Coastguard Worker            def create_invalid_tensor(check_invariants=None):
4528*da0073e9SAndroid Build Coastguard Worker                shape = (2, 2)
4529*da0073e9SAndroid Build Coastguard Worker                compressed_indices = torch.tensor([0, 0, 1])
4530*da0073e9SAndroid Build Coastguard Worker                invalid_plain_indices = torch.tensor([3])  # index is out of range
4531*da0073e9SAndroid Build Coastguard Worker                if layout in {torch.sparse_bsr, torch.sparse_bsc}:
4532*da0073e9SAndroid Build Coastguard Worker                    values = torch.tensor([[[1]]])
4533*da0073e9SAndroid Build Coastguard Worker                else:
4534*da0073e9SAndroid Build Coastguard Worker                    values = torch.tensor([1])
4535*da0073e9SAndroid Build Coastguard Worker                if check_invariants is None:
4536*da0073e9SAndroid Build Coastguard Worker                    return torch.sparse_compressed_tensor(compressed_indices, invalid_plain_indices, values, shape, layout=layout)
4537*da0073e9SAndroid Build Coastguard Worker                else:
4538*da0073e9SAndroid Build Coastguard Worker                    return torch.sparse_compressed_tensor(compressed_indices, invalid_plain_indices, values, shape, layout=layout,
4539*da0073e9SAndroid Build Coastguard Worker                                                          check_invariants=check_invariants)
4540*da0073e9SAndroid Build Coastguard Worker
4541*da0073e9SAndroid Build Coastguard Worker            if layout in {torch.sparse_csr, torch.sparse_bsr}:
4542*da0073e9SAndroid Build Coastguard Worker                expected_exception_message = r'`0 <= col_indices < ncols` is not satisfied.'
4543*da0073e9SAndroid Build Coastguard Worker            else:
4544*da0073e9SAndroid Build Coastguard Worker                expected_exception_message = r'`0 <= row_indices < nrows` is not satisfied.'
4545*da0073e9SAndroid Build Coastguard Worker
4546*da0073e9SAndroid Build Coastguard Worker        else:
4547*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError(layout)
4548*da0073e9SAndroid Build Coastguard Worker
4549*da0073e9SAndroid Build Coastguard Worker        # First, consider the case where invariant checks are disabled
4550*da0073e9SAndroid Build Coastguard Worker        # "globally" (read: within the context of this test method
4551*da0073e9SAndroid Build Coastguard Worker        # caller) as defined by check_sparse_tensor_invariants(False)
4552*da0073e9SAndroid Build Coastguard Worker        # decorator:
4553*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4554*da0073e9SAndroid Build Coastguard Worker
4555*da0073e9SAndroid Build Coastguard Worker        # Enable the invariant checks in a local context:
4556*da0073e9SAndroid Build Coastguard Worker        with torch.sparse.check_sparse_tensor_invariants():
4557*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4558*da0073e9SAndroid Build Coastguard Worker
4559*da0073e9SAndroid Build Coastguard Worker        # Leaving the local context must restore the "global" state of
4560*da0073e9SAndroid Build Coastguard Worker        # the invariant check feature:
4561*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4562*da0073e9SAndroid Build Coastguard Worker
4563*da0073e9SAndroid Build Coastguard Worker        # Since invariant checks are disabled by default, we can
4564*da0073e9SAndroid Build Coastguard Worker        # create an invalid sparse tensor without raising an
4565*da0073e9SAndroid Build Coastguard Worker        # exception:
4566*da0073e9SAndroid Build Coastguard Worker        r = create_invalid_tensor()
4567*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r.layout, layout)
4568*da0073e9SAndroid Build Coastguard Worker
4569*da0073e9SAndroid Build Coastguard Worker        # Or, when disabling the invariants check explicitly:
4570*da0073e9SAndroid Build Coastguard Worker        r = create_invalid_tensor(check_invariants=False)
4571*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r.layout, layout)
4572*da0073e9SAndroid Build Coastguard Worker
4573*da0073e9SAndroid Build Coastguard Worker        # Enabling invariant check via constructor's optional argument
4574*da0073e9SAndroid Build Coastguard Worker        # will raise an exception when sparse tensor invariants are
4575*da0073e9SAndroid Build Coastguard Worker        # violated:
4576*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, expected_exception_message):
4577*da0073e9SAndroid Build Coastguard Worker            create_invalid_tensor(check_invariants=True)
4578*da0073e9SAndroid Build Coastguard Worker
4579*da0073e9SAndroid Build Coastguard Worker        # Check that the global invariant check flag has been restored
4580*da0073e9SAndroid Build Coastguard Worker        # after raising the exception above:
4581*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4582*da0073e9SAndroid Build Coastguard Worker
4583*da0073e9SAndroid Build Coastguard Worker        # Next, consider the case where invariant checks are enabled
4584*da0073e9SAndroid Build Coastguard Worker        # within a local context:
4585*da0073e9SAndroid Build Coastguard Worker        with torch.sparse.check_sparse_tensor_invariants():
4586*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4587*da0073e9SAndroid Build Coastguard Worker
4588*da0073e9SAndroid Build Coastguard Worker            # Since invariant checks are now enabled by default, an
4589*da0073e9SAndroid Build Coastguard Worker            # attempt to create an invalid sparse tensor will lead to
4590*da0073e9SAndroid Build Coastguard Worker            # an exception:
4591*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, expected_exception_message):
4592*da0073e9SAndroid Build Coastguard Worker                create_invalid_tensor()
4593*da0073e9SAndroid Build Coastguard Worker
4594*da0073e9SAndroid Build Coastguard Worker            # Similarly, when enabling the invariant checks
4595*da0073e9SAndroid Build Coastguard Worker            # explicitly, invalid sparse tensor construction will lead
4596*da0073e9SAndroid Build Coastguard Worker            # to an exception:
4597*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, expected_exception_message):
4598*da0073e9SAndroid Build Coastguard Worker                create_invalid_tensor(check_invariants=True)
4599*da0073e9SAndroid Build Coastguard Worker
4600*da0073e9SAndroid Build Coastguard Worker            # However, invariants check can be disabled via
4601*da0073e9SAndroid Build Coastguard Worker            # constructor's optional argument so that the invalid
4602*da0073e9SAndroid Build Coastguard Worker            # tensor is succesfully constructed:
4603*da0073e9SAndroid Build Coastguard Worker            r = create_invalid_tensor(check_invariants=False)
4604*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(r.layout, layout)
4605*da0073e9SAndroid Build Coastguard Worker
4606*da0073e9SAndroid Build Coastguard Worker            # Check that the invariant check flag has been restored
4607*da0073e9SAndroid Build Coastguard Worker            # when leaving the constructor:
4608*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4609*da0073e9SAndroid Build Coastguard Worker
4610*da0073e9SAndroid Build Coastguard Worker        # Double-check restoring the global state when leaving the
4611*da0073e9SAndroid Build Coastguard Worker        # local context:
4612*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4613*da0073e9SAndroid Build Coastguard Worker
4614*da0073e9SAndroid Build Coastguard Worker        # Test nesting of pre-defined context managers
4615*da0073e9SAndroid Build Coastguard Worker        check_ctx = torch.sparse.check_sparse_tensor_invariants(True)
4616*da0073e9SAndroid Build Coastguard Worker        no_check_ctx = torch.sparse.check_sparse_tensor_invariants(False)
4617*da0073e9SAndroid Build Coastguard Worker        with check_ctx:
4618*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4619*da0073e9SAndroid Build Coastguard Worker            with no_check_ctx:
4620*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4621*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4622*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4623*da0073e9SAndroid Build Coastguard Worker
4624*da0073e9SAndroid Build Coastguard Worker        # Test an attempt to re-use an activate context manager instance
4625*da0073e9SAndroid Build Coastguard Worker        check_ctx2 = torch.sparse.check_sparse_tensor_invariants(True)
4626*da0073e9SAndroid Build Coastguard Worker        with check_ctx:
4627*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4628*da0073e9SAndroid Build Coastguard Worker            with no_check_ctx:
4629*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4630*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, "This context manager instance is already activated."
4631*da0073e9SAndroid Build Coastguard Worker                                            " Use a different context manager instance for context nesting"):
4632*da0073e9SAndroid Build Coastguard Worker                    with check_ctx:
4633*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4634*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4635*da0073e9SAndroid Build Coastguard Worker                with check_ctx2:
4636*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4637*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4638*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4639*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())
4640*da0073e9SAndroid Build Coastguard Worker
4641*da0073e9SAndroid Build Coastguard Worker    def test_generate_simple_inputs(self):
4642*da0073e9SAndroid Build Coastguard Worker        layouts = [torch.strided, torch.sparse_coo, torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc]
4643*da0073e9SAndroid Build Coastguard Worker
4644*da0073e9SAndroid Build Coastguard Worker        tested_combinations = set()
4645*da0073e9SAndroid Build Coastguard Worker        for tensors in zip(*map(self.generate_simple_inputs, layouts)):
4646*da0073e9SAndroid Build Coastguard Worker            for i, t in enumerate(tensors):
4647*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t.layout, layouts[i])
4648*da0073e9SAndroid Build Coastguard Worker
4649*da0073e9SAndroid Build Coastguard Worker                # all layouts must produce semantically the same tensors
4650*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t, tensors[0])
4651*da0073e9SAndroid Build Coastguard Worker
4652*da0073e9SAndroid Build Coastguard Worker                if t.layout is torch.strided:
4653*da0073e9SAndroid Build Coastguard Worker                    is_hybrid = None
4654*da0073e9SAndroid Build Coastguard Worker                else:
4655*da0073e9SAndroid Build Coastguard Worker                    is_hybrid = t.dense_dim() > 0
4656*da0073e9SAndroid Build Coastguard Worker                if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
4657*da0073e9SAndroid Build Coastguard Worker                    is_batch = t.crow_indices().ndim > 1
4658*da0073e9SAndroid Build Coastguard Worker                elif t.layout in {torch.sparse_csc, torch.sparse_bsc}:
4659*da0073e9SAndroid Build Coastguard Worker                    is_batch = t.ccol_indices().ndim > 1
4660*da0073e9SAndroid Build Coastguard Worker                else:
4661*da0073e9SAndroid Build Coastguard Worker                    is_batch = None
4662*da0073e9SAndroid Build Coastguard Worker                if t.layout in {torch.sparse_bsr, torch.sparse_bsc}:
4663*da0073e9SAndroid Build Coastguard Worker                    blocksize = t.values().shape[1:3]
4664*da0073e9SAndroid Build Coastguard Worker                    nontrivial_blocksize = 1 not in blocksize
4665*da0073e9SAndroid Build Coastguard Worker                else:
4666*da0073e9SAndroid Build Coastguard Worker                    nontrivial_blocksize = None
4667*da0073e9SAndroid Build Coastguard Worker                if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
4668*da0073e9SAndroid Build Coastguard Worker                    contiguous_indices = t.crow_indices().is_contiguous() and t.col_indices().is_contiguous()
4669*da0073e9SAndroid Build Coastguard Worker                    contiguous_values = t.values().is_contiguous()
4670*da0073e9SAndroid Build Coastguard Worker                elif t.layout in {torch.sparse_csc, torch.sparse_bsc}:
4671*da0073e9SAndroid Build Coastguard Worker                    contiguous_indices = t.ccol_indices().is_contiguous() and t.row_indices().is_contiguous()
4672*da0073e9SAndroid Build Coastguard Worker                    contiguous_values = t.values().is_contiguous()
4673*da0073e9SAndroid Build Coastguard Worker                elif t.layout is torch.sparse_coo:
4674*da0073e9SAndroid Build Coastguard Worker                    contiguous_indices = t._indices().is_contiguous()
4675*da0073e9SAndroid Build Coastguard Worker                    contiguous_values = t._values().is_contiguous()
4676*da0073e9SAndroid Build Coastguard Worker                else:
4677*da0073e9SAndroid Build Coastguard Worker                    contiguous_indices = None
4678*da0073e9SAndroid Build Coastguard Worker                    contiguous_values = t.is_contiguous()
4679*da0073e9SAndroid Build Coastguard Worker
4680*da0073e9SAndroid Build Coastguard Worker                tested_combinations.add((t.layout, is_hybrid, is_batch, nontrivial_blocksize,
4681*da0073e9SAndroid Build Coastguard Worker                                         contiguous_indices, contiguous_values))
4682*da0073e9SAndroid Build Coastguard Worker
4683*da0073e9SAndroid Build Coastguard Worker        # Ensure that the inputs generation covers all layout,
4684*da0073e9SAndroid Build Coastguard Worker        # non-hybrid/hybrid, non-batch/batch, and contiguity
4685*da0073e9SAndroid Build Coastguard Worker        # combinations:
4686*da0073e9SAndroid Build Coastguard Worker        untested_combinations = set()
4687*da0073e9SAndroid Build Coastguard Worker        for layout in layouts:
4688*da0073e9SAndroid Build Coastguard Worker            for is_hybrid in [False, True]:
4689*da0073e9SAndroid Build Coastguard Worker                if layout is torch.strided:
4690*da0073e9SAndroid Build Coastguard Worker                    is_hybrid = None
4691*da0073e9SAndroid Build Coastguard Worker                for is_batch in [False, True]:
4692*da0073e9SAndroid Build Coastguard Worker                    if layout in {torch.sparse_coo, torch.strided}:
4693*da0073e9SAndroid Build Coastguard Worker                        is_batch = None
4694*da0073e9SAndroid Build Coastguard Worker                    for nontrivial_blocksize in [False, True]:
4695*da0073e9SAndroid Build Coastguard Worker                        if layout not in {torch.sparse_bsr, torch.sparse_bsc}:
4696*da0073e9SAndroid Build Coastguard Worker                            nontrivial_blocksize = None
4697*da0073e9SAndroid Build Coastguard Worker                        for contiguous_indices in [False, True]:
4698*da0073e9SAndroid Build Coastguard Worker                            if layout is torch.strided:
4699*da0073e9SAndroid Build Coastguard Worker                                contiguous_indices = None
4700*da0073e9SAndroid Build Coastguard Worker                            elif not is_batch:
4701*da0073e9SAndroid Build Coastguard Worker                                # indices are contiguous per-patch
4702*da0073e9SAndroid Build Coastguard Worker                                contiguous_indices = True
4703*da0073e9SAndroid Build Coastguard Worker                            for contiguous_values in [False, True]:
4704*da0073e9SAndroid Build Coastguard Worker                                key = (layout, is_hybrid, is_batch, nontrivial_blocksize,
4705*da0073e9SAndroid Build Coastguard Worker                                       contiguous_indices, contiguous_values)
4706*da0073e9SAndroid Build Coastguard Worker                                if key not in tested_combinations:
4707*da0073e9SAndroid Build Coastguard Worker                                    untested_combinations.add(
4708*da0073e9SAndroid Build Coastguard Worker                                        f'layout={layout}, is_hybrid={is_hybrid}, is_batch={is_batch},'
4709*da0073e9SAndroid Build Coastguard Worker                                        f' nontrivial_blocksize={nontrivial_blocksize},'
4710*da0073e9SAndroid Build Coastguard Worker                                        f' contiguous_indices{contiguous_indices}, contiguous_values={contiguous_values}')
4711*da0073e9SAndroid Build Coastguard Worker        assert not untested_combinations, untested_combinations
4712*da0073e9SAndroid Build Coastguard Worker
4713*da0073e9SAndroid Build Coastguard Worker    @all_sparse_layouts('layout', include_strided=False)
4714*da0073e9SAndroid Build Coastguard Worker    def test_constructor_autograd(self, device, layout):
4715*da0073e9SAndroid Build Coastguard Worker
4716*da0073e9SAndroid Build Coastguard Worker        def specific_constructor(*args, **kwargs):
4717*da0073e9SAndroid Build Coastguard Worker            if layout is torch.sparse_csr:
4718*da0073e9SAndroid Build Coastguard Worker                return torch.sparse_csr_tensor(*args, **kwargs)
4719*da0073e9SAndroid Build Coastguard Worker            elif layout is torch.sparse_csc:
4720*da0073e9SAndroid Build Coastguard Worker                return torch.sparse_csc_tensor(*args, **kwargs)
4721*da0073e9SAndroid Build Coastguard Worker            elif layout is torch.sparse_bsc:
4722*da0073e9SAndroid Build Coastguard Worker                return torch.sparse_bsc_tensor(*args, **kwargs)
4723*da0073e9SAndroid Build Coastguard Worker            elif layout is torch.sparse_bsr:
4724*da0073e9SAndroid Build Coastguard Worker                return torch.sparse_bsr_tensor(*args, **kwargs)
4725*da0073e9SAndroid Build Coastguard Worker            elif layout is torch.sparse_coo:
4726*da0073e9SAndroid Build Coastguard Worker                return torch.sparse_coo_tensor(*args, **kwargs)
4727*da0073e9SAndroid Build Coastguard Worker            else:
4728*da0073e9SAndroid Build Coastguard Worker                raise NotImplementedError(layout)
4729*da0073e9SAndroid Build Coastguard Worker
4730*da0073e9SAndroid Build Coastguard Worker        def generic_constructor(*args, **kwargs):
4731*da0073e9SAndroid Build Coastguard Worker            if layout in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}:
4732*da0073e9SAndroid Build Coastguard Worker                kwargs.update(layout=layout)
4733*da0073e9SAndroid Build Coastguard Worker                return torch.sparse_compressed_tensor(*args, **kwargs)
4734*da0073e9SAndroid Build Coastguard Worker            elif layout is torch.sparse_coo:
4735*da0073e9SAndroid Build Coastguard Worker                return torch.sparse_coo_tensor(*args, **kwargs)
4736*da0073e9SAndroid Build Coastguard Worker            else:
4737*da0073e9SAndroid Build Coastguard Worker                raise NotImplementedError(layout)
4738*da0073e9SAndroid Build Coastguard Worker
4739*da0073e9SAndroid Build Coastguard Worker        if layout is torch.sparse_coo:
4740*da0073e9SAndroid Build Coastguard Worker            constructors = (specific_constructor,)
4741*da0073e9SAndroid Build Coastguard Worker        else:
4742*da0073e9SAndroid Build Coastguard Worker            constructors = (specific_constructor, generic_constructor)
4743*da0073e9SAndroid Build Coastguard Worker
4744*da0073e9SAndroid Build Coastguard Worker        for args, kwargs in self.generate_simple_inputs(
4745*da0073e9SAndroid Build Coastguard Worker                layout, device=device, dtype=torch.float64,
4746*da0073e9SAndroid Build Coastguard Worker                enable_batch=False,  # TODO: remove after gh-104868 is resolved
4747*da0073e9SAndroid Build Coastguard Worker                output_tensor=False):
4748*da0073e9SAndroid Build Coastguard Worker            values_offset = 1 if layout is torch.sparse_coo else 2
4749*da0073e9SAndroid Build Coastguard Worker
4750*da0073e9SAndroid Build Coastguard Worker            for cnstr in constructors:
4751*da0073e9SAndroid Build Coastguard Worker                for requires_grad in (False, True):
4752*da0073e9SAndroid Build Coastguard Worker                    values = args[values_offset].detach().requires_grad_(requires_grad)
4753*da0073e9SAndroid Build Coastguard Worker                    args = (*args[:values_offset], values, *args[values_offset + 1:])
4754*da0073e9SAndroid Build Coastguard Worker                    kwargs_ = dict(kwargs)
4755*da0073e9SAndroid Build Coastguard Worker                    args_ = args + (kwargs_.pop('size'),)
4756*da0073e9SAndroid Build Coastguard Worker
4757*da0073e9SAndroid Build Coastguard Worker                    sparse = cnstr(*args, **kwargs)
4758*da0073e9SAndroid Build Coastguard Worker
4759*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(sparse.requires_grad, requires_grad)
4760*da0073e9SAndroid Build Coastguard Worker
4761*da0073e9SAndroid Build Coastguard Worker                    if requires_grad:
4762*da0073e9SAndroid Build Coastguard Worker                        for masked in (False, True):
4763*da0073e9SAndroid Build Coastguard Worker                            if layout is torch.sparse_coo:
4764*da0073e9SAndroid Build Coastguard Worker                                torch.autograd.gradcheck(
4765*da0073e9SAndroid Build Coastguard Worker                                    lambda i, v: cnstr(i, v, **kwargs).to_dense(masked_grad=masked),
4766*da0073e9SAndroid Build Coastguard Worker                                    args, masked=masked)
4767*da0073e9SAndroid Build Coastguard Worker                                torch.autograd.gradcheck(
4768*da0073e9SAndroid Build Coastguard Worker                                    lambda i, v, sz: cnstr(i, v, sz, **kwargs_).to_dense(masked_grad=masked),
4769*da0073e9SAndroid Build Coastguard Worker                                    args_, masked=masked)
4770*da0073e9SAndroid Build Coastguard Worker                            else:
4771*da0073e9SAndroid Build Coastguard Worker                                if layout in {torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc} and 0:
4772*da0073e9SAndroid Build Coastguard Worker                                    # TODO: remove this if-block after gh-107370 is resolved
4773*da0073e9SAndroid Build Coastguard Worker                                    continue
4774*da0073e9SAndroid Build Coastguard Worker                                torch.autograd.gradcheck(
4775*da0073e9SAndroid Build Coastguard Worker                                    lambda ci, pi, v: cnstr(ci, pi, v, **kwargs).to_dense(masked_grad=masked),
4776*da0073e9SAndroid Build Coastguard Worker                                    args, masked=masked)
4777*da0073e9SAndroid Build Coastguard Worker                                torch.autograd.gradcheck(
4778*da0073e9SAndroid Build Coastguard Worker                                    lambda ci, pi, v, sz: cnstr(ci, pi, v, sz, **kwargs_).to_dense(masked_grad=masked),
4779*da0073e9SAndroid Build Coastguard Worker                                    args_, masked=masked)
4780*da0073e9SAndroid Build Coastguard Worker
4781*da0073e9SAndroid Build Coastguard Worker    @all_sparse_layouts('from_layout', include_strided=False)
4782*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
4783*da0073e9SAndroid Build Coastguard Worker    @parametrize("index_dtype", [torch.int32, torch.int64])
4784*da0073e9SAndroid Build Coastguard Worker    def test_to_dense(self, from_layout, device, dtype, index_dtype):
4785*da0073e9SAndroid Build Coastguard Worker        """
4786*da0073e9SAndroid Build Coastguard Worker        This test tests conversion from any layout to strided layout.
4787*da0073e9SAndroid Build Coastguard Worker        """
4788*da0073e9SAndroid Build Coastguard Worker        for t in self.generate_simple_inputs(
4789*da0073e9SAndroid Build Coastguard Worker                from_layout, device=device, dtype=dtype, index_dtype=index_dtype):
4790*da0073e9SAndroid Build Coastguard Worker            r = t.to_dense()
4791*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(r.layout, torch.strided)
4792*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(r, t)
4793*da0073e9SAndroid Build Coastguard Worker
4794*da0073e9SAndroid Build Coastguard Worker    @all_sparse_layouts('from_layout', include_strided=False)
4795*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float64, torch.complex128)
4796*da0073e9SAndroid Build Coastguard Worker    @parametrize("index_dtype", [torch.int64])
4797*da0073e9SAndroid Build Coastguard Worker    @gradcheck_semantics()
4798*da0073e9SAndroid Build Coastguard Worker    def test_gradcheck_to_dense(self, from_layout, device, dtype, index_dtype, gradcheck):
4799*da0073e9SAndroid Build Coastguard Worker        for t in self.generate_simple_inputs(
4800*da0073e9SAndroid Build Coastguard Worker                from_layout, device=device, dtype=dtype, index_dtype=index_dtype):
4801*da0073e9SAndroid Build Coastguard Worker            batch_dim = t.dim() - t.dense_dim() - t.sparse_dim()
4802*da0073e9SAndroid Build Coastguard Worker            if batch_dim > 0:
4803*da0073e9SAndroid Build Coastguard Worker                # TODO: implement batch support in _convert_indices_from_csr_to_coo
4804*da0073e9SAndroid Build Coastguard Worker                continue
4805*da0073e9SAndroid Build Coastguard Worker            t = t.clone().detach().requires_grad_(True)
4806*da0073e9SAndroid Build Coastguard Worker            r = gradcheck(lambda x: torch.Tensor.to_dense(x, masked_grad=gradcheck.masked), t)
4807*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(r)
4808*da0073e9SAndroid Build Coastguard Worker
4809*da0073e9SAndroid Build Coastguard Worker    @all_sparse_layouts('from_layout', include_strided=True)
4810*da0073e9SAndroid Build Coastguard Worker    @all_sparse_layouts('to_layout', include_strided=False)
4811*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
4812*da0073e9SAndroid Build Coastguard Worker    @parametrize("index_dtype", [torch.int32, torch.int64])
4813*da0073e9SAndroid Build Coastguard Worker    def test_to_sparse(self, from_layout, to_layout, device, dtype, index_dtype):
4814*da0073e9SAndroid Build Coastguard Worker        """
4815*da0073e9SAndroid Build Coastguard Worker        This test tests conversion from any layout to any sparse layout.
4816*da0073e9SAndroid Build Coastguard Worker        """
4817*da0073e9SAndroid Build Coastguard Worker        for t in self.generate_simple_inputs(
4818*da0073e9SAndroid Build Coastguard Worker                from_layout, device=device, dtype=dtype, index_dtype=index_dtype,
4819*da0073e9SAndroid Build Coastguard Worker                enable_hybrid=(
4820*da0073e9SAndroid Build Coastguard Worker                    # TODO: to support conversion strided->hybrid
4821*da0073e9SAndroid Build Coastguard Worker                    # CSR/CSC/BSR/BSC, to_sparse() requires extra keyword
4822*da0073e9SAndroid Build Coastguard Worker                    # argument, either nof_batch_dims or
4823*da0073e9SAndroid Build Coastguard Worker                    # nof_dense_dims
4824*da0073e9SAndroid Build Coastguard Worker                    not (from_layout is torch.strided and to_layout in
4825*da0073e9SAndroid Build Coastguard Worker                         {torch.sparse_bsr, torch.sparse_bsc, torch.sparse_csr, torch.sparse_csc}))):
4826*da0073e9SAndroid Build Coastguard Worker
4827*da0073e9SAndroid Build Coastguard Worker            if to_layout in {torch.sparse_bsr, torch.sparse_bsc}:
4828*da0073e9SAndroid Build Coastguard Worker                if from_layout == torch.sparse_bsr:
4829*da0073e9SAndroid Build Coastguard Worker                    batch_ndim = t.crow_indices().dim() - 1
4830*da0073e9SAndroid Build Coastguard Worker                    blocksize = t.values().shape[batch_ndim + 1:batch_ndim + 3]
4831*da0073e9SAndroid Build Coastguard Worker                elif from_layout == torch.sparse_bsc:
4832*da0073e9SAndroid Build Coastguard Worker                    batch_ndim = t.ccol_indices().dim() - 1
4833*da0073e9SAndroid Build Coastguard Worker                    blocksize = t.values().shape[batch_ndim + 1:batch_ndim + 3]
4834*da0073e9SAndroid Build Coastguard Worker                else:
4835*da0073e9SAndroid Build Coastguard Worker                    blocksize = (1, 1)
4836*da0073e9SAndroid Build Coastguard Worker            else:
4837*da0073e9SAndroid Build Coastguard Worker                blocksize = None
4838*da0073e9SAndroid Build Coastguard Worker
4839*da0073e9SAndroid Build Coastguard Worker            if from_layout is torch.strided:
4840*da0073e9SAndroid Build Coastguard Worker                is_batch = None
4841*da0073e9SAndroid Build Coastguard Worker                is_hybrid = None
4842*da0073e9SAndroid Build Coastguard Worker            else:
4843*da0073e9SAndroid Build Coastguard Worker                is_batch = t.dim() > (t.sparse_dim() + t.dense_dim())
4844*da0073e9SAndroid Build Coastguard Worker                is_hybrid = t.dense_dim() > 0
4845*da0073e9SAndroid Build Coastguard Worker
4846*da0073e9SAndroid Build Coastguard Worker            def explicit_to_sparse(x):
4847*da0073e9SAndroid Build Coastguard Worker                # Used to check that the explicit conversion methods
4848*da0073e9SAndroid Build Coastguard Worker                # are consistent with the `to_sparse(*, layout,
4849*da0073e9SAndroid Build Coastguard Worker                # blocksize)` method.
4850*da0073e9SAndroid Build Coastguard Worker                if to_layout is torch.sparse_coo:
4851*da0073e9SAndroid Build Coastguard Worker                    return x.to_sparse_coo()
4852*da0073e9SAndroid Build Coastguard Worker                elif to_layout is torch.sparse_csr:
4853*da0073e9SAndroid Build Coastguard Worker                    return x.to_sparse_csr()
4854*da0073e9SAndroid Build Coastguard Worker                elif to_layout is torch.sparse_csc:
4855*da0073e9SAndroid Build Coastguard Worker                    return x.to_sparse_csc()
4856*da0073e9SAndroid Build Coastguard Worker                elif to_layout is torch.sparse_bsr:
4857*da0073e9SAndroid Build Coastguard Worker                    return x.to_sparse_bsr(blocksize)
4858*da0073e9SAndroid Build Coastguard Worker                elif to_layout is torch.sparse_bsc:
4859*da0073e9SAndroid Build Coastguard Worker                    return x.to_sparse_bsc(blocksize)
4860*da0073e9SAndroid Build Coastguard Worker                else:
4861*da0073e9SAndroid Build Coastguard Worker                    assert 0  # unreachable
4862*da0073e9SAndroid Build Coastguard Worker
4863*da0073e9SAndroid Build Coastguard Worker            # TODO: The following exception cases all correspond to
4864*da0073e9SAndroid Build Coastguard Worker            # not implemented conversions
4865*da0073e9SAndroid Build Coastguard Worker            if from_layout in {
4866*da0073e9SAndroid Build Coastguard Worker                    torch.sparse_csr, torch.sparse_csc} and to_layout in {torch.sparse_bsr, torch.sparse_bsc} and is_batch:
4867*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
4868*da0073e9SAndroid Build Coastguard Worker                        RuntimeError,
4869*da0073e9SAndroid Build Coastguard Worker                        r"conversion from Sparse(Csr|Csc) to Sparse(Bsr|Bsc) for batched inputs is not supported"):
4870*da0073e9SAndroid Build Coastguard Worker                    t.to_sparse(layout=to_layout, blocksize=blocksize)
4871*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
4872*da0073e9SAndroid Build Coastguard Worker                        RuntimeError,
4873*da0073e9SAndroid Build Coastguard Worker                        r"conversion from Sparse(Csr|Csc) to Sparse(Bsr|Bsc) for batched inputs is not supported"):
4874*da0073e9SAndroid Build Coastguard Worker                    explicit_to_sparse(t)
4875*da0073e9SAndroid Build Coastguard Worker                continue
4876*da0073e9SAndroid Build Coastguard Worker            elif from_layout is torch.sparse_coo and to_layout in {
4877*da0073e9SAndroid Build Coastguard Worker                    torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc} and t.sparse_dim() != 2:
4878*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
4879*da0073e9SAndroid Build Coastguard Worker                        RuntimeError,
4880*da0073e9SAndroid Build Coastguard Worker                        r"conversion from Sparse to .* for input tensors with sparse_dim\(\)!=2 is not supported"):
4881*da0073e9SAndroid Build Coastguard Worker                    t.to_sparse(layout=to_layout, blocksize=blocksize)
4882*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
4883*da0073e9SAndroid Build Coastguard Worker                        RuntimeError,
4884*da0073e9SAndroid Build Coastguard Worker                        r"conversion from Sparse to .* for input tensors with sparse_dim\(\)!=2 is not supported"):
4885*da0073e9SAndroid Build Coastguard Worker                    explicit_to_sparse(t)
4886*da0073e9SAndroid Build Coastguard Worker                continue
4887*da0073e9SAndroid Build Coastguard Worker            elif (from_layout, to_layout) in {(torch.sparse_bsc, torch.sparse_csr), (torch.sparse_bsc, torch.sparse_csc),
4888*da0073e9SAndroid Build Coastguard Worker                                              (torch.sparse_bsr, torch.sparse_csr), (torch.sparse_bsr, torch.sparse_csc)}:
4889*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
4890*da0073e9SAndroid Build Coastguard Worker                        RuntimeError,
4891*da0073e9SAndroid Build Coastguard Worker                        r"sparse_compressed_to_sparse_(csr|csc|bsr|bsc): expected\s*(Sparse(Csc|Csr)[,]|)\s*Sparse(Csr|Bsr)"
4892*da0073e9SAndroid Build Coastguard Worker                        " or Sparse(Csc|Bsc) layout but got Sparse(Csr|Csc|Bsr|Bsc)"):
4893*da0073e9SAndroid Build Coastguard Worker                    t.to_sparse(layout=to_layout, blocksize=blocksize)
4894*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
4895*da0073e9SAndroid Build Coastguard Worker                        RuntimeError,
4896*da0073e9SAndroid Build Coastguard Worker                        r"sparse_compressed_to_sparse_(csr|csc|bsr|bsc): expected\s*(Sparse(Csc|Csr)[,]|)\s*Sparse(Csr|Bsr)"
4897*da0073e9SAndroid Build Coastguard Worker                        " or Sparse(Csc|Bsc) layout but got Sparse(Csr|Csc|Bsr|Bsc)"):
4898*da0073e9SAndroid Build Coastguard Worker                    explicit_to_sparse(t)
4899*da0073e9SAndroid Build Coastguard Worker                self.skipTest('NOT IMPL')
4900*da0073e9SAndroid Build Coastguard Worker            else:
4901*da0073e9SAndroid Build Coastguard Worker                r = t.to_sparse(layout=to_layout, blocksize=blocksize)
4902*da0073e9SAndroid Build Coastguard Worker
4903*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(r.layout, to_layout)
4904*da0073e9SAndroid Build Coastguard Worker
4905*da0073e9SAndroid Build Coastguard Worker                # to_sparse method uses unsafe construction of sparse
4906*da0073e9SAndroid Build Coastguard Worker                # tensors. Here we explicitly validate the results to
4907*da0073e9SAndroid Build Coastguard Worker                # make sure that the sparse tensors are consistent
4908*da0073e9SAndroid Build Coastguard Worker                # with the corresponding sparse tensor invariants.
4909*da0073e9SAndroid Build Coastguard Worker                if r.layout in {torch.sparse_csr, torch.sparse_bsr, torch.sparse_csc, torch.sparse_bsc}:
4910*da0073e9SAndroid Build Coastguard Worker                    if r.layout in {torch.sparse_csr, torch.sparse_bsr}:
4911*da0073e9SAndroid Build Coastguard Worker                        compressed_indices, plain_indices = r.crow_indices(), r.col_indices()
4912*da0073e9SAndroid Build Coastguard Worker                    else:
4913*da0073e9SAndroid Build Coastguard Worker                        compressed_indices, plain_indices = r.ccol_indices(), r.row_indices()
4914*da0073e9SAndroid Build Coastguard Worker                    torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, r.values(),
4915*da0073e9SAndroid Build Coastguard Worker                                                                  r.shape, r.layout)
4916*da0073e9SAndroid Build Coastguard Worker                    if from_layout in {torch.strided, torch.sparse_coo}:
4917*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(compressed_indices.dtype, torch.int64)
4918*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(plain_indices.dtype, torch.int64)
4919*da0073e9SAndroid Build Coastguard Worker                    else:
4920*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(compressed_indices.dtype, index_dtype)
4921*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(plain_indices.dtype, index_dtype)
4922*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(r.values().dtype, dtype)
4923*da0073e9SAndroid Build Coastguard Worker                elif r.layout is torch.sparse_coo:
4924*da0073e9SAndroid Build Coastguard Worker                    if t.layout is torch.sparse_coo:
4925*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(t.is_coalesced(), r.is_coalesced())
4926*da0073e9SAndroid Build Coastguard Worker
4927*da0073e9SAndroid Build Coastguard Worker                    # Check r is truly coalesced when r.is_coalesced == True
4928*da0073e9SAndroid Build Coastguard Worker                    if r.is_coalesced():
4929*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue(is_coalesced_indices(r))
4930*da0073e9SAndroid Build Coastguard Worker
4931*da0073e9SAndroid Build Coastguard Worker                    torch._validate_sparse_coo_tensor_args(r._indices(), r._values(), r.shape)
4932*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(r._indices().dtype, torch.int64)
4933*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(r._values().dtype, dtype)
4934*da0073e9SAndroid Build Coastguard Worker                else:
4935*da0073e9SAndroid Build Coastguard Worker                    assert 0  # unreachable
4936*da0073e9SAndroid Build Coastguard Worker
4937*da0073e9SAndroid Build Coastguard Worker                # Finally, we'll test tensor equality:
4938*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(r, t)
4939*da0073e9SAndroid Build Coastguard Worker
4940*da0073e9SAndroid Build Coastguard Worker                # Also, check consistency with explicit conversion methods:
4941*da0073e9SAndroid Build Coastguard Worker                r2 = explicit_to_sparse(t)
4942*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(r2, r)
4943*da0073e9SAndroid Build Coastguard Worker
4944*da0073e9SAndroid Build Coastguard Worker                # Check inverse conversion from sparse compressed block tensors
4945*da0073e9SAndroid Build Coastguard Worker                if from_layout == torch.sparse_bsr:
4946*da0073e9SAndroid Build Coastguard Worker                    batch_ndim = t.crow_indices().dim() - 1
4947*da0073e9SAndroid Build Coastguard Worker                    from_blocksize = t.values().shape[batch_ndim + 1:batch_ndim + 3]
4948*da0073e9SAndroid Build Coastguard Worker                elif from_layout == torch.sparse_bsc:
4949*da0073e9SAndroid Build Coastguard Worker                    batch_ndim = t.ccol_indices().dim() - 1
4950*da0073e9SAndroid Build Coastguard Worker                    from_blocksize = t.values().shape[batch_ndim + 1:batch_ndim + 3]
4951*da0073e9SAndroid Build Coastguard Worker                else:
4952*da0073e9SAndroid Build Coastguard Worker                    continue
4953*da0073e9SAndroid Build Coastguard Worker                if r.ndim != 2:
4954*da0073e9SAndroid Build Coastguard Worker                    continue
4955*da0073e9SAndroid Build Coastguard Worker
4956*da0073e9SAndroid Build Coastguard Worker                t2 = r.to_sparse(layout=from_layout, blocksize=from_blocksize)
4957*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t2, t)
4958*da0073e9SAndroid Build Coastguard Worker
4959*da0073e9SAndroid Build Coastguard Worker        # extra tests
4960*da0073e9SAndroid Build Coastguard Worker        if (from_layout, to_layout) == (torch.sparse_csr, torch.sparse_bsr):
4961*da0073e9SAndroid Build Coastguard Worker            # See gh-90910
4962*da0073e9SAndroid Build Coastguard Worker            t = torch.tensor([[0, 0, 1, 0], [0, 1, 0, 0]], dtype=dtype, device=device).to_sparse_csr()
4963*da0073e9SAndroid Build Coastguard Worker            r = t.to_sparse_bsr((2, 2))
4964*da0073e9SAndroid Build Coastguard Worker            torch._validate_sparse_compressed_tensor_args(r.crow_indices(), r.col_indices(), r.values(), r.shape, r.layout)
4965*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(r, t)
4966*da0073e9SAndroid Build Coastguard Worker
4967*da0073e9SAndroid Build Coastguard Worker        if (from_layout, to_layout) in {(torch.sparse_csr, torch.sparse_csc),
4968*da0073e9SAndroid Build Coastguard Worker                                        (torch.sparse_csc, torch.sparse_csr)}:
4969*da0073e9SAndroid Build Coastguard Worker            # See gh-91007
4970*da0073e9SAndroid Build Coastguard Worker            compressed_indices = torch.tensor([0, 4, 8, 8, 12, 16, 20], dtype=index_dtype, device=device)
4971*da0073e9SAndroid Build Coastguard Worker            plain_indices = torch.tensor([0, 1, 2, 3] * 5, dtype=index_dtype, device=device)
4972*da0073e9SAndroid Build Coastguard Worker            t = torch.sparse_compressed_tensor(compressed_indices, plain_indices, range(20),
4973*da0073e9SAndroid Build Coastguard Worker                                               dtype=dtype, device=device, layout=from_layout)
4974*da0073e9SAndroid Build Coastguard Worker            r = t.to_sparse(layout=to_layout)
4975*da0073e9SAndroid Build Coastguard Worker            if r.layout in {torch.sparse_csr, torch.sparse_bsr}:
4976*da0073e9SAndroid Build Coastguard Worker                compressed_indices, plain_indices = r.crow_indices(), r.col_indices()
4977*da0073e9SAndroid Build Coastguard Worker            else:
4978*da0073e9SAndroid Build Coastguard Worker                compressed_indices, plain_indices = r.ccol_indices(), r.row_indices()
4979*da0073e9SAndroid Build Coastguard Worker            torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, r.values(), r.shape, r.layout)
4980*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(r, t)
4981*da0073e9SAndroid Build Coastguard Worker
4982*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
4983*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
4984*da0073e9SAndroid Build Coastguard Worker    @ops(reduction_ops_with_sparse_support)
4985*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.bfloat16: 5e-4, torch.float16: 5e-3})
4986*da0073e9SAndroid Build Coastguard Worker    @all_sparse_layouts('layout', include_strided=False)
4987*da0073e9SAndroid Build Coastguard Worker    def test_reductions(self, layout, device, dtype, op):
4988*da0073e9SAndroid Build Coastguard Worker        count = 0
4989*da0073e9SAndroid Build Coastguard Worker        for sample in op.sample_inputs_sparse(layout, device, dtype):
4990*da0073e9SAndroid Build Coastguard Worker            count += 1
4991*da0073e9SAndroid Build Coastguard Worker
4992*da0073e9SAndroid Build Coastguard Worker            t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs
4993*da0073e9SAndroid Build Coastguard Worker            result = op.op(t_inp, *t_args, **t_kwargs)
4994*da0073e9SAndroid Build Coastguard Worker
4995*da0073e9SAndroid Build Coastguard Worker            #  Checking invariant rop(inp, ...).to_dense() == rop(inp.to_dense(), ...)
4996*da0073e9SAndroid Build Coastguard Worker            dense = op.op(t_inp.to_dense(), *t_args, **t_kwargs)
4997*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, dense)
4998*da0073e9SAndroid Build Coastguard Worker
4999*da0073e9SAndroid Build Coastguard Worker        if count == 0:
5000*da0073e9SAndroid Build Coastguard Worker            # we count samples to avoid false-positive test reports
5001*da0073e9SAndroid Build Coastguard Worker            self.skipTest('no sample inputs')
5002*da0073e9SAndroid Build Coastguard Worker
5003*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
5004*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
5005*da0073e9SAndroid Build Coastguard Worker    @ops(reduction_ops_with_sparse_support, allowed_dtypes=(torch.float32, torch.float64, torch.complex64, torch.complex128))
5006*da0073e9SAndroid Build Coastguard Worker    @all_sparse_layouts('layout', include_strided=False)
5007*da0073e9SAndroid Build Coastguard Worker    def test_reductions_backward(self, layout, device, dtype, op):
5008*da0073e9SAndroid Build Coastguard Worker        count = 0
5009*da0073e9SAndroid Build Coastguard Worker        for sample in op.sample_inputs_sparse(layout, device, dtype, requires_grad=True):
5010*da0073e9SAndroid Build Coastguard Worker            t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs
5011*da0073e9SAndroid Build Coastguard Worker            r = op.op(t_inp, *t_args, **t_kwargs)
5012*da0073e9SAndroid Build Coastguard Worker            if r.numel() != 0:
5013*da0073e9SAndroid Build Coastguard Worker                r = r.sum()
5014*da0073e9SAndroid Build Coastguard Worker
5015*da0073e9SAndroid Build Coastguard Worker            if op.name == 'sum':
5016*da0073e9SAndroid Build Coastguard Worker                count += 1
5017*da0073e9SAndroid Build Coastguard Worker                r.abs().backward()
5018*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t_inp.grad, torch.ones(t_inp.shape, dtype=dtype, device=device) * torch.sgn(r))
5019*da0073e9SAndroid Build Coastguard Worker            else:
5020*da0073e9SAndroid Build Coastguard Worker                self.skipTest('NOT IMPL')
5021*da0073e9SAndroid Build Coastguard Worker
5022*da0073e9SAndroid Build Coastguard Worker        if count == 0:
5023*da0073e9SAndroid Build Coastguard Worker            # we count samples to avoid false-positive test reports
5024*da0073e9SAndroid Build Coastguard Worker            self.skipTest('no sample inputs')
5025*da0073e9SAndroid Build Coastguard Worker
5026*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
5027*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
5028*da0073e9SAndroid Build Coastguard Worker    @parametrize("mth", [subtest(mth, name=mth.__name__)
5029*da0073e9SAndroid Build Coastguard Worker                         for mth in [torch.Tensor.is_coalesced,
5030*da0073e9SAndroid Build Coastguard Worker                                     torch.Tensor.coalesce,
5031*da0073e9SAndroid Build Coastguard Worker                                     torch.Tensor.indices,
5032*da0073e9SAndroid Build Coastguard Worker                                     torch.Tensor.values,
5033*da0073e9SAndroid Build Coastguard Worker                                     torch.Tensor.crow_indices,
5034*da0073e9SAndroid Build Coastguard Worker                                     torch.Tensor.col_indices,
5035*da0073e9SAndroid Build Coastguard Worker                                     torch.Tensor.ccol_indices,
5036*da0073e9SAndroid Build Coastguard Worker                                     torch.Tensor.row_indices,
5037*da0073e9SAndroid Build Coastguard Worker                                     ]])
5038*da0073e9SAndroid Build Coastguard Worker    @all_sparse_layouts('layout', include_strided=True)
5039*da0073e9SAndroid Build Coastguard Worker    def test_unsupported_backend_error_message(self, mth, layout, device):
5040*da0073e9SAndroid Build Coastguard Worker        inp = torch.tensor([[1, 2], [3, 4]], device=device).to_sparse(
5041*da0073e9SAndroid Build Coastguard Worker            layout=layout,
5042*da0073e9SAndroid Build Coastguard Worker            blocksize=(1, 1) if layout in {torch.sparse_bsr, torch.sparse_bsc} else None)
5043*da0073e9SAndroid Build Coastguard Worker        assert inp.layout is layout
5044*da0073e9SAndroid Build Coastguard Worker
5045*da0073e9SAndroid Build Coastguard Worker        expected_behaviour = dict(
5046*da0073e9SAndroid Build Coastguard Worker            # <mth name> = (<supported layouts>, <exception message on other layouts>)
5047*da0073e9SAndroid Build Coastguard Worker            is_coalesced=({torch.sparse_coo},
5048*da0073e9SAndroid Build Coastguard Worker                          "is_coalesced expected sparse coordinate tensor layout but got (Sparse(Csr|Csc|Bsr|Bsc)|Strided)"),
5049*da0073e9SAndroid Build Coastguard Worker            coalesce=({torch.sparse_coo},
5050*da0073e9SAndroid Build Coastguard Worker                      "coalesce expected sparse coordinate tensor layout but got (Sparse(Csr|Csc|Bsr|Bsc)|Strided)"),
5051*da0073e9SAndroid Build Coastguard Worker            indices=({torch.sparse_coo},
5052*da0073e9SAndroid Build Coastguard Worker                     "indices expected sparse coordinate tensor layout but got (Sparse(Csr|Csc|Bsr|Bsc)|Strided)"),
5053*da0073e9SAndroid Build Coastguard Worker            values=({torch.sparse_coo, torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc},
5054*da0073e9SAndroid Build Coastguard Worker                    "values expected sparse tensor layout but got Strided"),
5055*da0073e9SAndroid Build Coastguard Worker            crow_indices=({torch.sparse_csr, torch.sparse_bsr},
5056*da0073e9SAndroid Build Coastguard Worker                          "crow_indices expected sparse row compressed tensor layout but got (Sparse(Csc|Bsc|)|Strided)"),
5057*da0073e9SAndroid Build Coastguard Worker            col_indices=({torch.sparse_csr, torch.sparse_bsr},
5058*da0073e9SAndroid Build Coastguard Worker                         "col_indices expected sparse row compressed tensor layout but got (Sparse(Csc|Bsc|)|Strided)"),
5059*da0073e9SAndroid Build Coastguard Worker            ccol_indices=({torch.sparse_csc, torch.sparse_bsc},
5060*da0073e9SAndroid Build Coastguard Worker                          "ccol_indices expected sparse column compressed tensor layout but got (Sparse(Csr|Bsr|)|Strided)"),
5061*da0073e9SAndroid Build Coastguard Worker            row_indices=({torch.sparse_csc, torch.sparse_bsc},
5062*da0073e9SAndroid Build Coastguard Worker                         "row_indices expected sparse column compressed tensor layout but got (Sparse(Csr|Bsr|)|Strided)"),
5063*da0073e9SAndroid Build Coastguard Worker        )[mth.__name__]
5064*da0073e9SAndroid Build Coastguard Worker
5065*da0073e9SAndroid Build Coastguard Worker        if layout in expected_behaviour[0]:
5066*da0073e9SAndroid Build Coastguard Worker            mth(inp)
5067*da0073e9SAndroid Build Coastguard Worker        else:
5068*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, expected_behaviour[1]):
5069*da0073e9SAndroid Build Coastguard Worker                mth(inp)
5070*da0073e9SAndroid Build Coastguard Worker
5071*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
5072*da0073e9SAndroid Build Coastguard Worker    @all_sparse_layouts('layout', include_strided=not True)
5073*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float64, torch.cdouble)
5074*da0073e9SAndroid Build Coastguard Worker    @parametrize("masked", [subtest(False, name='sparse'), subtest(True, name='masked')])
5075*da0073e9SAndroid Build Coastguard Worker    @parametrize("fast_mode", [subtest(False, name='slow'), subtest(True, name='fast')])
5076*da0073e9SAndroid Build Coastguard Worker    def test_gradcheck_mm(self, layout, dtype, device, masked, fast_mode):
5077*da0073e9SAndroid Build Coastguard Worker        # This function does not check the following cases:
5078*da0073e9SAndroid Build Coastguard Worker        # - batch or hybrid tensors because addmm does not support
5079*da0073e9SAndroid Build Coastguard Worker        #   such inputs yet
5080*da0073e9SAndroid Build Coastguard Worker        # - check_forward_ad=True because of the lack of sparse tensor
5081*da0073e9SAndroid Build Coastguard Worker        #   support in aten::view_as_real, torch._VF._make_dual, etc.
5082*da0073e9SAndroid Build Coastguard Worker
5083*da0073e9SAndroid Build Coastguard Worker        ref_x = torch.tensor([[1, 2, 0, 0],
5084*da0073e9SAndroid Build Coastguard Worker                              [0, 6, 0, 0],
5085*da0073e9SAndroid Build Coastguard Worker                              [0, 0, 0, 0],
5086*da0073e9SAndroid Build Coastguard Worker                              [13, 14, 0, 15]], dtype=dtype, device=device)
5087*da0073e9SAndroid Build Coastguard Worker        ref_y = torch.tensor([[11, 12, 13, 14],
5088*da0073e9SAndroid Build Coastguard Worker                              [21, 22, 23, 24],
5089*da0073e9SAndroid Build Coastguard Worker                              [31, 32, 33, 34],
5090*da0073e9SAndroid Build Coastguard Worker                              [41, 42, 43, 44]],
5091*da0073e9SAndroid Build Coastguard Worker                             dtype=dtype, device=device)
5092*da0073e9SAndroid Build Coastguard Worker
5093*da0073e9SAndroid Build Coastguard Worker        mm = torch.sparse.mm if masked else torch.mm
5094*da0073e9SAndroid Build Coastguard Worker
5095*da0073e9SAndroid Build Coastguard Worker        blocksize = (2, 2) if layout in {torch.sparse_bsr, torch.sparse_bsc} else None
5096*da0073e9SAndroid Build Coastguard Worker        x = ref_x.to_sparse(layout=layout, blocksize=blocksize).requires_grad_(True)
5097*da0073e9SAndroid Build Coastguard Worker        y = ref_y.requires_grad_(True)
5098*da0073e9SAndroid Build Coastguard Worker
5099*da0073e9SAndroid Build Coastguard Worker        if layout is torch.sparse_bsr and not masked or layout is torch.sparse_bsc:
5100*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
5101*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
5102*da0073e9SAndroid Build Coastguard Worker                    r"addmm: computation on (CPU|CUDA) is not implemented for Strided \+ Sparse(Bsr|Bsc) @ Strided"):
5103*da0073e9SAndroid Build Coastguard Worker                torch.autograd.gradcheck(mm, (x, y), fast_mode=fast_mode, masked=masked)
5104*da0073e9SAndroid Build Coastguard Worker            self.skipTest('NOT IMPL')
5105*da0073e9SAndroid Build Coastguard Worker        elif layout in {torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc} and masked:
5106*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
5107*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
5108*da0073e9SAndroid Build Coastguard Worker                    r"(sparse_addmm_sparse_backward: unsupported combination of layouts,"
5109*da0073e9SAndroid Build Coastguard Worker                    r" grad: Strided, mat1: Sparse(Csc|Bsr|Bsc), mat2: Strided"
5110*da0073e9SAndroid Build Coastguard Worker                    r"|addmm: computation on (CPU|CUDA) is not implemented for "
5111*da0073e9SAndroid Build Coastguard Worker                    r"Strided \+ Sparse(Csc|Bsr|Bsc) @ Strided without MKL)"):
5112*da0073e9SAndroid Build Coastguard Worker                torch.autograd.gradcheck(mm, (x, y), fast_mode=fast_mode, masked=masked)
5113*da0073e9SAndroid Build Coastguard Worker            self.skipTest('NOT IMPL')
5114*da0073e9SAndroid Build Coastguard Worker        else:
5115*da0073e9SAndroid Build Coastguard Worker            torch.autograd.gradcheck(mm, (x, y), fast_mode=fast_mode, masked=masked)
5116*da0073e9SAndroid Build Coastguard Worker
5117*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
5118*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
5119*da0073e9SAndroid Build Coastguard Worker    @ops(binary_ufuncs_with_sparse_support)
5120*da0073e9SAndroid Build Coastguard Worker    @all_sparse_layouts('layout', include_strided=False)
5121*da0073e9SAndroid Build Coastguard Worker    def test_binary_operation(self, layout, device, dtype, op):
5122*da0073e9SAndroid Build Coastguard Worker        if not op.supports_sparse_layout(layout):
5123*da0073e9SAndroid Build Coastguard Worker            self.skipTest(f'{layout} is not supported in `{op.name}` OpInfo definition. Skipping!')
5124*da0073e9SAndroid Build Coastguard Worker
5125*da0073e9SAndroid Build Coastguard Worker        for sample in op.sample_inputs_sparse(layout, device, dtype):
5126*da0073e9SAndroid Build Coastguard Worker            if validate_sample_input_sparse(op, sample, check_validate=False) is not sample:
5127*da0073e9SAndroid Build Coastguard Worker                # that is, the validation returns the sparse sample
5128*da0073e9SAndroid Build Coastguard Worker                # wrapped within ErrorInput instance
5129*da0073e9SAndroid Build Coastguard Worker                continue
5130*da0073e9SAndroid Build Coastguard Worker            t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs
5131*da0073e9SAndroid Build Coastguard Worker            batch_dim = t_inp.dim() - t_inp.dense_dim() - t_inp.sparse_dim()
5132*da0073e9SAndroid Build Coastguard Worker            result = op.op(t_inp, *t_args, **t_kwargs)
5133*da0073e9SAndroid Build Coastguard Worker
5134*da0073e9SAndroid Build Coastguard Worker            # Check rop(inp, ...).shape == inp.shape
5135*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.shape, t_inp.shape)
5136*da0073e9SAndroid Build Coastguard Worker
5137*da0073e9SAndroid Build Coastguard Worker            # Check rop(inp, ...).sparse_dim() == inp.sparse_dim()
5138*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.sparse_dim(), t_inp.sparse_dim())
5139*da0073e9SAndroid Build Coastguard Worker
5140*da0073e9SAndroid Build Coastguard Worker            # Check rop(inp, ...).dense_dim() == inp.dense_dim()
5141*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.dense_dim(), t_inp.dense_dim())
5142*da0073e9SAndroid Build Coastguard Worker
5143*da0073e9SAndroid Build Coastguard Worker            # Check invariant rop(inp, ...).to_dense() == rop(inp.to_dense(), ...)
5144*da0073e9SAndroid Build Coastguard Worker            try:
5145*da0073e9SAndroid Build Coastguard Worker                dense = op.op(t_inp.to_dense(), *(t_args[0].to_dense(), *t_args[1:]), **t_kwargs)
5146*da0073e9SAndroid Build Coastguard Worker            except Exception as msg:
5147*da0073e9SAndroid Build Coastguard Worker                # this is strided op issue, so skipping the sample silently here
5148*da0073e9SAndroid Build Coastguard Worker                if "\"cpublas_axpy_impl\" not implemented for 'ComplexHalf'" in str(msg):
5149*da0073e9SAndroid Build Coastguard Worker                    continue
5150*da0073e9SAndroid Build Coastguard Worker                raise
5151*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, dense)
5152*da0073e9SAndroid Build Coastguard Worker
5153*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
5154*da0073e9SAndroid Build Coastguard Worker    @all_sparse_layouts('layout', include_strided=True)
5155*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
5156*da0073e9SAndroid Build Coastguard Worker    def test_to_sparse_identity(self, device, layout, dtype):
5157*da0073e9SAndroid Build Coastguard Worker        for dense_dim in range(4):
5158*da0073e9SAndroid Build Coastguard Worker            x_dense = torch.eye(dense_dim, dtype=dtype, device=device)
5159*da0073e9SAndroid Build Coastguard Worker            for sparse_dim_in in range(1, dense_dim):
5160*da0073e9SAndroid Build Coastguard Worker                x_sparse = x_dense.to_sparse(sparse_dim_in)
5161*da0073e9SAndroid Build Coastguard Worker                for sparse_dim_out in range(0, dense_dim):
5162*da0073e9SAndroid Build Coastguard Worker                    if sparse_dim_out == sparse_dim_in:
5163*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue(x_sparse.to_sparse(sparse_dim_out).sparse_dim() == sparse_dim_out)
5164*da0073e9SAndroid Build Coastguard Worker                    else:
5165*da0073e9SAndroid Build Coastguard Worker                        with self.assertRaisesRegex(
5166*da0073e9SAndroid Build Coastguard Worker                                RuntimeError,
5167*da0073e9SAndroid Build Coastguard Worker                                r"to_sparse: conversion from Sparse to Sparse with sparse_dim argument !=self.sparse_dim\(\)"
5168*da0073e9SAndroid Build Coastguard Worker                                " is not supported"):
5169*da0073e9SAndroid Build Coastguard Worker                            x_sparse.to_sparse(sparse_dim_out)
5170*da0073e9SAndroid Build Coastguard Worker
5171*da0073e9SAndroid Build Coastguard Worker
5172*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
5173*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
5174*da0073e9SAndroid Build Coastguard Worker    @ops(like_fns_with_sparse_support)
5175*da0073e9SAndroid Build Coastguard Worker    @all_sparse_layouts('layout', include_strided=False)
5176*da0073e9SAndroid Build Coastguard Worker    def test_like_fns(self, layout, device, dtype, op):
5177*da0073e9SAndroid Build Coastguard Worker
5178*da0073e9SAndroid Build Coastguard Worker        for sample in op.sample_inputs_sparse(layout, device, dtype):
5179*da0073e9SAndroid Build Coastguard Worker            t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs
5180*da0073e9SAndroid Build Coastguard Worker            batch_dim = t_inp.dim() - t_inp.dense_dim() - t_inp.sparse_dim()
5181*da0073e9SAndroid Build Coastguard Worker            if t_inp.layout in {torch.sparse_bsr, torch.sparse_bsc}:
5182*da0073e9SAndroid Build Coastguard Worker                expected_blocksize = t_inp.values().shape[batch_dim + 1:batch_dim + 3]
5183*da0073e9SAndroid Build Coastguard Worker            else:
5184*da0073e9SAndroid Build Coastguard Worker                expected_blocksize = None
5185*da0073e9SAndroid Build Coastguard Worker            expected_dtype = t_kwargs.get('dtype', dtype)
5186*da0073e9SAndroid Build Coastguard Worker            expected_device = torch.device(t_kwargs.get('device', device))
5187*da0073e9SAndroid Build Coastguard Worker            expected_layout = t_kwargs.get('layout', layout)
5188*da0073e9SAndroid Build Coastguard Worker
5189*da0073e9SAndroid Build Coastguard Worker            result = op.op(t_inp, *t_args, **t_kwargs)
5190*da0073e9SAndroid Build Coastguard Worker
5191*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.dtype, expected_dtype)
5192*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.device.type, expected_device.type)
5193*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.layout, expected_layout)
5194*da0073e9SAndroid Build Coastguard Worker
5195*da0073e9SAndroid Build Coastguard Worker            if result.layout in {torch.sparse_bsr, torch.sparse_bsc}:
5196*da0073e9SAndroid Build Coastguard Worker                result_batch_dim = result.dim() - result.dense_dim() - result.sparse_dim()
5197*da0073e9SAndroid Build Coastguard Worker                blocksize = result.values().shape[result_batch_dim + 1:result_batch_dim + 3]
5198*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(blocksize, expected_blocksize)
5199*da0073e9SAndroid Build Coastguard Worker
5200*da0073e9SAndroid Build Coastguard Worker            # Check op(inp).shape == inp.shape
5201*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.shape, t_inp.shape)
5202*da0073e9SAndroid Build Coastguard Worker
5203*da0073e9SAndroid Build Coastguard Worker            if expected_layout is torch.strided:
5204*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result.sparse_dim(), 0)
5205*da0073e9SAndroid Build Coastguard Worker                # Check op(inp, layout=torch.strided).dense_dim() == inp.dim()
5206*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result.dense_dim(), t_inp.dim())
5207*da0073e9SAndroid Build Coastguard Worker            elif expected_layout is torch.sparse_coo:
5208*da0073e9SAndroid Build Coastguard Worker                # Check op(inp, layout=torch.sparse_coo).sparse_dim() == batch_dim + inp.sparse_dim()
5209*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result.sparse_dim(), batch_dim + t_inp.sparse_dim())
5210*da0073e9SAndroid Build Coastguard Worker                # Check op(inp, layout=torch.sparse_coo).dense_dim() == inp.dense_dim()
5211*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result.dense_dim(), t_inp.dense_dim())
5212*da0073e9SAndroid Build Coastguard Worker
5213*da0073e9SAndroid Build Coastguard Worker                torch._validate_sparse_coo_tensor_args(result._indices(), result._values(), result.shape)
5214*da0073e9SAndroid Build Coastguard Worker            else:
5215*da0073e9SAndroid Build Coastguard Worker                # Check op(inp).sparse_dim() == inp.sparse_dim()
5216*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result.sparse_dim(), t_inp.sparse_dim())
5217*da0073e9SAndroid Build Coastguard Worker                # Check op(inp).dense_dim() == inp.dense_dim()
5218*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result.dense_dim(), t_inp.dense_dim())
5219*da0073e9SAndroid Build Coastguard Worker
5220*da0073e9SAndroid Build Coastguard Worker                if result.layout in {torch.sparse_csr, torch.sparse_bsr}:
5221*da0073e9SAndroid Build Coastguard Worker                    compressed_indices, plain_indices = result.crow_indices(), result.col_indices()
5222*da0073e9SAndroid Build Coastguard Worker                else:
5223*da0073e9SAndroid Build Coastguard Worker                    compressed_indices, plain_indices = result.ccol_indices(), result.row_indices()
5224*da0073e9SAndroid Build Coastguard Worker
5225*da0073e9SAndroid Build Coastguard Worker                torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, result.values(),
5226*da0073e9SAndroid Build Coastguard Worker                                                              result.shape, result.layout)
5227*da0073e9SAndroid Build Coastguard Worker
5228*da0073e9SAndroid Build Coastguard Worker    @all_sparse_layouts('mask_layout', include_strided=False)
5229*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
5230*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
5231*da0073e9SAndroid Build Coastguard Worker    def test_sparse_mask(self, mask_layout, device, dtype):
5232*da0073e9SAndroid Build Coastguard Worker        input_layout = torch.strided
5233*da0073e9SAndroid Build Coastguard Worker        mask_dtype = torch.bool
5234*da0073e9SAndroid Build Coastguard Worker        for mask in self.generate_simple_inputs(mask_layout, dtype=mask_dtype, device=device,
5235*da0073e9SAndroid Build Coastguard Worker                                                enable_hybrid=False, enable_batch=False):
5236*da0073e9SAndroid Build Coastguard Worker
5237*da0073e9SAndroid Build Coastguard Worker            x = make_tensor(mask.shape, dtype=dtype, device=device).to_sparse(layout=input_layout)
5238*da0073e9SAndroid Build Coastguard Worker
5239*da0073e9SAndroid Build Coastguard Worker            result = x.sparse_mask(mask)
5240*da0073e9SAndroid Build Coastguard Worker
5241*da0073e9SAndroid Build Coastguard Worker            # Check invariant `x.sparse_mask(mask).<indices> == mask.<indices>`
5242*da0073e9SAndroid Build Coastguard Worker            if mask_layout is torch.sparse_coo:
5243*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result._indices(), mask._indices())
5244*da0073e9SAndroid Build Coastguard Worker                ones = torch.sparse_coo_tensor(mask._indices(),
5245*da0073e9SAndroid Build Coastguard Worker                                               torch.ones_like(mask._values(), dtype=x.dtype),
5246*da0073e9SAndroid Build Coastguard Worker                                               mask.shape,
5247*da0073e9SAndroid Build Coastguard Worker                                               is_coalesced=mask.is_coalesced())
5248*da0073e9SAndroid Build Coastguard Worker            elif mask_layout in {torch.sparse_csr, torch.sparse_bsr}:
5249*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result.crow_indices(), mask.crow_indices())
5250*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result.col_indices(), mask.col_indices())
5251*da0073e9SAndroid Build Coastguard Worker                ones = torch.sparse_compressed_tensor(mask.crow_indices(), mask.col_indices(),
5252*da0073e9SAndroid Build Coastguard Worker                                                      torch.ones_like(mask.values(), dtype=x.dtype),
5253*da0073e9SAndroid Build Coastguard Worker                                                      mask.shape, layout=mask.layout)
5254*da0073e9SAndroid Build Coastguard Worker            else:
5255*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result.ccol_indices(), mask.ccol_indices())
5256*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result.row_indices(), mask.row_indices())
5257*da0073e9SAndroid Build Coastguard Worker                ones = torch.sparse_compressed_tensor(mask.ccol_indices(), mask.row_indices(),
5258*da0073e9SAndroid Build Coastguard Worker                                                      torch.ones_like(mask.values(), dtype=x.dtype),
5259*da0073e9SAndroid Build Coastguard Worker                                                      mask.shape, layout=mask.layout)
5260*da0073e9SAndroid Build Coastguard Worker
5261*da0073e9SAndroid Build Coastguard Worker            # Check invariant:
5262*da0073e9SAndroid Build Coastguard Worker            #  x.sparse_mask(mask).to_dense() == x.mul(sparse_xyz_tensor(<mask indices>,
5263*da0073e9SAndroid Build Coastguard Worker            #                                          ones_like(<mask values>)).to_dense())
5264*da0073e9SAndroid Build Coastguard Worker            expected = x.mul(ones.to_dense())
5265*da0073e9SAndroid Build Coastguard Worker
5266*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.to_dense(), expected)
5267*da0073e9SAndroid Build Coastguard Worker
5268*da0073e9SAndroid Build Coastguard Worker            # Check invariant `mask.to_dense().sparse_mask(mask) == mask`
5269*da0073e9SAndroid Build Coastguard Worker            result = mask.to_dense().sparse_mask(mask)
5270*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, mask)
5271*da0073e9SAndroid Build Coastguard Worker
5272*da0073e9SAndroid Build Coastguard Worker    @all_sparse_layouts('layout', include_strided=False)
5273*da0073e9SAndroid Build Coastguard Worker    @parametrize("masked", [subtest(False, name='nonmasked'), subtest(True, name='masked')])
5274*da0073e9SAndroid Build Coastguard Worker    @parametrize("fast_mode", [subtest(False, name='slow'), subtest(True, name='fast')])
5275*da0073e9SAndroid Build Coastguard Worker    def test_as_sparse_gradcheck(self, layout, device, masked, fast_mode):
5276*da0073e9SAndroid Build Coastguard Worker        gradcheck = torch.sparse.as_sparse_gradcheck(torch.autograd.gradcheck)
5277*da0073e9SAndroid Build Coastguard Worker        sparse_compressed_layouts = {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}
5278*da0073e9SAndroid Build Coastguard Worker
5279*da0073e9SAndroid Build Coastguard Worker        def identity(x):
5280*da0073e9SAndroid Build Coastguard Worker            return x
5281*da0073e9SAndroid Build Coastguard Worker
5282*da0073e9SAndroid Build Coastguard Worker        for func in (torch.Tensor.to_dense,
5283*da0073e9SAndroid Build Coastguard Worker                     torch.Tensor.sum,
5284*da0073e9SAndroid Build Coastguard Worker                     identity,
5285*da0073e9SAndroid Build Coastguard Worker                     torch.Tensor.to_sparse,
5286*da0073e9SAndroid Build Coastguard Worker                     torch.Tensor.values,
5287*da0073e9SAndroid Build Coastguard Worker                     ):
5288*da0073e9SAndroid Build Coastguard Worker            for x in self.generate_simple_inputs(
5289*da0073e9SAndroid Build Coastguard Worker                    layout,
5290*da0073e9SAndroid Build Coastguard Worker                    device=device,
5291*da0073e9SAndroid Build Coastguard Worker                    dtype=torch.float64,
5292*da0073e9SAndroid Build Coastguard Worker                    # TODO: fix gh-104868  to enable batched samples:
5293*da0073e9SAndroid Build Coastguard Worker                    enable_batch=layout not in sparse_compressed_layouts,
5294*da0073e9SAndroid Build Coastguard Worker                    enable_hybrid=not (
5295*da0073e9SAndroid Build Coastguard Worker                        layout in sparse_compressed_layouts and (
5296*da0073e9SAndroid Build Coastguard Worker                            # FIXME: RuntimeError: sparse_mask(): the
5297*da0073e9SAndroid Build Coastguard Worker                            # number of sparse dimensions in `self`
5298*da0073e9SAndroid Build Coastguard Worker                            # should match that of the `mask`. Got
5299*da0073e9SAndroid Build Coastguard Worker                            # `self.sparse_dim() == 3` !=
5300*da0073e9SAndroid Build Coastguard Worker                            # `mask.sparse_dim() == 2
5301*da0073e9SAndroid Build Coastguard Worker                            func.__name__ == 'sum'
5302*da0073e9SAndroid Build Coastguard Worker                            # FIXME: RuntimeError: expected
5303*da0073e9SAndroid Build Coastguard Worker                            # col_indices to be a contiguous tensor
5304*da0073e9SAndroid Build Coastguard Worker                            # per batch
5305*da0073e9SAndroid Build Coastguard Worker                            or func.__name__ == 'to_sparse'
5306*da0073e9SAndroid Build Coastguard Worker                        ))):
5307*da0073e9SAndroid Build Coastguard Worker                if layout is torch.sparse_coo and func.__name__ == 'values':
5308*da0073e9SAndroid Build Coastguard Worker                    x = x.coalesce()
5309*da0073e9SAndroid Build Coastguard Worker
5310*da0073e9SAndroid Build Coastguard Worker                gradcheck(func, x.requires_grad_(True), masked=masked, fast_mode=fast_mode)
5311*da0073e9SAndroid Build Coastguard Worker
5312*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
5313*da0073e9SAndroid Build Coastguard Worker    @all_sparse_layouts('layout', include_strided=False)
5314*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
5315*da0073e9SAndroid Build Coastguard Worker    def test_dataloader(self, device, layout, dtype):
5316*da0073e9SAndroid Build Coastguard Worker
5317*da0073e9SAndroid Build Coastguard Worker        data = list(self.generate_simple_inputs(layout, device=device, dtype=dtype))
5318*da0073e9SAndroid Build Coastguard Worker
5319*da0073e9SAndroid Build Coastguard Worker        dataset = _SparseDataset(data)
5320*da0073e9SAndroid Build Coastguard Worker        loader = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=2)
5321*da0073e9SAndroid Build Coastguard Worker
5322*da0073e9SAndroid Build Coastguard Worker        loaded_data = list(loader)
5323*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(data, loaded_data)
5324*da0073e9SAndroid Build Coastguard Worker
5325*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
5326*da0073e9SAndroid Build Coastguard Worker    def test_invalid_blocksize(self):
5327*da0073e9SAndroid Build Coastguard Worker        # Blocksize should be a tuple/list/torch.Size containing two values
5328*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, ".*blocksize.*, but got 1"):
5329*da0073e9SAndroid Build Coastguard Worker            torch.randn(1).to_sparse(blocksize=(1,))
5330*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, ".*blocksize.*, but got 1"):
5331*da0073e9SAndroid Build Coastguard Worker            torch.randn(1).to_sparse(blocksize=[1])
5332*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, ".*blocksize.*, but got 1"):
5333*da0073e9SAndroid Build Coastguard Worker            torch.randn(1).to_sparse(blocksize=torch.Size((1,)))
5334*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, ".*blocksize.*, but got 3"):
5335*da0073e9SAndroid Build Coastguard Worker            torch.randn(1).to_sparse(blocksize=(1, 1, 1))
5336*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, ".*blocksize.*, but got 3"):
5337*da0073e9SAndroid Build Coastguard Worker            torch.randn(1).to_sparse(blocksize=[1, 1, 1])
5338*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, ".*blocksize.*, but got 3"):
5339*da0073e9SAndroid Build Coastguard Worker            torch.randn(1).to_sparse(blocksize=torch.Size((1, 1, 1)))
5340*da0073e9SAndroid Build Coastguard Worker
5341*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.cuda.is_available(), 'requires cuda')
5342*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
5343*da0073e9SAndroid Build Coastguard Worker    @all_sparse_layouts('layout', include_strided=True)
5344*da0073e9SAndroid Build Coastguard Worker    def test_constructor_pin_memory(self, device, layout):
5345*da0073e9SAndroid Build Coastguard Worker        """Tests sparse_xyz_tensor(indices, values, pin_memory=True)
5346*da0073e9SAndroid Build Coastguard Worker        """
5347*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(device, "cpu")
5348*da0073e9SAndroid Build Coastguard Worker        for t in self.generate_simple_inputs(
5349*da0073e9SAndroid Build Coastguard Worker                layout, device=device, dtype=torch.float64,
5350*da0073e9SAndroid Build Coastguard Worker                enable_zero_sized=False,  # pinning zero-sized tensors is a no-op
5351*da0073e9SAndroid Build Coastguard Worker                pin_memory=True,
5352*da0073e9SAndroid Build Coastguard Worker                enable_batch=False,  # TODO: remove after gh-104868 is resolved
5353*da0073e9SAndroid Build Coastguard Worker        ):
5354*da0073e9SAndroid Build Coastguard Worker            if layout is torch.sparse_coo:
5355*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(t._indices().is_pinned())
5356*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(t._values().is_pinned())
5357*da0073e9SAndroid Build Coastguard Worker            elif layout in {torch.sparse_csr, torch.sparse_bsr}:
5358*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(t.crow_indices().is_pinned())
5359*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(t.col_indices().is_pinned())
5360*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(t.values().is_pinned())
5361*da0073e9SAndroid Build Coastguard Worker            elif layout in {torch.sparse_csc, torch.sparse_bsc}:
5362*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(t.ccol_indices().is_pinned())
5363*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(t.row_indices().is_pinned())
5364*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(t.values().is_pinned())
5365*da0073e9SAndroid Build Coastguard Worker            elif layout is torch.strided:
5366*da0073e9SAndroid Build Coastguard Worker                pass
5367*da0073e9SAndroid Build Coastguard Worker            else:
5368*da0073e9SAndroid Build Coastguard Worker                assert 0  # unreachable
5369*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(t.is_pinned())
5370*da0073e9SAndroid Build Coastguard Worker
5371*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.cuda.is_available(), 'requires cuda')
5372*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
5373*da0073e9SAndroid Build Coastguard Worker    @all_sparse_layouts('layout', include_strided=True)
5374*da0073e9SAndroid Build Coastguard Worker    def test_method_pin_memory(self, device, layout):
5375*da0073e9SAndroid Build Coastguard Worker        """Tests sparse_xyz_tensor(indices, values, pin_memory=False).pin_memory()
5376*da0073e9SAndroid Build Coastguard Worker        """
5377*da0073e9SAndroid Build Coastguard Worker
5378*da0073e9SAndroid Build Coastguard Worker        for t_ in self.generate_simple_inputs(
5379*da0073e9SAndroid Build Coastguard Worker                layout, device=device, dtype=torch.float64,
5380*da0073e9SAndroid Build Coastguard Worker                enable_zero_sized=False,  # pinning zero-sized tensors is a no-op
5381*da0073e9SAndroid Build Coastguard Worker                pin_memory=False,         # no pinning
5382*da0073e9SAndroid Build Coastguard Worker                enable_batch=False,  # TODO: remove after gh-104868 is resolved
5383*da0073e9SAndroid Build Coastguard Worker        ):
5384*da0073e9SAndroid Build Coastguard Worker            t = t_.pin_memory()
5385*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(t.is_pinned())
5386*da0073e9SAndroid Build Coastguard Worker
5387*da0073e9SAndroid Build Coastguard Worker            # registering a non-pinned tensor with CUDA memory is a
5388*da0073e9SAndroid Build Coastguard Worker            # clone operation
5389*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(t_.is_pinned())
5390*da0073e9SAndroid Build Coastguard Worker
5391*da0073e9SAndroid Build Coastguard Worker            # registering already pinned tensor with CUDA memory is an
5392*da0073e9SAndroid Build Coastguard Worker            # identity operation:
5393*da0073e9SAndroid Build Coastguard Worker            t2 = t.pin_memory()
5394*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(t2 is t)
5395*da0073e9SAndroid Build Coastguard Worker
5396*da0073e9SAndroid Build Coastguard Worker            if layout is torch.sparse_coo:
5397*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(t._indices().is_pinned())
5398*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(t._values().is_pinned())
5399*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(t_._indices().is_pinned())
5400*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(t_._values().is_pinned())
5401*da0073e9SAndroid Build Coastguard Worker            elif layout in {torch.sparse_csr, torch.sparse_bsr}:
5402*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(t.crow_indices().is_pinned())
5403*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(t.col_indices().is_pinned())
5404*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(t.values().is_pinned())
5405*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(t_.crow_indices().is_pinned())
5406*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(t_.col_indices().is_pinned())
5407*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(t_.values().is_pinned())
5408*da0073e9SAndroid Build Coastguard Worker            elif layout in {torch.sparse_csc, torch.sparse_bsc}:
5409*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(t.ccol_indices().is_pinned())
5410*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(t.row_indices().is_pinned())
5411*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(t.values().is_pinned())
5412*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(t_.ccol_indices().is_pinned())
5413*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(t_.row_indices().is_pinned())
5414*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(t_.values().is_pinned())
5415*da0073e9SAndroid Build Coastguard Worker            elif layout is torch.strided:
5416*da0073e9SAndroid Build Coastguard Worker                pass
5417*da0073e9SAndroid Build Coastguard Worker            else:
5418*da0073e9SAndroid Build Coastguard Worker                assert 0  # unreachable
5419*da0073e9SAndroid Build Coastguard Worker
5420*da0073e9SAndroid Build Coastguard Worker
5421*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.cuda.is_available(), 'requires cuda')
5422*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
5423*da0073e9SAndroid Build Coastguard Worker    @all_sparse_layouts('layout', include_strided=True)
5424*da0073e9SAndroid Build Coastguard Worker    def test_constructor_pinned_memory(self, device, layout):
5425*da0073e9SAndroid Build Coastguard Worker        """Tests sparse_xyz_tensor(indices.pin_memory(device), values.pin_memory(device))
5426*da0073e9SAndroid Build Coastguard Worker        """
5427*da0073e9SAndroid Build Coastguard Worker        pin_memory_device = "cuda"
5428*da0073e9SAndroid Build Coastguard Worker        for t in self.generate_simple_inputs(
5429*da0073e9SAndroid Build Coastguard Worker                layout, device=device, dtype=torch.float64,
5430*da0073e9SAndroid Build Coastguard Worker                enable_zero_sized=False,     # pinning zero-sized tensors is a no-op
5431*da0073e9SAndroid Build Coastguard Worker                pin_memory=None,             # constructor does not specify pin_memory=...
5432*da0073e9SAndroid Build Coastguard Worker                members_pin_memory=True,     # indices and values are pinned
5433*da0073e9SAndroid Build Coastguard Worker                enable_batch=False,          # TODO: remove after gh-104868 is resolved
5434*da0073e9SAndroid Build Coastguard Worker        ):
5435*da0073e9SAndroid Build Coastguard Worker            if layout is torch.sparse_coo:
5436*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(t._indices().is_pinned())
5437*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(t._values().is_pinned())
5438*da0073e9SAndroid Build Coastguard Worker            elif layout in {torch.sparse_csr, torch.sparse_bsr}:
5439*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(t.crow_indices().is_pinned())
5440*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(t.col_indices().is_pinned())
5441*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(t.values().is_pinned())
5442*da0073e9SAndroid Build Coastguard Worker            elif layout in {torch.sparse_csc, torch.sparse_bsc}:
5443*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(t.ccol_indices().is_pinned())
5444*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(t.row_indices().is_pinned())
5445*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(t.values().is_pinned())
5446*da0073e9SAndroid Build Coastguard Worker            elif layout is torch.strided:
5447*da0073e9SAndroid Build Coastguard Worker                pass
5448*da0073e9SAndroid Build Coastguard Worker            else:
5449*da0073e9SAndroid Build Coastguard Worker                assert 0  # unreachable
5450*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(t.is_pinned())
5451*da0073e9SAndroid Build Coastguard Worker
5452*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.cuda.is_available(), 'requires cuda')
5453*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
5454*da0073e9SAndroid Build Coastguard Worker    @all_sparse_layouts('layout', include_strided=False)
5455*da0073e9SAndroid Build Coastguard Worker    def test_constructor_mismatched_pinned_memory(self, device, layout):
5456*da0073e9SAndroid Build Coastguard Worker        """Test the failure to construct sparse tensor from indices and values
5457*da0073e9SAndroid Build Coastguard Worker        that have different pinning states.
5458*da0073e9SAndroid Build Coastguard Worker        """
5459*da0073e9SAndroid Build Coastguard Worker        def generic_constructor(*args, **kwargs):
5460*da0073e9SAndroid Build Coastguard Worker            if layout in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}:
5461*da0073e9SAndroid Build Coastguard Worker                kwargs.update(layout=layout)
5462*da0073e9SAndroid Build Coastguard Worker                return torch.sparse_compressed_tensor(*args, **kwargs)
5463*da0073e9SAndroid Build Coastguard Worker            elif layout is torch.sparse_coo:
5464*da0073e9SAndroid Build Coastguard Worker                return torch.sparse_coo_tensor(*args, **kwargs)
5465*da0073e9SAndroid Build Coastguard Worker            else:
5466*da0073e9SAndroid Build Coastguard Worker                raise NotImplementedError(layout)
5467*da0073e9SAndroid Build Coastguard Worker
5468*da0073e9SAndroid Build Coastguard Worker        for args, kwargs in self.generate_simple_inputs(
5469*da0073e9SAndroid Build Coastguard Worker                layout, device=device, dtype=torch.float64,
5470*da0073e9SAndroid Build Coastguard Worker                enable_zero_sized=False,     # pinning zero-sized tensors is a no-op
5471*da0073e9SAndroid Build Coastguard Worker                enable_batch=False,  # TODO: remove after gh-104868 is resolved
5472*da0073e9SAndroid Build Coastguard Worker                output_tensor=False):
5473*da0073e9SAndroid Build Coastguard Worker
5474*da0073e9SAndroid Build Coastguard Worker            # indices are pinned, values is a non-pinned tensor
5475*da0073e9SAndroid Build Coastguard Worker            args1 = (args[0].pin_memory(), *args[1:])
5476*da0073e9SAndroid Build Coastguard Worker
5477*da0073e9SAndroid Build Coastguard Worker            # indices are non-pinned, values is a pinned tensor
5478*da0073e9SAndroid Build Coastguard Worker            args2 = (*args[:-1], args[-1].pin_memory())
5479*da0073e9SAndroid Build Coastguard Worker
5480*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
5481*da0073e9SAndroid Build Coastguard Worker                    RuntimeError, r"memory pinning of \w*indices \(=1\) must match memory pinning of values \(=0\)"):
5482*da0073e9SAndroid Build Coastguard Worker                generic_constructor(*args1, **kwargs)
5483*da0073e9SAndroid Build Coastguard Worker
5484*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
5485*da0073e9SAndroid Build Coastguard Worker                    RuntimeError, r"memory pinning of \w*indices \(=0\) must match memory pinning of values \(=1\)"):
5486*da0073e9SAndroid Build Coastguard Worker                generic_constructor(*args2, **kwargs)
5487*da0073e9SAndroid Build Coastguard Worker
5488*da0073e9SAndroid Build Coastguard Worker
5489*da0073e9SAndroid Build Coastguard Worker# e.g., TestSparseUnaryUfuncsCPU and TestSparseUnaryUfuncsCUDA
5490*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestSparseUnaryUfuncs, globals(), except_for='meta')
5491*da0073e9SAndroid Build Coastguard Worker
5492*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestSparseMaskedReductions, globals(), except_for='meta')
5493*da0073e9SAndroid Build Coastguard Worker
5494*da0073e9SAndroid Build Coastguard Worker# e.g., TestSparseCPU and TestSparseCUDA
5495*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestSparse, globals(), except_for='meta')
5496*da0073e9SAndroid Build Coastguard Worker
5497*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestSparseAny, globals(), except_for='meta')
5498*da0073e9SAndroid Build Coastguard Worker
5499*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestSparseMeta)
5500*da0073e9SAndroid Build Coastguard Worker
5501*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestSparseLegacyAndDeprecation)
5502*da0073e9SAndroid Build Coastguard Worker
5503*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
5504*da0073e9SAndroid Build Coastguard Worker    run_tests()
5505