1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: sparse"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport torch 4*da0073e9SAndroid Build Coastguard Workerimport itertools 5*da0073e9SAndroid Build Coastguard Workerimport functools 6*da0073e9SAndroid Build Coastguard Workerimport operator 7*da0073e9SAndroid Build Coastguard Workerimport random 8*da0073e9SAndroid Build Coastguard Workerimport unittest 9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor 10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \ 11*da0073e9SAndroid Build Coastguard Worker load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, \ 12*da0073e9SAndroid Build Coastguard Worker DeterministicGuard, first_sample, TEST_WITH_CROSSREF, TEST_WITH_ROCM, skipIfTorchDynamo, \ 13*da0073e9SAndroid Build Coastguard Worker parametrize, subtest, is_coalesced_indices, suppress_warnings, instantiate_parametrized_tests, \ 14*da0073e9SAndroid Build Coastguard Worker skipIfCrossRef 15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA 16*da0073e9SAndroid Build Coastguard Workerfrom numbers import Number 17*da0073e9SAndroid Build Coastguard Workerfrom typing import Dict, Any 18*da0073e9SAndroid Build Coastguard Workerfrom packaging import version 19*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import \ 20*da0073e9SAndroid Build Coastguard Worker (SM53OrLater, SM80OrLater, TEST_MULTIGPU) 21*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import \ 22*da0073e9SAndroid Build Coastguard Worker (instantiate_device_type_tests, ops, dtypes, dtypesIfCUDA, onlyCPU, onlyCUDA, precisionOverride, 23*da0073e9SAndroid Build Coastguard Worker deviceCountAtLeast, OpDTypes, onlyNativeDeviceTypes) 24*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import \ 25*da0073e9SAndroid Build Coastguard Worker (op_db, reduction_ops, sparse_unary_ufuncs, sparse_masked_reduction_ops, binary_ufuncs) 26*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import ( 27*da0073e9SAndroid Build Coastguard Worker all_types, all_types_and_complex, all_types_and_complex_and, floating_and_complex_types, 28*da0073e9SAndroid Build Coastguard Worker floating_and_complex_types_and, integral_types, floating_types_and, 29*da0073e9SAndroid Build Coastguard Worker) 30*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.opinfo.definitions.sparse import validate_sample_input_sparse 31*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.opinfo.refs import ( 32*da0073e9SAndroid Build Coastguard Worker ElementwiseBinaryPythonRefInfo, 33*da0073e9SAndroid Build Coastguard Worker ReductionPythonRefInfo 34*da0073e9SAndroid Build Coastguard Worker) 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Workerdef _op_supports_any_sparse(op): 37*da0073e9SAndroid Build Coastguard Worker return (op.supports_sparse 38*da0073e9SAndroid Build Coastguard Worker or op.supports_sparse_csr 39*da0073e9SAndroid Build Coastguard Worker or op.supports_sparse_csc 40*da0073e9SAndroid Build Coastguard Worker or op.supports_sparse_bsr 41*da0073e9SAndroid Build Coastguard Worker or op.supports_sparse_bsc) 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Workerreduction_ops_with_sparse_support = [ 46*da0073e9SAndroid Build Coastguard Worker op for op in reduction_ops if 'masked.' not in op.name and 47*da0073e9SAndroid Build Coastguard Worker _op_supports_any_sparse(op) and not isinstance(op, ReductionPythonRefInfo)] 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Workerbinary_ufuncs_with_sparse_support = [ 50*da0073e9SAndroid Build Coastguard Worker op for op in binary_ufuncs if _op_supports_any_sparse(op) and 51*da0073e9SAndroid Build Coastguard Worker not isinstance(op, ElementwiseBinaryPythonRefInfo)] 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Workerlike_fns_with_sparse_support = [op for op in op_db if _op_supports_any_sparse(op) and '_like' in op.name] 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Workerif TEST_SCIPY: 56*da0073e9SAndroid Build Coastguard Worker import scipy.sparse 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for 59*da0073e9SAndroid Build Coastguard Worker# sharding on sandcastle. This line silences flake warnings 60*da0073e9SAndroid Build Coastguard Workerload_tests = load_tests 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker# batched grad doesn't support sparse 63*da0073e9SAndroid Build Coastguard Workergradcheck = functools.partial(gradcheck, check_batched_grad=False) 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard WorkerCUSPARSE_SPMM_COMPLEX128_SUPPORTED = ( 66*da0073e9SAndroid Build Coastguard Worker IS_WINDOWS and torch.version.cuda and version.parse(torch.version.cuda) > version.parse("11.2") 67*da0073e9SAndroid Build Coastguard Worker) or (not IS_WINDOWS and not TEST_WITH_ROCM) 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard WorkerHIPSPARSE_SPMM_COMPLEX128_SUPPORTED = torch.version.hip and version.parse(torch.version.hip.split("-")[0]) >= version.parse("6.0") 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Workerdef all_sparse_layouts(test_name='layout', include_strided=False): 72*da0073e9SAndroid Build Coastguard Worker return parametrize(test_name, [ 73*da0073e9SAndroid Build Coastguard Worker subtest(torch.strided, name='Strided'), 74*da0073e9SAndroid Build Coastguard Worker subtest(torch.sparse_coo, name='SparseCOO'), 75*da0073e9SAndroid Build Coastguard Worker subtest(torch.sparse_csr, name='SparseCSR'), 76*da0073e9SAndroid Build Coastguard Worker subtest(torch.sparse_csc, name='SparseCSC'), 77*da0073e9SAndroid Build Coastguard Worker subtest(torch.sparse_bsr, name='SparseBSR'), 78*da0073e9SAndroid Build Coastguard Worker subtest(torch.sparse_bsc, name='SparseBSC'), 79*da0073e9SAndroid Build Coastguard Worker ][(0 if include_strided else 1):]) 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Workerdef gradcheck_semantics(test_name='gradcheck'): 82*da0073e9SAndroid Build Coastguard Worker gradcheck_sparse = functools.partial(gradcheck, masked=False) 83*da0073e9SAndroid Build Coastguard Worker gradcheck_masked = functools.partial(gradcheck, masked=True) 84*da0073e9SAndroid Build Coastguard Worker gradcheck_sparse.masked = False 85*da0073e9SAndroid Build Coastguard Worker gradcheck_masked.masked = True 86*da0073e9SAndroid Build Coastguard Worker return parametrize(test_name, [ 87*da0073e9SAndroid Build Coastguard Worker subtest(gradcheck_sparse, name='sparse'), 88*da0073e9SAndroid Build Coastguard Worker subtest(gradcheck_masked, name='masked')]) 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Workerclass CrossRefSparseFakeMode(torch._subclasses.CrossRefFakeMode): 92*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 93*da0073e9SAndroid Build Coastguard Worker super().__init__( 94*da0073e9SAndroid Build Coastguard Worker self.ignore_op, check_strides=False, 95*da0073e9SAndroid Build Coastguard Worker check_aliasing=False, 96*da0073e9SAndroid Build Coastguard Worker ) # TODO: enable stride/alias checking 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker # empty_like excluded for now due to sparse complex 99*da0073e9SAndroid Build Coastguard Worker # aten._to_dense.default this one is getting called with csc 100*da0073e9SAndroid Build Coastguard Worker @staticmethod 101*da0073e9SAndroid Build Coastguard Worker def ignore_op(func): 102*da0073e9SAndroid Build Coastguard Worker return func in ( 103*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.empty_like.default, 104*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.set_.source_Storage_storage_offset, 105*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.sspaddmm.out, 106*da0073e9SAndroid Build Coastguard Worker torch.ops.aten._spdiags.default, 107*da0073e9SAndroid Build Coastguard Worker torch.ops.aten._to_dense.default, 108*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.indices.default, 109*da0073e9SAndroid Build Coastguard Worker torch.ops.aten._indices.default, 110*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.values.default, 111*da0073e9SAndroid Build Coastguard Worker torch.ops.aten._values.default, 112*da0073e9SAndroid Build Coastguard Worker ) 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Workerclass TestSparseLegacyAndDeprecation(TestCase): 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 117*da0073e9SAndroid Build Coastguard Worker def test_legacy_warnings(self): 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker def f1(): 120*da0073e9SAndroid Build Coastguard Worker "torch.sparse.SparseTensor() is deprecated."\ 121*da0073e9SAndroid Build Coastguard Worker " Please use torch.sparse_coo_tensor((0,), dtype=)" 122*da0073e9SAndroid Build Coastguard Worker x_ref = torch.sparse_coo_tensor((0,), dtype=torch.float64) 123*da0073e9SAndroid Build Coastguard Worker x = torch.sparse.DoubleTensor() 124*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x_ref) 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker def f2(): 127*da0073e9SAndroid Build Coastguard Worker "torch.sparse.SparseTensor(cdata=x._cdata) is deprecated."\ 128*da0073e9SAndroid Build Coastguard Worker " Please use torch.sparse_coo_tensor(x._indices(), x._values(), x.shape)" 129*da0073e9SAndroid Build Coastguard Worker x_ref = torch.tensor([[1, 2], [3, 4]], dtype=torch.float64).to_sparse() 130*da0073e9SAndroid Build Coastguard Worker x = torch.sparse.DoubleTensor(cdata=x_ref._cdata) 131*da0073e9SAndroid Build Coastguard Worker y = torch.sparse_coo_tensor(x._indices(), x._values(), x.shape) 132*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x_ref) 133*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, x_ref) 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Worker def f3(): 136*da0073e9SAndroid Build Coastguard Worker "torch.sparse.SparseTensor(indices, values, *, device=) is deprecated."\ 137*da0073e9SAndroid Build Coastguard Worker " Please use torch.sparse_coo_tensor(indices, values, dtype=, device=)" 138*da0073e9SAndroid Build Coastguard Worker x_ref = torch.sparse_coo_tensor([[0, 0, 1, 1], [0, 1, 0, 1]], [1, 2, 3, 4], dtype=torch.float64) 139*da0073e9SAndroid Build Coastguard Worker x = torch.sparse.DoubleTensor(torch.tensor([[0, 0, 1, 1], [0, 1, 0, 1]]), 140*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 2, 3, 4], dtype=torch.float64)) 141*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x_ref) 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker def f4(): 144*da0073e9SAndroid Build Coastguard Worker "torch.sparse.SparseTensor(indices, values, shape, *, device=) is deprecated."\ 145*da0073e9SAndroid Build Coastguard Worker " Please use torch.sparse_coo_tensor(indices, values, shape, dtype=, device=)" 146*da0073e9SAndroid Build Coastguard Worker x_ref = torch.sparse_coo_tensor([[0, 0, 1, 1], [0, 1, 0, 1]], [1, 2, 3, 4], (2, 3), dtype=torch.float64) 147*da0073e9SAndroid Build Coastguard Worker x = torch.sparse.DoubleTensor(torch.tensor([[0, 0, 1, 1], [0, 1, 0, 1]]), 148*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 2, 3, 4], dtype=torch.float64), (2, 3)) 149*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x_ref) 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker def f5(): 152*da0073e9SAndroid Build Coastguard Worker "torch.sparse.SparseTensor(shape, *, device=) is deprecated."\ 153*da0073e9SAndroid Build Coastguard Worker " Please use torch.sparse_coo_tensor(shape, dtype=, device=)" 154*da0073e9SAndroid Build Coastguard Worker x_ref = torch.sparse_coo_tensor((2, 3), dtype=torch.float64) 155*da0073e9SAndroid Build Coastguard Worker x = torch.sparse.DoubleTensor(2, 3) 156*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x_ref) 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker for test_f in [f1, f2, f3, f4, f5]: 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker with self.assertWarns(UserWarning, msg=test_f.__doc__) as cm: 161*da0073e9SAndroid Build Coastguard Worker test_f() 162*da0073e9SAndroid Build Coastguard Worker test_f() 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Worker # Check warn-once: 165*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(cm.warnings), 1) 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Workerclass TestSparseBase(TestCase): 169*da0073e9SAndroid Build Coastguard Worker def run(self, result=None): 170*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_CROSSREF: 171*da0073e9SAndroid Build Coastguard Worker with CrossRefSparseFakeMode(): 172*da0073e9SAndroid Build Coastguard Worker return super().run(result) 173*da0073e9SAndroid Build Coastguard Worker else: 174*da0073e9SAndroid Build Coastguard Worker return super().run(result) 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Workerclass TestSparse(TestSparseBase): 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker def setUp(self): 179*da0073e9SAndroid Build Coastguard Worker TestCase.setUp(self) 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard Worker self.index_tensor = lambda *args, **kwargs: torch.tensor(*args, **kwargs, dtype=torch.int64) 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker def sparse_empty_factory(*args, **kwargs): 184*da0073e9SAndroid Build Coastguard Worker kwargs['layout'] = kwargs.get('layout', torch.sparse_coo) 185*da0073e9SAndroid Build Coastguard Worker return torch.empty(*args, **kwargs) 186*da0073e9SAndroid Build Coastguard Worker self.sparse_empty = sparse_empty_factory 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker def sparse_tensor_factory(*args, **kwargs): 189*da0073e9SAndroid Build Coastguard Worker return torch.sparse_coo_tensor(*args, **kwargs) 190*da0073e9SAndroid Build Coastguard Worker self.sparse_tensor = sparse_tensor_factory 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Worker def _gen_sparse(self, sparse_dim, nnz, with_size, dtype, device, coalesced): 193*da0073e9SAndroid Build Coastguard Worker if isinstance(with_size, Number): 194*da0073e9SAndroid Build Coastguard Worker with_size = [with_size] * sparse_dim 195*da0073e9SAndroid Build Coastguard Worker 196*da0073e9SAndroid Build Coastguard Worker x, i, v = self.genSparseTensor(with_size, sparse_dim, nnz, not coalesced, dtype=dtype, device=device) 197*da0073e9SAndroid Build Coastguard Worker 198*da0073e9SAndroid Build Coastguard Worker if not coalesced: 199*da0073e9SAndroid Build Coastguard Worker self.assert_uncoalesced(x) 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker return x, i, v 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Worker def assert_uncoalesced(self, x): 204*da0073e9SAndroid Build Coastguard Worker """ 205*da0073e9SAndroid Build Coastguard Worker Test if a CPU tensor is uncoalesced. This is used to ensure 206*da0073e9SAndroid Build Coastguard Worker correctness of the uncoalesced tensor generation algorithm. 207*da0073e9SAndroid Build Coastguard Worker """ 208*da0073e9SAndroid Build Coastguard Worker assert not x.is_coalesced() 209*da0073e9SAndroid Build Coastguard Worker existing_indices = set() 210*da0073e9SAndroid Build Coastguard Worker indices = x._indices() 211*da0073e9SAndroid Build Coastguard Worker for i in range(x._nnz()): 212*da0073e9SAndroid Build Coastguard Worker index = str(indices[:, i]) 213*da0073e9SAndroid Build Coastguard Worker if index in existing_indices: 214*da0073e9SAndroid Build Coastguard Worker return True 215*da0073e9SAndroid Build Coastguard Worker else: 216*da0073e9SAndroid Build Coastguard Worker existing_indices.add(index) 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker def randn(self, *args, **kwargs): 219*da0073e9SAndroid Build Coastguard Worker """ 220*da0073e9SAndroid Build Coastguard Worker Variant of torch.randn that also works in the TEST_CUDA case. 221*da0073e9SAndroid Build Coastguard Worker """ 222*da0073e9SAndroid Build Coastguard Worker # TODO: Put this in torch.cuda.randn 223*da0073e9SAndroid Build Coastguard Worker return torch.empty(*args, **kwargs).normal_() 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 226*da0073e9SAndroid Build Coastguard Worker def test_print_coalesced(self, device, dtype): 227*da0073e9SAndroid Build Coastguard Worker self._test_print(device, dtype, True) 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 230*da0073e9SAndroid Build Coastguard Worker def test_print_uncoalesced(self, device, dtype): 231*da0073e9SAndroid Build Coastguard Worker self._test_print(device, dtype, False) 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Worker def _test_print(self, device, dtype, coalesced): 234*da0073e9SAndroid Build Coastguard Worker shape_sparse_dim_nnz = [ 235*da0073e9SAndroid Build Coastguard Worker ((), 0, 2), 236*da0073e9SAndroid Build Coastguard Worker ((0,), 0, 10), 237*da0073e9SAndroid Build Coastguard Worker ((2,), 0, 3), 238*da0073e9SAndroid Build Coastguard Worker ((100, 3), 1, 3), 239*da0073e9SAndroid Build Coastguard Worker ((100, 20, 3), 2, 0), 240*da0073e9SAndroid Build Coastguard Worker ((10, 0, 3), 0, 3), 241*da0073e9SAndroid Build Coastguard Worker ((10, 0, 3), 0, 0), 242*da0073e9SAndroid Build Coastguard Worker ] 243*da0073e9SAndroid Build Coastguard Worker printed = [] 244*da0073e9SAndroid Build Coastguard Worker for shape, sparse_dim, nnz in shape_sparse_dim_nnz: 245*da0073e9SAndroid Build Coastguard Worker indices_shape = torch.Size((sparse_dim, nnz)) 246*da0073e9SAndroid Build Coastguard Worker values_shape = torch.Size((nnz,) + shape[sparse_dim:]) 247*da0073e9SAndroid Build Coastguard Worker printed.append(f"# shape: {torch.Size(shape)}") 248*da0073e9SAndroid Build Coastguard Worker printed.append(f"# nnz: {nnz}") 249*da0073e9SAndroid Build Coastguard Worker printed.append(f"# sparse_dim: {sparse_dim}") 250*da0073e9SAndroid Build Coastguard Worker printed.append(f"# indices shape: {indices_shape}") 251*da0073e9SAndroid Build Coastguard Worker printed.append(f"# values shape: {values_shape}") 252*da0073e9SAndroid Build Coastguard Worker 253*da0073e9SAndroid Build Coastguard Worker indices = torch.arange(indices_shape.numel(), dtype=self.index_tensor(0).dtype, 254*da0073e9SAndroid Build Coastguard Worker device=device).view(indices_shape) 255*da0073e9SAndroid Build Coastguard Worker for d in range(sparse_dim): 256*da0073e9SAndroid Build Coastguard Worker indices[d].clamp_(max=(shape[d] - 1)) # make it valid index 257*da0073e9SAndroid Build Coastguard Worker if not coalesced and indices.numel() > 0: 258*da0073e9SAndroid Build Coastguard Worker indices[:, -1] = indices[:, 0] # make it uncoalesced 259*da0073e9SAndroid Build Coastguard Worker values_numel = values_shape.numel() 260*da0073e9SAndroid Build Coastguard Worker values = torch.arange(values_numel, dtype=dtype, 261*da0073e9SAndroid Build Coastguard Worker device=device).view(values_shape).div_(values_numel / 2.) 262*da0073e9SAndroid Build Coastguard Worker sp_tensor = self.sparse_tensor(indices, values, shape, dtype=dtype, device=device) 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Worker dtypes = [torch.int32] 265*da0073e9SAndroid Build Coastguard Worker if values.dtype == torch.double: 266*da0073e9SAndroid Build Coastguard Worker dtypes.append(torch.float) 267*da0073e9SAndroid Build Coastguard Worker else: 268*da0073e9SAndroid Build Coastguard Worker dtypes.append(torch.double) 269*da0073e9SAndroid Build Coastguard Worker for dtype in dtypes: 270*da0073e9SAndroid Build Coastguard Worker printed.append(f"########## {dtype} ##########") 271*da0073e9SAndroid Build Coastguard Worker x = sp_tensor.detach().to(dtype) 272*da0073e9SAndroid Build Coastguard Worker printed.append("# sparse tensor") 273*da0073e9SAndroid Build Coastguard Worker printed.append(str(x)) 274*da0073e9SAndroid Build Coastguard Worker if x.dtype.is_floating_point: 275*da0073e9SAndroid Build Coastguard Worker printed.append("# after requires_grad_") 276*da0073e9SAndroid Build Coastguard Worker printed.append(str(x.requires_grad_())) 277*da0073e9SAndroid Build Coastguard Worker printed.append("# after addition") 278*da0073e9SAndroid Build Coastguard Worker printed.append(str(x + x)) 279*da0073e9SAndroid Build Coastguard Worker printed.append("# _indices") 280*da0073e9SAndroid Build Coastguard Worker printed.append(str(x._indices())) 281*da0073e9SAndroid Build Coastguard Worker printed.append("# _values") 282*da0073e9SAndroid Build Coastguard Worker printed.append(str(x._values())) 283*da0073e9SAndroid Build Coastguard Worker printed.append('') 284*da0073e9SAndroid Build Coastguard Worker self.assertExpected('\n'.join(printed)) 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 287*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 288*da0073e9SAndroid Build Coastguard Worker def test_basic(self, device, dtype, coalesced): 289*da0073e9SAndroid Build Coastguard Worker def test_shape(sparse_dims, nnz, with_size): 290*da0073e9SAndroid Build Coastguard Worker if isinstance(with_size, Number): 291*da0073e9SAndroid Build Coastguard Worker with_size = [with_size] * sparse_dims 292*da0073e9SAndroid Build Coastguard Worker x, i, v = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced) 293*da0073e9SAndroid Build Coastguard Worker self.assertEqual(i, x._indices()) 294*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v, x._values()) 295*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.ndimension(), len(with_size)) 296*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.coalesce()._nnz(), nnz if x.is_coalesced() else nnz // 2) 297*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(x.size()), with_size) 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker # Test .indices() and .values() 300*da0073e9SAndroid Build Coastguard Worker if not coalesced: 301*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Cannot get indices on an uncoalesced tensor"): 302*da0073e9SAndroid Build Coastguard Worker x.indices() 303*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Cannot get values on an uncoalesced tensor"): 304*da0073e9SAndroid Build Coastguard Worker x.values() 305*da0073e9SAndroid Build Coastguard Worker else: 306*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.indices(), x._indices()) 307*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.values(), x._values()) 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker test_shape(3, 10, 100) 310*da0073e9SAndroid Build Coastguard Worker test_shape(3, 10, [100, 100, 100]) 311*da0073e9SAndroid Build Coastguard Worker test_shape(3, 10, [100, 100, 100, 5, 5, 5, 0]) 312*da0073e9SAndroid Build Coastguard Worker test_shape(3, 0, [0, 0, 100, 5, 5, 5, 0]) 313*da0073e9SAndroid Build Coastguard Worker 314*da0073e9SAndroid Build Coastguard Worker # Make sure that coalesce handles duplicate indices correctly 315*da0073e9SAndroid Build Coastguard Worker i = self.index_tensor([[9, 0, 0, 0, 8, 1, 1, 1, 2, 7, 2, 2, 3, 4, 6, 9]], device=device) 316*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([[idx**2, idx] for idx in range(i.size(1))], dtype=dtype, device=device) 317*da0073e9SAndroid Build Coastguard Worker x = self.sparse_tensor(i, v, torch.Size([10, 2]), dtype=dtype, device=device) 318*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.coalesce()._nnz(), 9) 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 321*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble, torch.bfloat16) 322*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.bfloat16: 1e-2}) 323*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991") 324*da0073e9SAndroid Build Coastguard Worker def test_coalesce(self, device, dtype, coalesced): 325*da0073e9SAndroid Build Coastguard Worker 326*da0073e9SAndroid Build Coastguard Worker def _test_coalesce(t): 327*da0073e9SAndroid Build Coastguard Worker tc = t.coalesce() 328*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tc.to_dense(), t.to_dense()) 329*da0073e9SAndroid Build Coastguard Worker self.assertTrue(tc.is_coalesced()) 330*da0073e9SAndroid Build Coastguard Worker # Our code below doesn't work when nnz is 0, because 331*da0073e9SAndroid Build Coastguard Worker # then it's a 0D tensor, not a 2D tensor. 332*da0073e9SAndroid Build Coastguard Worker if t._nnz() == 0: 333*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t._indices(), tc._indices()) 334*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t._values(), tc._values()) 335*da0073e9SAndroid Build Coastguard Worker return tc 336*da0073e9SAndroid Build Coastguard Worker 337*da0073e9SAndroid Build Coastguard Worker value_map: Dict[Any, Any] = {} 338*da0073e9SAndroid Build Coastguard Worker for idx, val in zip(t._indices().t(), t._values()): 339*da0073e9SAndroid Build Coastguard Worker idx_tup = tuple(idx.tolist()) 340*da0073e9SAndroid Build Coastguard Worker if idx_tup in value_map: 341*da0073e9SAndroid Build Coastguard Worker value_map[idx_tup] += val 342*da0073e9SAndroid Build Coastguard Worker else: 343*da0073e9SAndroid Build Coastguard Worker value_map[idx_tup] = val.clone() if isinstance(val, torch.Tensor) else val 344*da0073e9SAndroid Build Coastguard Worker 345*da0073e9SAndroid Build Coastguard Worker new_indices = sorted(value_map.keys()) 346*da0073e9SAndroid Build Coastguard Worker _new_values = [value_map[idx] for idx in new_indices] 347*da0073e9SAndroid Build Coastguard Worker if t._values().ndimension() < 2: 348*da0073e9SAndroid Build Coastguard Worker new_values = t._values().new(_new_values) 349*da0073e9SAndroid Build Coastguard Worker else: 350*da0073e9SAndroid Build Coastguard Worker new_values = torch.stack(_new_values) 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker new_indices = t._indices().new(new_indices).t() 353*da0073e9SAndroid Build Coastguard Worker tg = t.new(new_indices, new_values, t.size()) 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tc._indices(), tg._indices()) 356*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tc._values(), tg._values()) 357*da0073e9SAndroid Build Coastguard Worker 358*da0073e9SAndroid Build Coastguard Worker if t.is_coalesced(): 359*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tc._indices(), t._indices()) 360*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tc._values(), t._values()) 361*da0073e9SAndroid Build Coastguard Worker 362*da0073e9SAndroid Build Coastguard Worker for empty_i, empty_v, empty_nnz in itertools.product([True, False], repeat=3): 363*da0073e9SAndroid Build Coastguard Worker sparse_size = [] if empty_i else [2, 1] 364*da0073e9SAndroid Build Coastguard Worker dense_size = [1, 0, 2] if empty_v else [1, 2] 365*da0073e9SAndroid Build Coastguard Worker nnz = 0 if empty_nnz else 5 366*da0073e9SAndroid Build Coastguard Worker 367*da0073e9SAndroid Build Coastguard Worker t, _, _ = self._gen_sparse(len(sparse_size), nnz, sparse_size + dense_size, dtype, device, coalesced) 368*da0073e9SAndroid Build Coastguard Worker _test_coalesce(t) # this tests correctness 369*da0073e9SAndroid Build Coastguard Worker 370*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 371*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/89395") 372*da0073e9SAndroid Build Coastguard Worker def test_coalesce_reference_cycle(self, device, dtype): 373*da0073e9SAndroid Build Coastguard Worker # Test coalesce doesn't create autograd graph cycles (gh-52253) 374*da0073e9SAndroid Build Coastguard Worker 375*da0073e9SAndroid Build Coastguard Worker # Sanity check that the helper class works as expected 376*da0073e9SAndroid Build Coastguard Worker t = torch.rand(2) 377*da0073e9SAndroid Build Coastguard Worker t_ref = torch._C._WeakTensorRef(t) 378*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t_ref.expired()) 379*da0073e9SAndroid Build Coastguard Worker 380*da0073e9SAndroid Build Coastguard Worker del t 381*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t_ref.expired()) 382*da0073e9SAndroid Build Coastguard Worker 383*da0073e9SAndroid Build Coastguard Worker def test_sparse_sum(): 384*da0073e9SAndroid Build Coastguard Worker i = torch.tensor([[0], [4]], dtype=torch.long, device=device) 385*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([[[-0.4567, -1.8797, 0.0380, 1.4316]]], 386*da0073e9SAndroid Build Coastguard Worker dtype=dtype, device=device) 387*da0073e9SAndroid Build Coastguard Worker S = torch.sparse_coo_tensor(i, v) 388*da0073e9SAndroid Build Coastguard Worker S = S.coalesce() 389*da0073e9SAndroid Build Coastguard Worker S.requires_grad_(True) 390*da0073e9SAndroid Build Coastguard Worker S2 = S.coalesce() 391*da0073e9SAndroid Build Coastguard Worker self.assertTrue(S2.is_coalesced()) 392*da0073e9SAndroid Build Coastguard Worker return torch._C._WeakTensorRef(S2) 393*da0073e9SAndroid Build Coastguard Worker 394*da0073e9SAndroid Build Coastguard Worker ref = test_sparse_sum() 395*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref.expired()) 396*da0073e9SAndroid Build Coastguard Worker 397*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 398*da0073e9SAndroid Build Coastguard Worker def test_ctor_large_sizes(self, device, dtype): 399*da0073e9SAndroid Build Coastguard Worker # Test that integer overflow is detected when computing numel 400*da0073e9SAndroid Build Coastguard Worker # of a sparse tensor with large dimensions (gh-57416). Notice 401*da0073e9SAndroid Build Coastguard Worker # that numel is computed internally when constructing a 402*da0073e9SAndroid Build Coastguard Worker # tensor, hence the overflow may appear during the tensor 403*da0073e9SAndroid Build Coastguard Worker # construction step. 404*da0073e9SAndroid Build Coastguard Worker N = 100000 405*da0073e9SAndroid Build Coastguard Worker indices = torch.tensor([[N, N - 1]] * 4, dtype=torch.int64, device=device) 406*da0073e9SAndroid Build Coastguard Worker values = torch.tensor([1, 2], dtype=dtype, device=device) 407*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, 408*da0073e9SAndroid Build Coastguard Worker lambda: torch.sparse_coo_tensor( 409*da0073e9SAndroid Build Coastguard Worker indices, values, (N + 1,) * 4, device=device)) 410*da0073e9SAndroid Build Coastguard Worker 411*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 412*da0073e9SAndroid Build Coastguard Worker def test_ctor_size_checks(self, device, dtype): 413*da0073e9SAndroid Build Coastguard Worker indices = self.index_tensor([ 414*da0073e9SAndroid Build Coastguard Worker [0, 0, 0], 415*da0073e9SAndroid Build Coastguard Worker [0, 3, 0], 416*da0073e9SAndroid Build Coastguard Worker [0, 0, 0], 417*da0073e9SAndroid Build Coastguard Worker [0, 0, 0], 418*da0073e9SAndroid Build Coastguard Worker ], device=device) 419*da0073e9SAndroid Build Coastguard Worker values = torch.tensor([2, 1, 3, 4], dtype=dtype, device=device) 420*da0073e9SAndroid Build Coastguard Worker 421*da0073e9SAndroid Build Coastguard Worker # indices inconsistent with size 422*da0073e9SAndroid Build Coastguard Worker self.assertRaises( 423*da0073e9SAndroid Build Coastguard Worker RuntimeError, 424*da0073e9SAndroid Build Coastguard Worker lambda: self.sparse_tensor(indices, values, torch.Size([2, 1, 1]))) 425*da0073e9SAndroid Build Coastguard Worker 426*da0073e9SAndroid Build Coastguard Worker # values inconsistent with size 427*da0073e9SAndroid Build Coastguard Worker values = torch.tensor([ 428*da0073e9SAndroid Build Coastguard Worker [2, 1, 2, 1], 429*da0073e9SAndroid Build Coastguard Worker [1, 0, 5, 2], 430*da0073e9SAndroid Build Coastguard Worker ], dtype=dtype, device=device) 431*da0073e9SAndroid Build Coastguard Worker self.assertRaises( 432*da0073e9SAndroid Build Coastguard Worker RuntimeError, 433*da0073e9SAndroid Build Coastguard Worker lambda: self.sparse_tensor(indices, values, torch.Size([2, 4, 2, 1]))) 434*da0073e9SAndroid Build Coastguard Worker 435*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 436*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 437*da0073e9SAndroid Build Coastguard Worker def test_ctor_is_coalesced_with_gradcheck(self, device, dtype, coalesced): 438*da0073e9SAndroid Build Coastguard Worker for sparse_size, nnz in (((3, 3), 5), ((2, 3, 1, 5), 11)): 439*da0073e9SAndroid Build Coastguard Worker t, _, _ = self._gen_sparse(len(sparse_size), nnz, sparse_size, dtype, device, coalesced) 440*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.is_coalesced(), coalesced) 441*da0073e9SAndroid Build Coastguard Worker 442*da0073e9SAndroid Build Coastguard Worker def func(indices, values, shape, is_coalesced): 443*da0073e9SAndroid Build Coastguard Worker s = torch.sparse_coo_tensor(indices, values, shape, check_invariants=True, is_coalesced=is_coalesced) 444*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s.is_coalesced(), is_coalesced) 445*da0073e9SAndroid Build Coastguard Worker return s.to_dense(masked_grad=False) 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Worker if coalesced: 448*da0073e9SAndroid Build Coastguard Worker torch.autograd.gradcheck(func, (t._indices(), t._values().requires_grad_(True), t.shape, False)) 449*da0073e9SAndroid Build Coastguard Worker torch.autograd.gradcheck(func, (t._indices(), t._values().requires_grad_(True), t.shape, True)) 450*da0073e9SAndroid Build Coastguard Worker else: 451*da0073e9SAndroid Build Coastguard Worker torch.autograd.gradcheck(func, (t._indices(), t._values().requires_grad_(True), t.shape, False)) 452*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 453*da0073e9SAndroid Build Coastguard Worker "cannot set is_coalesced to true if indices correspond to uncoalesced COO tensor"): 454*da0073e9SAndroid Build Coastguard Worker torch.autograd.gradcheck(func, (t._indices(), t._values().requires_grad_(True), t.shape, True)) 455*da0073e9SAndroid Build Coastguard Worker 456*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types_and(torch.float16, torch.bfloat16)) 457*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error") 458*da0073e9SAndroid Build Coastguard Worker @gradcheck_semantics() 459*da0073e9SAndroid Build Coastguard Worker def test_to_dense_with_gradcheck(self, device, dtype, gradcheck): 460*da0073e9SAndroid Build Coastguard Worker 461*da0073e9SAndroid Build Coastguard Worker def test_tensor(x, res): 462*da0073e9SAndroid Build Coastguard Worker x.to_dense() # Tests triple to_dense for memory corruption 463*da0073e9SAndroid Build Coastguard Worker x.to_dense() 464*da0073e9SAndroid Build Coastguard Worker x.to_dense() 465*da0073e9SAndroid Build Coastguard Worker dense_x = x.to_dense() 466*da0073e9SAndroid Build Coastguard Worker safe_dense_x = self.safeToDense(x) 467*da0073e9SAndroid Build Coastguard Worker dense_x = dense_x.to(res.dtype) 468*da0073e9SAndroid Build Coastguard Worker safe_dense_x = safe_dense_x.to(res.dtype) 469*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, dense_x) 470*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, safe_dense_x) 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Worker # Only run autograd test for float64 473*da0073e9SAndroid Build Coastguard Worker if x.dtype != torch.float64: 474*da0073e9SAndroid Build Coastguard Worker return 475*da0073e9SAndroid Build Coastguard Worker 476*da0073e9SAndroid Build Coastguard Worker def fn(x): 477*da0073e9SAndroid Build Coastguard Worker return x.to_dense(masked_grad=gradcheck.masked) 478*da0073e9SAndroid Build Coastguard Worker x.requires_grad_(True) 479*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, (x,)) 480*da0073e9SAndroid Build Coastguard Worker 481*da0073e9SAndroid Build Coastguard Worker for value_type in [torch.double, torch.cdouble]: 482*da0073e9SAndroid Build Coastguard Worker i = self.index_tensor([ 483*da0073e9SAndroid Build Coastguard Worker [0, 1, 2, 2], 484*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 3], 485*da0073e9SAndroid Build Coastguard Worker [0, 0, 1, 4], 486*da0073e9SAndroid Build Coastguard Worker ], device=device) 487*da0073e9SAndroid Build Coastguard Worker # we don't have to_dense for half types on CPU because it is implemented 488*da0073e9SAndroid Build Coastguard Worker # with a slower add_ operation 489*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([2, 1, 3, 4], dtype=dtype, device=device) 490*da0073e9SAndroid Build Coastguard Worker x = self.sparse_tensor(i, v, torch.Size([3, 4, 5]), dtype=value_type, device=device) 491*da0073e9SAndroid Build Coastguard Worker res = torch.tensor([ 492*da0073e9SAndroid Build Coastguard Worker [[2, 0, 0, 0, 0], 493*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 0, 0], 494*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 0, 0], 495*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 0, 0]], 496*da0073e9SAndroid Build Coastguard Worker [[1, 0, 0, 0, 0], 497*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 0, 0], 498*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 0, 0], 499*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 0, 0]], 500*da0073e9SAndroid Build Coastguard Worker [[0, 3, 0, 0, 0], 501*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 0, 0], 502*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 0, 0], 503*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 0, 4]], 504*da0073e9SAndroid Build Coastguard Worker ], dtype=dtype, device=device) 505*da0073e9SAndroid Build Coastguard Worker 506*da0073e9SAndroid Build Coastguard Worker test_tensor(x, res) 507*da0073e9SAndroid Build Coastguard Worker test_tensor(res, res) 508*da0073e9SAndroid Build Coastguard Worker 509*da0073e9SAndroid Build Coastguard Worker i = self.index_tensor([ 510*da0073e9SAndroid Build Coastguard Worker [0, 1, 2, 2], 511*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 3], 512*da0073e9SAndroid Build Coastguard Worker [0, 0, 1, 4], 513*da0073e9SAndroid Build Coastguard Worker ], device=device) 514*da0073e9SAndroid Build Coastguard Worker v = torch.empty(4, 0, dtype=dtype, device=device) 515*da0073e9SAndroid Build Coastguard Worker x = self.sparse_tensor(i, v, torch.Size([3, 4, 5, 0]), dtype=value_type, device=device) 516*da0073e9SAndroid Build Coastguard Worker res = torch.empty((3, 4, 5, 0), dtype=dtype, device=device) 517*da0073e9SAndroid Build Coastguard Worker test_tensor(x, res) 518*da0073e9SAndroid Build Coastguard Worker 519*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 520*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float16, torch.bfloat16, torch.float64, torch.int, torch.cfloat, torch.cdouble) 521*da0073e9SAndroid Build Coastguard Worker def test_to_sparse(self, device, dtype, coalesced): 522*da0073e9SAndroid Build Coastguard Worker shape = [5, 2, 10, 4] 523*da0073e9SAndroid Build Coastguard Worker max_nnz = 1 524*da0073e9SAndroid Build Coastguard Worker for value_type in [torch.double, torch.cdouble]: 525*da0073e9SAndroid Build Coastguard Worker for dim, dim_sz in enumerate(shape, 1): 526*da0073e9SAndroid Build Coastguard Worker max_nnz *= dim_sz 527*da0073e9SAndroid Build Coastguard Worker rnnz = torch.randint(2, max_nnz, (1,)).item() 528*da0073e9SAndroid Build Coastguard Worker for nnz in [0, 1, rnnz]: 529*da0073e9SAndroid Build Coastguard Worker expected, _, _ = self._gen_sparse(dim, nnz, shape, dtype=value_type, device=device, 530*da0073e9SAndroid Build Coastguard Worker coalesced=coalesced) 531*da0073e9SAndroid Build Coastguard Worker expected = expected.to(dtype) 532*da0073e9SAndroid Build Coastguard Worker 533*da0073e9SAndroid Build Coastguard Worker d = expected.to_dense() 534*da0073e9SAndroid Build Coastguard Worker result = d.to_sparse(dim) 535*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d, result.to_dense()) 536*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected.size(), result.size()) 537*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dim, result.sparse_dim()) 538*da0073e9SAndroid Build Coastguard Worker 539*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 540*da0073e9SAndroid Build Coastguard Worker def test_sparse_bool(self, device, dtype): 541*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([True, False], dtype=dtype, device=device).to(torch.bool) 542*da0073e9SAndroid Build Coastguard Worker b = a.to_sparse().to_dense() 543*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, b) 544*da0073e9SAndroid Build Coastguard Worker 545*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/108667") 546*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 547*da0073e9SAndroid Build Coastguard Worker def test_scalar(self, device, dtype): 548*da0073e9SAndroid Build Coastguard Worker # tensor with value 549*da0073e9SAndroid Build Coastguard Worker a = self.sparse_tensor(self.index_tensor([], device=device).unsqueeze(1), 12.3, [], dtype=dtype, device=device) 550*da0073e9SAndroid Build Coastguard Worker self.assertEqual(1, a._values().numel()) 551*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, a.clone()) 552*da0073e9SAndroid Build Coastguard Worker a_coalesced = a.coalesce() 553*da0073e9SAndroid Build Coastguard Worker self.assertTrue(a_coalesced.is_coalesced()) 554*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor(12.3, dtype=dtype, device=device), a.to_dense()) 555*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, a.to_dense().to_sparse()) 556*da0073e9SAndroid Build Coastguard Worker 557*da0073e9SAndroid Build Coastguard Worker # tensor with multiple values 558*da0073e9SAndroid Build Coastguard Worker a = self.sparse_tensor(self.index_tensor([], device=device).unsqueeze(1).expand(0, 2), 559*da0073e9SAndroid Build Coastguard Worker [12.3, 12.3], [], dtype=dtype, device=device) 560*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2, a._values().numel()) 561*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, a.clone()) 562*da0073e9SAndroid Build Coastguard Worker a_coalesced = a.coalesce() 563*da0073e9SAndroid Build Coastguard Worker self.assertTrue(a_coalesced.is_coalesced()) 564*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor(12.3 * 2, dtype=dtype, device=device), a.to_dense()) 565*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.coalesce(), a.coalesce().to_dense().to_sparse()) 566*da0073e9SAndroid Build Coastguard Worker 567*da0073e9SAndroid Build Coastguard Worker # tensor without value 568*da0073e9SAndroid Build Coastguard Worker a = self.sparse_empty((), dtype=dtype, device=device) 569*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, a._values().numel()) 570*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, a.clone()) 571*da0073e9SAndroid Build Coastguard Worker a_coalesced = a.coalesce() 572*da0073e9SAndroid Build Coastguard Worker self.assertTrue(a_coalesced.is_coalesced()) 573*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor(0, dtype=dtype, device=device), a.to_dense()) 574*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, a.to_dense().to_sparse()) 575*da0073e9SAndroid Build Coastguard Worker 576*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 577*da0073e9SAndroid Build Coastguard Worker def test_shared(self, device, dtype): 578*da0073e9SAndroid Build Coastguard Worker i = self.index_tensor([[2]], device=device) 579*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([5], dtype=dtype, device=device) 580*da0073e9SAndroid Build Coastguard Worker x = self.sparse_tensor(i, v, torch.Size([3])) 581*da0073e9SAndroid Build Coastguard Worker v[0] = 6 582*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor([0, 0, 6], dtype=dtype, device=device), self.safeToDense(x)) 583*da0073e9SAndroid Build Coastguard Worker i[0][0] = 0 584*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor([6, 0, 0], dtype=dtype, device=device), self.safeToDense(x)) 585*da0073e9SAndroid Build Coastguard Worker 586*da0073e9SAndroid Build Coastguard Worker i = self.index_tensor([[2]], device=device) 587*da0073e9SAndroid Build Coastguard Worker v = torch.empty((1, 0), dtype=dtype, device=device) 588*da0073e9SAndroid Build Coastguard Worker x = self.sparse_tensor(i, v, torch.Size([3, 0])) 589*da0073e9SAndroid Build Coastguard Worker i[0][0] = 0 590*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty((3, 0), dtype=dtype, device=device), self.safeToDense(x)) 591*da0073e9SAndroid Build Coastguard Worker 592*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 593*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error") 594*da0073e9SAndroid Build Coastguard Worker @gradcheck_semantics() 595*da0073e9SAndroid Build Coastguard Worker def test_to_dense_hybrid(self, device, dtype, gradcheck): 596*da0073e9SAndroid Build Coastguard Worker 597*da0073e9SAndroid Build Coastguard Worker def test_tensor(x, res): 598*da0073e9SAndroid Build Coastguard Worker x.to_dense() # Tests double to_dense for memory corruption 599*da0073e9SAndroid Build Coastguard Worker x.to_dense() 600*da0073e9SAndroid Build Coastguard Worker x.to_dense() 601*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, x.to_dense()) 602*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, self.safeToDense(x)) 603*da0073e9SAndroid Build Coastguard Worker 604*da0073e9SAndroid Build Coastguard Worker def fn(x): 605*da0073e9SAndroid Build Coastguard Worker return x.to_dense(masked_grad=gradcheck.masked) 606*da0073e9SAndroid Build Coastguard Worker x.requires_grad_(True) 607*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, (x,)) 608*da0073e9SAndroid Build Coastguard Worker 609*da0073e9SAndroid Build Coastguard Worker i = self.index_tensor([ 610*da0073e9SAndroid Build Coastguard Worker [0, 1, 2, 2], 611*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 3], 612*da0073e9SAndroid Build Coastguard Worker ], device=device) 613*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([[2, 3], [1, 2], [3, 4], [4, 5]], dtype=dtype, device=device) 614*da0073e9SAndroid Build Coastguard Worker x = self.sparse_tensor(i, v, torch.Size([3, 4, 2])) 615*da0073e9SAndroid Build Coastguard Worker res = torch.tensor([ 616*da0073e9SAndroid Build Coastguard Worker [[2, 3], 617*da0073e9SAndroid Build Coastguard Worker [0, 0], 618*da0073e9SAndroid Build Coastguard Worker [0, 0], 619*da0073e9SAndroid Build Coastguard Worker [0, 0]], 620*da0073e9SAndroid Build Coastguard Worker [[1, 2], 621*da0073e9SAndroid Build Coastguard Worker [0, 0], 622*da0073e9SAndroid Build Coastguard Worker [0, 0], 623*da0073e9SAndroid Build Coastguard Worker [0, 0]], 624*da0073e9SAndroid Build Coastguard Worker [[3, 4], 625*da0073e9SAndroid Build Coastguard Worker [0, 0], 626*da0073e9SAndroid Build Coastguard Worker [0, 0], 627*da0073e9SAndroid Build Coastguard Worker [4, 5]], 628*da0073e9SAndroid Build Coastguard Worker ], dtype=dtype, device=device) 629*da0073e9SAndroid Build Coastguard Worker test_tensor(x, res) 630*da0073e9SAndroid Build Coastguard Worker 631*da0073e9SAndroid Build Coastguard Worker i = self.index_tensor([ 632*da0073e9SAndroid Build Coastguard Worker [0, 1, 2, 2], 633*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 3], 634*da0073e9SAndroid Build Coastguard Worker ], device=device) 635*da0073e9SAndroid Build Coastguard Worker v = torch.empty((4, 2, 0), dtype=dtype, device=device) 636*da0073e9SAndroid Build Coastguard Worker x = self.sparse_tensor(i, v, torch.Size([3, 4, 2, 0])) 637*da0073e9SAndroid Build Coastguard Worker res = torch.empty((3, 4, 2, 0), dtype=dtype, device=device) 638*da0073e9SAndroid Build Coastguard Worker test_tensor(x, res) 639*da0073e9SAndroid Build Coastguard Worker 640*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 641*da0073e9SAndroid Build Coastguard Worker def test_contig(self, device, dtype): 642*da0073e9SAndroid Build Coastguard Worker def test_tensor(x, exp_i, exp_v): 643*da0073e9SAndroid Build Coastguard Worker x = x.coalesce() 644*da0073e9SAndroid Build Coastguard Worker self.assertEqual(exp_i, x._indices()) 645*da0073e9SAndroid Build Coastguard Worker self.assertEqual(exp_v, x._values()) 646*da0073e9SAndroid Build Coastguard Worker 647*da0073e9SAndroid Build Coastguard Worker i = self.index_tensor([ 648*da0073e9SAndroid Build Coastguard Worker [1, 0, 35, 14, 39, 6, 71, 66, 40, 27], 649*da0073e9SAndroid Build Coastguard Worker [92, 31, 62, 50, 22, 65, 89, 74, 56, 34], 650*da0073e9SAndroid Build Coastguard Worker ], device=device) 651*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype, device=device) 652*da0073e9SAndroid Build Coastguard Worker x = self.sparse_tensor(i, v, torch.Size([100, 100])) 653*da0073e9SAndroid Build Coastguard Worker exp_i = self.index_tensor([ 654*da0073e9SAndroid Build Coastguard Worker [0, 1, 6, 14, 27, 35, 39, 40, 66, 71], 655*da0073e9SAndroid Build Coastguard Worker [31, 92, 65, 50, 34, 62, 22, 56, 74, 89], 656*da0073e9SAndroid Build Coastguard Worker ], device=device) 657*da0073e9SAndroid Build Coastguard Worker exp_v = torch.tensor([2, 1, 6, 4, 10, 3, 5, 9, 8, 7], dtype=dtype, device=device) 658*da0073e9SAndroid Build Coastguard Worker test_tensor(x, exp_i, exp_v) 659*da0073e9SAndroid Build Coastguard Worker 660*da0073e9SAndroid Build Coastguard Worker i = self.index_tensor([ 661*da0073e9SAndroid Build Coastguard Worker [2, 0, 2, 1], 662*da0073e9SAndroid Build Coastguard Worker [0, 0, 3, 0], 663*da0073e9SAndroid Build Coastguard Worker [1, 0, 4, 0], 664*da0073e9SAndroid Build Coastguard Worker ], device=device) 665*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([3, 2, 4, 1], dtype=dtype, device=device) 666*da0073e9SAndroid Build Coastguard Worker x = self.sparse_tensor(i, v, torch.Size([3, 4, 5])) 667*da0073e9SAndroid Build Coastguard Worker exp_i = self.index_tensor([ 668*da0073e9SAndroid Build Coastguard Worker [0, 1, 2, 2], 669*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 3], 670*da0073e9SAndroid Build Coastguard Worker [0, 0, 1, 4], 671*da0073e9SAndroid Build Coastguard Worker ], device=device) 672*da0073e9SAndroid Build Coastguard Worker exp_v = torch.tensor([2, 1, 3, 4], dtype=dtype, device=device) 673*da0073e9SAndroid Build Coastguard Worker test_tensor(x, exp_i, exp_v) 674*da0073e9SAndroid Build Coastguard Worker 675*da0073e9SAndroid Build Coastguard Worker i = self.index_tensor([ 676*da0073e9SAndroid Build Coastguard Worker [2, 0, 2, 1], 677*da0073e9SAndroid Build Coastguard Worker [0, 0, 3, 0], 678*da0073e9SAndroid Build Coastguard Worker [1, 0, 4, 0], 679*da0073e9SAndroid Build Coastguard Worker ], device=device) 680*da0073e9SAndroid Build Coastguard Worker v = torch.empty([4, 0], dtype=dtype, device=device) 681*da0073e9SAndroid Build Coastguard Worker x = self.sparse_tensor(i, v, torch.Size([3, 4, 5, 0])) 682*da0073e9SAndroid Build Coastguard Worker exp_i = self.index_tensor([ 683*da0073e9SAndroid Build Coastguard Worker [0, 1, 2, 2], 684*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 3], 685*da0073e9SAndroid Build Coastguard Worker [0, 0, 1, 4], 686*da0073e9SAndroid Build Coastguard Worker ], device=device) 687*da0073e9SAndroid Build Coastguard Worker exp_v = torch.empty([4, 0], dtype=dtype, device=device) 688*da0073e9SAndroid Build Coastguard Worker test_tensor(x, exp_i, exp_v) 689*da0073e9SAndroid Build Coastguard Worker 690*da0073e9SAndroid Build Coastguard Worker # Duplicate indices 691*da0073e9SAndroid Build Coastguard Worker i = self.index_tensor([ 692*da0073e9SAndroid Build Coastguard Worker [0, 0, 2, 0], 693*da0073e9SAndroid Build Coastguard Worker [0, 0, 3, 0], 694*da0073e9SAndroid Build Coastguard Worker [0, 0, 4, 0], 695*da0073e9SAndroid Build Coastguard Worker ], device=device) 696*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([3, 2, 4, 1], dtype=dtype, device=device) 697*da0073e9SAndroid Build Coastguard Worker x = self.sparse_tensor(i, v, torch.Size([3, 4, 5])) 698*da0073e9SAndroid Build Coastguard Worker exp_i = self.index_tensor([ 699*da0073e9SAndroid Build Coastguard Worker [0, 2], 700*da0073e9SAndroid Build Coastguard Worker [0, 3], 701*da0073e9SAndroid Build Coastguard Worker [0, 4], 702*da0073e9SAndroid Build Coastguard Worker ], device=device) 703*da0073e9SAndroid Build Coastguard Worker exp_v = torch.tensor([6, 4], dtype=dtype, device=device) 704*da0073e9SAndroid Build Coastguard Worker test_tensor(x, exp_i, exp_v) 705*da0073e9SAndroid Build Coastguard Worker 706*da0073e9SAndroid Build Coastguard Worker i = self.index_tensor([ 707*da0073e9SAndroid Build Coastguard Worker [0, 0, 2, 0], 708*da0073e9SAndroid Build Coastguard Worker [0, 0, 3, 0], 709*da0073e9SAndroid Build Coastguard Worker [0, 0, 4, 0], 710*da0073e9SAndroid Build Coastguard Worker ], device=device) 711*da0073e9SAndroid Build Coastguard Worker v = torch.empty([4, 0], dtype=dtype, device=device) 712*da0073e9SAndroid Build Coastguard Worker x = self.sparse_tensor(i, v, torch.Size([3, 4, 5, 0])) 713*da0073e9SAndroid Build Coastguard Worker exp_i = self.index_tensor([ 714*da0073e9SAndroid Build Coastguard Worker [0, 2], 715*da0073e9SAndroid Build Coastguard Worker [0, 3], 716*da0073e9SAndroid Build Coastguard Worker [0, 4], 717*da0073e9SAndroid Build Coastguard Worker ], device=device) 718*da0073e9SAndroid Build Coastguard Worker exp_v = torch.empty([2, 0], dtype=dtype, device=device) 719*da0073e9SAndroid Build Coastguard Worker test_tensor(x, exp_i, exp_v) 720*da0073e9SAndroid Build Coastguard Worker 721*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 722*da0073e9SAndroid Build Coastguard Worker def test_contig_hybrid(self, device, dtype): 723*da0073e9SAndroid Build Coastguard Worker def test_tensor(x, exp_i, exp_v): 724*da0073e9SAndroid Build Coastguard Worker x = x.coalesce() 725*da0073e9SAndroid Build Coastguard Worker self.assertEqual(exp_i, x._indices()) 726*da0073e9SAndroid Build Coastguard Worker self.assertEqual(exp_v, x._values()) 727*da0073e9SAndroid Build Coastguard Worker 728*da0073e9SAndroid Build Coastguard Worker i = self.index_tensor([ 729*da0073e9SAndroid Build Coastguard Worker [1, 0, 35, 14, 39, 6, 71, 66, 40, 27], 730*da0073e9SAndroid Build Coastguard Worker [92, 31, 62, 50, 22, 65, 89, 74, 56, 34], 731*da0073e9SAndroid Build Coastguard Worker ], device=device) 732*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([ 733*da0073e9SAndroid Build Coastguard Worker [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], 734*da0073e9SAndroid Build Coastguard Worker [6, 7], [7, 8], [8, 9], [9, 10], [10, 11], 735*da0073e9SAndroid Build Coastguard Worker ], dtype=dtype, device=device) 736*da0073e9SAndroid Build Coastguard Worker x = self.sparse_tensor(i, v, torch.Size([100, 100, 2])) 737*da0073e9SAndroid Build Coastguard Worker exp_i = self.index_tensor([ 738*da0073e9SAndroid Build Coastguard Worker [0, 1, 6, 14, 27, 35, 39, 40, 66, 71], 739*da0073e9SAndroid Build Coastguard Worker [31, 92, 65, 50, 34, 62, 22, 56, 74, 89], 740*da0073e9SAndroid Build Coastguard Worker ], device=device) 741*da0073e9SAndroid Build Coastguard Worker exp_v = torch.tensor([ 742*da0073e9SAndroid Build Coastguard Worker [2, 3], [1, 2], [6, 7], [4, 5], [10, 11], 743*da0073e9SAndroid Build Coastguard Worker [3, 4], [5, 6], [9, 10], [8, 9], [7, 8], 744*da0073e9SAndroid Build Coastguard Worker ], dtype=dtype, device=device) 745*da0073e9SAndroid Build Coastguard Worker test_tensor(x, exp_i, exp_v) 746*da0073e9SAndroid Build Coastguard Worker 747*da0073e9SAndroid Build Coastguard Worker i = self.index_tensor([ 748*da0073e9SAndroid Build Coastguard Worker [2, 0, 2, 1], 749*da0073e9SAndroid Build Coastguard Worker [0, 0, 3, 0], 750*da0073e9SAndroid Build Coastguard Worker [1, 0, 4, 0], 751*da0073e9SAndroid Build Coastguard Worker ], device=device) 752*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([[3, 3, 3], [2, 2, 2], [4, 4, 4], [1, 1, 1]], dtype=dtype, device=device) 753*da0073e9SAndroid Build Coastguard Worker x = self.sparse_tensor(i, v, torch.Size([3, 4, 5, 3])) 754*da0073e9SAndroid Build Coastguard Worker exp_i = self.index_tensor([ 755*da0073e9SAndroid Build Coastguard Worker [0, 1, 2, 2], 756*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 3], 757*da0073e9SAndroid Build Coastguard Worker [0, 0, 1, 4], 758*da0073e9SAndroid Build Coastguard Worker ], device=device) 759*da0073e9SAndroid Build Coastguard Worker exp_v = torch.tensor([[2, 2, 2], [1, 1, 1], [3, 3, 3], [4, 4, 4]], dtype=dtype, device=device) 760*da0073e9SAndroid Build Coastguard Worker test_tensor(x, exp_i, exp_v) 761*da0073e9SAndroid Build Coastguard Worker 762*da0073e9SAndroid Build Coastguard Worker i = self.index_tensor([ 763*da0073e9SAndroid Build Coastguard Worker [2, 0, 2, 1], 764*da0073e9SAndroid Build Coastguard Worker [0, 0, 3, 0], 765*da0073e9SAndroid Build Coastguard Worker [1, 0, 4, 0], 766*da0073e9SAndroid Build Coastguard Worker ], device=device) 767*da0073e9SAndroid Build Coastguard Worker v = torch.empty([4, 3, 0], dtype=dtype, device=device) 768*da0073e9SAndroid Build Coastguard Worker x = self.sparse_tensor(i, v, torch.Size([3, 4, 5, 3, 0])) 769*da0073e9SAndroid Build Coastguard Worker exp_i = self.index_tensor([ 770*da0073e9SAndroid Build Coastguard Worker [0, 1, 2, 2], 771*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 3], 772*da0073e9SAndroid Build Coastguard Worker [0, 0, 1, 4], 773*da0073e9SAndroid Build Coastguard Worker ], device=device) 774*da0073e9SAndroid Build Coastguard Worker exp_v = torch.empty([4, 3, 0], dtype=dtype, device=device) 775*da0073e9SAndroid Build Coastguard Worker test_tensor(x, exp_i, exp_v) 776*da0073e9SAndroid Build Coastguard Worker 777*da0073e9SAndroid Build Coastguard Worker # Duplicate indices 778*da0073e9SAndroid Build Coastguard Worker i = self.index_tensor([ 779*da0073e9SAndroid Build Coastguard Worker [0, 0, 2, 0], 780*da0073e9SAndroid Build Coastguard Worker [0, 0, 3, 0], 781*da0073e9SAndroid Build Coastguard Worker [0, 0, 4, 0], 782*da0073e9SAndroid Build Coastguard Worker ], device=device) 783*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([[3, 2, 3], [2, 1, 1], [4, 3, 4], [1, 1, 1]], dtype=dtype, device=device) 784*da0073e9SAndroid Build Coastguard Worker x = self.sparse_tensor(i, v, torch.Size([3, 4, 5, 3])) 785*da0073e9SAndroid Build Coastguard Worker exp_i = self.index_tensor([ 786*da0073e9SAndroid Build Coastguard Worker [0, 2], 787*da0073e9SAndroid Build Coastguard Worker [0, 3], 788*da0073e9SAndroid Build Coastguard Worker [0, 4], 789*da0073e9SAndroid Build Coastguard Worker ], device=device) 790*da0073e9SAndroid Build Coastguard Worker exp_v = torch.tensor([[6, 4, 5], [4, 3, 4]], dtype=dtype, device=device) 791*da0073e9SAndroid Build Coastguard Worker test_tensor(x, exp_i, exp_v) 792*da0073e9SAndroid Build Coastguard Worker 793*da0073e9SAndroid Build Coastguard Worker i = self.index_tensor([ 794*da0073e9SAndroid Build Coastguard Worker [0, 0, 2, 0], 795*da0073e9SAndroid Build Coastguard Worker [0, 0, 3, 0], 796*da0073e9SAndroid Build Coastguard Worker [0, 0, 4, 0], 797*da0073e9SAndroid Build Coastguard Worker ], device=device) 798*da0073e9SAndroid Build Coastguard Worker v = torch.empty([4, 3, 0], dtype=dtype, device=device) 799*da0073e9SAndroid Build Coastguard Worker x = self.sparse_tensor(i, v, torch.Size([3, 4, 5, 3, 0])) 800*da0073e9SAndroid Build Coastguard Worker exp_i = self.index_tensor([ 801*da0073e9SAndroid Build Coastguard Worker [0, 2], 802*da0073e9SAndroid Build Coastguard Worker [0, 3], 803*da0073e9SAndroid Build Coastguard Worker [0, 4], 804*da0073e9SAndroid Build Coastguard Worker ], device=device) 805*da0073e9SAndroid Build Coastguard Worker exp_v = torch.empty([2, 3, 0], dtype=dtype, device=device) 806*da0073e9SAndroid Build Coastguard Worker test_tensor(x, exp_i, exp_v) 807*da0073e9SAndroid Build Coastguard Worker 808*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 809*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 810*da0073e9SAndroid Build Coastguard Worker def test_clone(self, device, dtype, coalesced): 811*da0073e9SAndroid Build Coastguard Worker def test_shape(sparse_dims, nnz, with_size): 812*da0073e9SAndroid Build Coastguard Worker x = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0] 813*da0073e9SAndroid Build Coastguard Worker if not coalesced: 814*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.is_coalesced()) 815*da0073e9SAndroid Build Coastguard Worker y = x.clone() 816*da0073e9SAndroid Build Coastguard Worker self.assertFalse(y.is_coalesced()) 817*da0073e9SAndroid Build Coastguard Worker x = x.coalesce() 818*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.is_coalesced()) 819*da0073e9SAndroid Build Coastguard Worker y = x.clone() 820*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y.is_coalesced()) 821*da0073e9SAndroid Build Coastguard Worker 822*da0073e9SAndroid Build Coastguard Worker test_shape(4, 20, 5) 823*da0073e9SAndroid Build Coastguard Worker test_shape(3, 10, [100, 100, 100, 5, 5, 5, 0]) 824*da0073e9SAndroid Build Coastguard Worker test_shape(3, 0, [0, 0, 100, 5, 5, 5, 0]) 825*da0073e9SAndroid Build Coastguard Worker 826*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 827*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble, torch.bfloat16) 828*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.bfloat16: 2e-2}) 829*da0073e9SAndroid Build Coastguard Worker def test_Sparse_to_Sparse_copy_(self, device, dtype, coalesced): 830*da0073e9SAndroid Build Coastguard Worker # This is for testing torch.copy_(SparseTensor, SparseTensor) 831*da0073e9SAndroid Build Coastguard Worker sparse_dims = 3 832*da0073e9SAndroid Build Coastguard Worker nnz = 10 833*da0073e9SAndroid Build Coastguard Worker sizes = [2, 3, 4, 5] # hybrid sparse 834*da0073e9SAndroid Build Coastguard Worker x1, _, _ = self._gen_sparse(sparse_dims, nnz, sizes, dtype, device, coalesced) 835*da0073e9SAndroid Build Coastguard Worker x2, _, _ = self._gen_sparse(sparse_dims, nnz + 10, sizes, dtype, device, coalesced) 836*da0073e9SAndroid Build Coastguard Worker 837*da0073e9SAndroid Build Coastguard Worker # test copy 838*da0073e9SAndroid Build Coastguard Worker x2_dense = x2.to_dense() 839*da0073e9SAndroid Build Coastguard Worker x1.copy_(x2) 840*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x2_dense, x1.to_dense()) 841*da0073e9SAndroid Build Coastguard Worker 842*da0073e9SAndroid Build Coastguard Worker # test type conversion (when x1.copy_(x2), x1.dtype should stay the same) 843*da0073e9SAndroid Build Coastguard Worker x1 = x1.to(torch.float32) 844*da0073e9SAndroid Build Coastguard Worker 845*da0073e9SAndroid Build Coastguard Worker x2 = x2.to(torch.float16) 846*da0073e9SAndroid Build Coastguard Worker x1_dtype = x1.dtype 847*da0073e9SAndroid Build Coastguard Worker x1.copy_(x2) 848*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1_dtype, x1.dtype) 849*da0073e9SAndroid Build Coastguard Worker 850*da0073e9SAndroid Build Coastguard Worker x2 = x2.to(torch.float64) 851*da0073e9SAndroid Build Coastguard Worker x1_dtype = x1.dtype 852*da0073e9SAndroid Build Coastguard Worker x1.copy_(x2) 853*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1_dtype, x1.dtype) 854*da0073e9SAndroid Build Coastguard Worker 855*da0073e9SAndroid Build Coastguard Worker # test no broadcast 856*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: x1.copy_(x2.narrow_copy(0, 0, 1))) 857*da0073e9SAndroid Build Coastguard Worker 858*da0073e9SAndroid Build Coastguard Worker # test raise error on copy_() between dense and sparse Tensors 859*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: x1.copy_(torch.randn(5, 5))) 860*da0073e9SAndroid Build Coastguard Worker 861*da0073e9SAndroid Build Coastguard Worker # test autograd 862*da0073e9SAndroid Build Coastguard Worker x1, _, _ = self._gen_sparse(sparse_dims, nnz, sizes, dtype, device, coalesced) 863*da0073e9SAndroid Build Coastguard Worker x2, _, _ = self._gen_sparse(sparse_dims, nnz + 10, sizes, dtype, device, coalesced) 864*da0073e9SAndroid Build Coastguard Worker x2.requires_grad_(True) 865*da0073e9SAndroid Build Coastguard Worker x1.copy_(x2) 866*da0073e9SAndroid Build Coastguard Worker y = x1 * 2 867*da0073e9SAndroid Build Coastguard Worker x2_clone = x2.clone() 868*da0073e9SAndroid Build Coastguard Worker y.backward(x2_clone) 869*da0073e9SAndroid Build Coastguard Worker expected_grad = x2_clone * 2 870*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_grad.to_dense(), x2.grad.to_dense()) 871*da0073e9SAndroid Build Coastguard Worker self.assertEqual(None, x1.grad) 872*da0073e9SAndroid Build Coastguard Worker 873*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 874*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") 875*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 876*da0073e9SAndroid Build Coastguard Worker def test_Sparse_to_Sparse_copy_multi_gpu(self, device, dtype, coalesced): 877*da0073e9SAndroid Build Coastguard Worker # This is for testing torch.copy_(SparseTensor, SparseTensor) across GPU devices 878*da0073e9SAndroid Build Coastguard Worker sparse_dims = 3 879*da0073e9SAndroid Build Coastguard Worker nnz = 10 880*da0073e9SAndroid Build Coastguard Worker sizes = [2, 3, 4, 5] # hybrid sparse 881*da0073e9SAndroid Build Coastguard Worker x1, _, _ = self._gen_sparse(sparse_dims, nnz, sizes, dtype, device, coalesced) 882*da0073e9SAndroid Build Coastguard Worker x2, _, _ = self._gen_sparse(sparse_dims, nnz + 10, sizes, dtype, device, coalesced) 883*da0073e9SAndroid Build Coastguard Worker x1 = x1.to('cuda:0') 884*da0073e9SAndroid Build Coastguard Worker 885*da0073e9SAndroid Build Coastguard Worker def test_cross_device(x1, x2): 886*da0073e9SAndroid Build Coastguard Worker x1_device = x1.device 887*da0073e9SAndroid Build Coastguard Worker x1.copy_(x2) 888*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x2.to('cuda:0').to_dense(), x1.to_dense()) 889*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1_device, x1.device) 890*da0073e9SAndroid Build Coastguard Worker 891*da0073e9SAndroid Build Coastguard Worker test_cross_device(x1, x2.to('cuda:1')) # test across gpu devices 892*da0073e9SAndroid Build Coastguard Worker test_cross_device(x1, x2.to('cpu')) # test between cpu and gpu 893*da0073e9SAndroid Build Coastguard Worker 894*da0073e9SAndroid Build Coastguard Worker # test autograd 895*da0073e9SAndroid Build Coastguard Worker x2 = x2.to('cuda:1') 896*da0073e9SAndroid Build Coastguard Worker x2.requires_grad_(True) 897*da0073e9SAndroid Build Coastguard Worker x1.copy_(x2) 898*da0073e9SAndroid Build Coastguard Worker y = x1 * 2 899*da0073e9SAndroid Build Coastguard Worker x2_clone = x2.clone().to('cuda:0') 900*da0073e9SAndroid Build Coastguard Worker y.backward(x2_clone) 901*da0073e9SAndroid Build Coastguard Worker expected_grad = x2_clone * 2 902*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_grad.to_dense(), x2.grad.to('cuda:0').to_dense()) 903*da0073e9SAndroid Build Coastguard Worker self.assertEqual(None, x1.grad) 904*da0073e9SAndroid Build Coastguard Worker 905*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 906*da0073e9SAndroid Build Coastguard Worker def test_cuda_empty(self, device): 907*da0073e9SAndroid Build Coastguard Worker def test_tensor(x): 908*da0073e9SAndroid Build Coastguard Worker y = x.to(device) 909*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.sparse_dim(), y.sparse_dim()) 910*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.dense_dim(), y.dense_dim()) 911*da0073e9SAndroid Build Coastguard Worker x = y.cpu() 912*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.sparse_dim(), x.sparse_dim()) 913*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.dense_dim(), x.dense_dim()) 914*da0073e9SAndroid Build Coastguard Worker 915*da0073e9SAndroid Build Coastguard Worker x = torch.sparse_coo_tensor((2, 3, 4), dtype=torch.float32) 916*da0073e9SAndroid Build Coastguard Worker test_tensor(x) 917*da0073e9SAndroid Build Coastguard Worker 918*da0073e9SAndroid Build Coastguard Worker x = torch.sparse_coo_tensor((2, 3, 4), dtype=torch.float16) 919*da0073e9SAndroid Build Coastguard Worker test_tensor(x) 920*da0073e9SAndroid Build Coastguard Worker 921*da0073e9SAndroid Build Coastguard Worker x = torch.sparse_coo_tensor((2, 3, 4), dtype=torch.float16) 922*da0073e9SAndroid Build Coastguard Worker test_tensor(x) 923*da0073e9SAndroid Build Coastguard Worker 924*da0073e9SAndroid Build Coastguard Worker x = torch.sparse_coo_tensor((2, 3, 4, 0), dtype=torch.float32) 925*da0073e9SAndroid Build Coastguard Worker test_tensor(x) 926*da0073e9SAndroid Build Coastguard Worker 927*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 928*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 929*da0073e9SAndroid Build Coastguard Worker def test_transpose(self, device, dtype, coalesced): 930*da0073e9SAndroid Build Coastguard Worker def test_shape(sparse_dims, nnz, with_size): 931*da0073e9SAndroid Build Coastguard Worker x = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0] 932*da0073e9SAndroid Build Coastguard Worker y = self.safeToDense(x) 933*da0073e9SAndroid Build Coastguard Worker 934*da0073e9SAndroid Build Coastguard Worker for i, j in itertools.combinations(range(4), 2): 935*da0073e9SAndroid Build Coastguard Worker x = x.transpose_(i, j) 936*da0073e9SAndroid Build Coastguard Worker y = y.transpose(i, j) 937*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.safeToDense(x), y) 938*da0073e9SAndroid Build Coastguard Worker 939*da0073e9SAndroid Build Coastguard Worker x = x.transpose(i, j) 940*da0073e9SAndroid Build Coastguard Worker y = y.transpose(i, j) 941*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.safeToDense(x), y) 942*da0073e9SAndroid Build Coastguard Worker 943*da0073e9SAndroid Build Coastguard Worker test_shape(4, 6, 3) 944*da0073e9SAndroid Build Coastguard Worker test_shape(4, 3, [7, 7, 7, 3, 3, 3, 0]) 945*da0073e9SAndroid Build Coastguard Worker test_shape(4, 0, [0, 0, 7, 3, 3, 3, 0]) 946*da0073e9SAndroid Build Coastguard Worker 947*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 948*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 949*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error") 950*da0073e9SAndroid Build Coastguard Worker @gradcheck_semantics() 951*da0073e9SAndroid Build Coastguard Worker def test_permute(self, device, dtype, coalesced, gradcheck): 952*da0073e9SAndroid Build Coastguard Worker # trivial checks 953*da0073e9SAndroid Build Coastguard Worker s = torch.rand(3, 3, 3, device=device, dtype=dtype).to_sparse() 954*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "does not match the length"): 955*da0073e9SAndroid Build Coastguard Worker s.permute(dims=(1, 0)) 956*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "duplicate dims"): 957*da0073e9SAndroid Build Coastguard Worker s.permute(dims=(1, 1, 1)) 958*da0073e9SAndroid Build Coastguard Worker # Calling permute on a sparse tensor with an empty tuple used to segfault, 959*da0073e9SAndroid Build Coastguard Worker # see https://github.com/pytorch/pytorch/issues/116325 960*da0073e9SAndroid Build Coastguard Worker x = torch.rand((), device=device, dtype=dtype).to_sparse() 961*da0073e9SAndroid Build Coastguard Worker x.permute(()) 962*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(x.values()), 1) 963*da0073e9SAndroid Build Coastguard Worker 964*da0073e9SAndroid Build Coastguard Worker def test_shape(sparse_dims, nnz, with_size): 965*da0073e9SAndroid Build Coastguard Worker ndim = len(with_size) 966*da0073e9SAndroid Build Coastguard Worker valid_sparse_dims = torch.arange(-ndim, -ndim + sparse_dims) 967*da0073e9SAndroid Build Coastguard Worker valid_dense_dims = torch.arange(-ndim + sparse_dims, 0) 968*da0073e9SAndroid Build Coastguard Worker 969*da0073e9SAndroid Build Coastguard Worker for dims in itertools.permutations(range(-ndim, 0)): 970*da0073e9SAndroid Build Coastguard Worker s = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0] 971*da0073e9SAndroid Build Coastguard Worker d = self.safeToDense(s) 972*da0073e9SAndroid Build Coastguard Worker 973*da0073e9SAndroid Build Coastguard Worker dims_sparse, _ = torch.tensor(dims[:sparse_dims]).sort() 974*da0073e9SAndroid Build Coastguard Worker dims_dense, _ = torch.tensor(dims[sparse_dims:]).sort() 975*da0073e9SAndroid Build Coastguard Worker 976*da0073e9SAndroid Build Coastguard Worker if (valid_sparse_dims == dims_sparse).all() and (valid_dense_dims == dims_dense).all(): 977*da0073e9SAndroid Build Coastguard Worker # if valid permutation, test for correctness 978*da0073e9SAndroid Build Coastguard Worker s_permuted = s.permute(dims) 979*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s_permuted, d.permute(dims)) 980*da0073e9SAndroid Build Coastguard Worker 981*da0073e9SAndroid Build Coastguard Worker # if s is coalesced, and perm does not touch 0-dim, 982*da0073e9SAndroid Build Coastguard Worker # the result has to be coalesced as well 983*da0073e9SAndroid Build Coastguard Worker if dims[0] == 0: 984*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s_permuted.is_coalesced(), s.is_coalesced()) 985*da0073e9SAndroid Build Coastguard Worker else: 986*da0073e9SAndroid Build Coastguard Worker self.assertFalse(s_permuted.is_coalesced()) 987*da0073e9SAndroid Build Coastguard Worker 988*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda t: t.permute(dims).to_dense(masked_grad=gradcheck.masked), s.requires_grad_()) 989*da0073e9SAndroid Build Coastguard Worker else: 990*da0073e9SAndroid Build Coastguard Worker # otherwise check if exception is thrown 991*da0073e9SAndroid Build Coastguard Worker fail_message = "transpositions between sparse and dense dimensions are not allowed" 992*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, fail_message): 993*da0073e9SAndroid Build Coastguard Worker s.permute(dims) 994*da0073e9SAndroid Build Coastguard Worker 995*da0073e9SAndroid Build Coastguard Worker test_shape(2, 3, [2, 3, 4, 5]) 996*da0073e9SAndroid Build Coastguard Worker test_shape(2, 3, [2, 2, 0]) 997*da0073e9SAndroid Build Coastguard Worker # if nnz=0, it is not true that t == t.to_dense().to_sparse() 998*da0073e9SAndroid Build Coastguard Worker # unless t.sparse_dim == t.dim (i.e. t is not hybrid) 999*da0073e9SAndroid Build Coastguard Worker test_shape(3, 0, [0, 0, 2]) 1000*da0073e9SAndroid Build Coastguard Worker 1001*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 1002*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1003*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 1004*da0073e9SAndroid Build Coastguard Worker def test_coalesce_transpose_mm(self, device, dtype, coalesced): 1005*da0073e9SAndroid Build Coastguard Worker def test_shape(di, dj, dk, nnz): 1006*da0073e9SAndroid Build Coastguard Worker x, _, _ = self._gen_sparse(2, nnz, [dj, di], dtype, device, coalesced) 1007*da0073e9SAndroid Build Coastguard Worker y = torch.randn(dj, dk, dtype=dtype, device=device) 1008*da0073e9SAndroid Build Coastguard Worker 1009*da0073e9SAndroid Build Coastguard Worker x_coalesced = x.coalesce() 1010*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x_coalesced.is_coalesced()) 1011*da0073e9SAndroid Build Coastguard Worker 1012*da0073e9SAndroid Build Coastguard Worker x_coalesced_t = x_coalesced.t() 1013*da0073e9SAndroid Build Coastguard Worker # Transpose is `colasced`-preserving if the indices tensor is empty. 1014*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_coalesced_t.is_coalesced(), di * nnz == 0) 1015*da0073e9SAndroid Build Coastguard Worker 1016*da0073e9SAndroid Build Coastguard Worker res = torch.mm(x_coalesced_t, y) 1017*da0073e9SAndroid Build Coastguard Worker expected = torch.mm(self.safeToDense(x_coalesced_t), y) 1018*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected) 1019*da0073e9SAndroid Build Coastguard Worker 1020*da0073e9SAndroid Build Coastguard Worker test_shape(10, 20, 30, 20) 1021*da0073e9SAndroid Build Coastguard Worker test_shape(0, 20, 30, 0) 1022*da0073e9SAndroid Build Coastguard Worker test_shape(10, 0, 30, 0) 1023*da0073e9SAndroid Build Coastguard Worker test_shape(10, 20, 0, 0) 1024*da0073e9SAndroid Build Coastguard Worker test_shape(10, 20, 0, 20) 1025*da0073e9SAndroid Build Coastguard Worker 1026*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1166") 1027*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 1028*da0073e9SAndroid Build Coastguard Worker def test_t_empty(self, device, dtype): 1029*da0073e9SAndroid Build Coastguard Worker def test_in_place(x): 1030*da0073e9SAndroid Build Coastguard Worker shape_original = x.shape 1031*da0073e9SAndroid Build Coastguard Worker x.t_() 1032*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.Size([shape_original[1], shape_original[0]]), x.size()) 1033*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, x._indices().numel()) 1034*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, x._values().numel()) 1035*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.sparse_dim(), 2) 1036*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.dense_dim(), 0) 1037*da0073e9SAndroid Build Coastguard Worker 1038*da0073e9SAndroid Build Coastguard Worker def test_not_in_place(x): 1039*da0073e9SAndroid Build Coastguard Worker shape_original = x.shape 1040*da0073e9SAndroid Build Coastguard Worker y = x.t() 1041*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.Size([shape_original[1], shape_original[0]]), y.size()) 1042*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, y._indices().numel()) 1043*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, y._values().numel()) 1044*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.sparse_dim(), 2) 1045*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.dense_dim(), 0) 1046*da0073e9SAndroid Build Coastguard Worker 1047*da0073e9SAndroid Build Coastguard Worker x = self.sparse_empty(2, 3, dtype=dtype, device=device) 1048*da0073e9SAndroid Build Coastguard Worker test_in_place(x) 1049*da0073e9SAndroid Build Coastguard Worker test_not_in_place(x) 1050*da0073e9SAndroid Build Coastguard Worker 1051*da0073e9SAndroid Build Coastguard Worker x = self.sparse_empty(2, 0, dtype=dtype, device=device) 1052*da0073e9SAndroid Build Coastguard Worker test_in_place(x) 1053*da0073e9SAndroid Build Coastguard Worker test_not_in_place(x) 1054*da0073e9SAndroid Build Coastguard Worker 1055*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 1056*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 1057*da0073e9SAndroid Build Coastguard Worker def test_add_zeros(self, device, dtype, coalesced): 1058*da0073e9SAndroid Build Coastguard Worker def test_shape(sparse_dims, nnz, sizes): 1059*da0073e9SAndroid Build Coastguard Worker x, _, _ = self._gen_sparse(sparse_dims, nnz, sizes, dtype, device, coalesced) 1060*da0073e9SAndroid Build Coastguard Worker zeros = torch.sparse_coo_tensor(sizes, device=x.device) 1061*da0073e9SAndroid Build Coastguard Worker r1 = zeros + x 1062*da0073e9SAndroid Build Coastguard Worker r2 = x + zeros 1063*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r1, x) 1064*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r2, x) 1065*da0073e9SAndroid Build Coastguard Worker 1066*da0073e9SAndroid Build Coastguard Worker test_shape(1, 20, [1]) 1067*da0073e9SAndroid Build Coastguard Worker test_shape(4, 20, [3, 17, 19, 5]) 1068*da0073e9SAndroid Build Coastguard Worker test_shape(2, 20, [3, 17, 19, 5]) 1069*da0073e9SAndroid Build Coastguard Worker test_shape(2, 20, [3, 17, 19, 0]) 1070*da0073e9SAndroid Build Coastguard Worker 1071*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 1072*da0073e9SAndroid Build Coastguard Worker def test_add_sub_nnz(self, device, dtype): 1073*da0073e9SAndroid Build Coastguard Worker # nnz should not grow unbounded (gh-34964) 1074*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, dtype=dtype, device=device).to_sparse() 1075*da0073e9SAndroid Build Coastguard Worker x.add_(x) 1076*da0073e9SAndroid Build Coastguard Worker x.add_(x) 1077*da0073e9SAndroid Build Coastguard Worker self.assertLessEqual(x._nnz(), 10) 1078*da0073e9SAndroid Build Coastguard Worker 1079*da0073e9SAndroid Build Coastguard Worker x.sub_(2 * x) 1080*da0073e9SAndroid Build Coastguard Worker x.sub_(2 * x) 1081*da0073e9SAndroid Build Coastguard Worker self.assertLessEqual(x._nnz(), 10) 1082*da0073e9SAndroid Build Coastguard Worker 1083*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 1084*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 1085*da0073e9SAndroid Build Coastguard Worker def test_cat(self, device, dtype, coalesced): 1086*da0073e9SAndroid Build Coastguard Worker # shapes: list of tuples (sparse_dims, nnz, sizes) 1087*da0073e9SAndroid Build Coastguard Worker def test_shapes(shapes, dim, fail_message=None): 1088*da0073e9SAndroid Build Coastguard Worker inputs = [self._gen_sparse(shape[0], shape[1], shape[2], dtype, device, coalesced)[0] 1089*da0073e9SAndroid Build Coastguard Worker for shape in shapes] 1090*da0073e9SAndroid Build Coastguard Worker if fail_message: 1091*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, fail_message): 1092*da0073e9SAndroid Build Coastguard Worker torch.cat(inputs, dim) 1093*da0073e9SAndroid Build Coastguard Worker else: 1094*da0073e9SAndroid Build Coastguard Worker result = torch.cat(inputs, dim) 1095*da0073e9SAndroid Build Coastguard Worker dense_result = torch.cat([t.to_dense() for t in inputs], dim) 1096*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dense_result, result.to_dense()) 1097*da0073e9SAndroid Build Coastguard Worker 1098*da0073e9SAndroid Build Coastguard Worker test_shapes( 1099*da0073e9SAndroid Build Coastguard Worker [(3, 10, [2, 3, 4]), (3, 10, [2, 1, 4]), (3, 10, [2, 4, 4])], 1) 1100*da0073e9SAndroid Build Coastguard Worker 1101*da0073e9SAndroid Build Coastguard Worker # mismatched sizes 1102*da0073e9SAndroid Build Coastguard Worker test_shapes([(3, 10, [2, 3, 4]), (3, 10, [2, 1, 4])], 0, 1103*da0073e9SAndroid Build Coastguard Worker "All tensors must have the same shape: \\[2, 3, 4].*\\[2, 1, 4]") 1104*da0073e9SAndroid Build Coastguard Worker # hybrid sparse/dense 1105*da0073e9SAndroid Build Coastguard Worker test_shapes( 1106*da0073e9SAndroid Build Coastguard Worker [(2, 10, [2, 3, 4]), (2, 10, [2, 1, 4]), (2, 10, [2, 4, 4])], 1) 1107*da0073e9SAndroid Build Coastguard Worker # cat along dense dim 1108*da0073e9SAndroid Build Coastguard Worker test_shapes([(2, 10, [2, 3, 4]), (2, 10, [2, 3, 7])], 2) 1109*da0073e9SAndroid Build Coastguard Worker test_shapes([(1, 10, [2, 3, 4]), (1, 10, [2, 3, 4])], 1) 1110*da0073e9SAndroid Build Coastguard Worker test_shapes([(1, 10, [2, 3, 4]), (1, 10, [2, 3, 4])], 2) 1111*da0073e9SAndroid Build Coastguard Worker # mismatched dimensions 1112*da0073e9SAndroid Build Coastguard Worker test_shapes([(2, 10, [2, 3, 4]), (3, 10, [2, 3, 4])], 0, 1113*da0073e9SAndroid Build Coastguard Worker "All tensors must have the same.*2, 1, but tensor at position 1 has 3, 0.") 1114*da0073e9SAndroid Build Coastguard Worker # wrapped dimension 1115*da0073e9SAndroid Build Coastguard Worker test_shapes( 1116*da0073e9SAndroid Build Coastguard Worker [(3, 10, [2, 3, 4]), (3, 10, [2, 1, 4]), (3, 10, [2, 4, 4])], -2) 1117*da0073e9SAndroid Build Coastguard Worker 1118*da0073e9SAndroid Build Coastguard Worker # sparse with dense 1119*da0073e9SAndroid Build Coastguard Worker sp = self._gen_sparse(3, 10, [2, 3, 4], dtype, device, coalesced)[0] 1120*da0073e9SAndroid Build Coastguard Worker dn = sp.to_dense() 1121*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 1122*da0073e9SAndroid Build Coastguard Worker "Concatenating sparse tensors, but a dense tensor was found at position 1."): 1123*da0073e9SAndroid Build Coastguard Worker torch.cat((sp, dn)) 1124*da0073e9SAndroid Build Coastguard Worker 1125*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 1126*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 1127*da0073e9SAndroid Build Coastguard Worker def test_unsqueeze(self, device, dtype, coalesced): 1128*da0073e9SAndroid Build Coastguard Worker def test_shape(sparse_dims, nnz, sizes, unsqueeze_dim, fail_message=None): 1129*da0073e9SAndroid Build Coastguard Worker x, _, _ = self._gen_sparse(sparse_dims, nnz, sizes, dtype, device, coalesced) 1130*da0073e9SAndroid Build Coastguard Worker if fail_message: 1131*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(IndexError, fail_message): 1132*da0073e9SAndroid Build Coastguard Worker torch.unsqueeze(x, unsqueeze_dim) 1133*da0073e9SAndroid Build Coastguard Worker else: 1134*da0073e9SAndroid Build Coastguard Worker result = torch.unsqueeze(x, unsqueeze_dim) 1135*da0073e9SAndroid Build Coastguard Worker dense_result = torch.unsqueeze(x.to_dense(), unsqueeze_dim) 1136*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dense_result, result.to_dense()) 1137*da0073e9SAndroid Build Coastguard Worker 1138*da0073e9SAndroid Build Coastguard Worker # basic case 1139*da0073e9SAndroid Build Coastguard Worker test_shape(3, 10, [5, 7, 11], 0) 1140*da0073e9SAndroid Build Coastguard Worker 1141*da0073e9SAndroid Build Coastguard Worker # hybrid sparse/dense, unsqueeze along sparse dim 1142*da0073e9SAndroid Build Coastguard Worker test_shape(3, 10, [5, 7, 11, 13, 17], 0) 1143*da0073e9SAndroid Build Coastguard Worker test_shape(3, 10, [5, 7, 11, 13, 17], 3) 1144*da0073e9SAndroid Build Coastguard Worker 1145*da0073e9SAndroid Build Coastguard Worker # unsqueeze along dense dimensions 1146*da0073e9SAndroid Build Coastguard Worker test_shape(3, 10, [5, 7, 11, 13, 17], 4) 1147*da0073e9SAndroid Build Coastguard Worker test_shape(3, 10, [5, 7, 11, 13, 17], 5) 1148*da0073e9SAndroid Build Coastguard Worker 1149*da0073e9SAndroid Build Coastguard Worker # wrapped dimensions 1150*da0073e9SAndroid Build Coastguard Worker test_shape(3, 10, [5, 7, 11, 13, 17], -1) 1151*da0073e9SAndroid Build Coastguard Worker test_shape(3, 10, [5, 7, 11, 13, 17], -6) 1152*da0073e9SAndroid Build Coastguard Worker 1153*da0073e9SAndroid Build Coastguard Worker # bounds 1154*da0073e9SAndroid Build Coastguard Worker test_shape(3, 10, [5, 7, 11, 13, 17], -7, "Dimension out of range") 1155*da0073e9SAndroid Build Coastguard Worker test_shape(3, 10, [5, 7, 11, 13, 17], 6, "Dimension out of range") 1156*da0073e9SAndroid Build Coastguard Worker 1157*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 1158*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 1159*da0073e9SAndroid Build Coastguard Worker def test_select(self, device, dtype, coalesced): 1160*da0073e9SAndroid Build Coastguard Worker def test_shape(sparse_dims, nnz, sizes, select_dim, select_index, fail_message=None): 1161*da0073e9SAndroid Build Coastguard Worker x, _, _ = self._gen_sparse(sparse_dims, nnz, sizes, dtype, device, coalesced) 1162*da0073e9SAndroid Build Coastguard Worker if fail_message: 1163*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(IndexError, fail_message): 1164*da0073e9SAndroid Build Coastguard Worker torch.select(x, select_dim, select_index) 1165*da0073e9SAndroid Build Coastguard Worker else: 1166*da0073e9SAndroid Build Coastguard Worker result = torch.select(x, select_dim, select_index) 1167*da0073e9SAndroid Build Coastguard Worker if result.is_sparse: 1168*da0073e9SAndroid Build Coastguard Worker result = result.to_dense() 1169*da0073e9SAndroid Build Coastguard Worker dense_result = torch.select(x.to_dense(), select_dim, select_index) 1170*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dense_result, result) 1171*da0073e9SAndroid Build Coastguard Worker 1172*da0073e9SAndroid Build Coastguard Worker 1173*da0073e9SAndroid Build Coastguard Worker sizes = [5, 7, 11, 13, 17] 1174*da0073e9SAndroid Build Coastguard Worker # hybrid sparse/dense, select sparse dim, result is dense 1175*da0073e9SAndroid Build Coastguard Worker for i in range(sizes[0]): 1176*da0073e9SAndroid Build Coastguard Worker test_shape(1, 10, sizes, 0, i) 1177*da0073e9SAndroid Build Coastguard Worker test_shape(1, 10, sizes, 0, sizes[0] + 1, r'select[(][)][:] index \d out of range.*') 1178*da0073e9SAndroid Build Coastguard Worker 1179*da0073e9SAndroid Build Coastguard Worker # hybrid sparse/dense, select sparse dim, result is sparse 1180*da0073e9SAndroid Build Coastguard Worker for d in range(3): 1181*da0073e9SAndroid Build Coastguard Worker for i in range(sizes[d]): 1182*da0073e9SAndroid Build Coastguard Worker test_shape(3, 10, sizes, d, i) 1183*da0073e9SAndroid Build Coastguard Worker 1184*da0073e9SAndroid Build Coastguard Worker # hybrid sparse/dense, select dense dim, result is sparse 1185*da0073e9SAndroid Build Coastguard Worker for d in range(1, 3): 1186*da0073e9SAndroid Build Coastguard Worker for i in range(sizes[d]): 1187*da0073e9SAndroid Build Coastguard Worker test_shape(1, 10, sizes, d, i) 1188*da0073e9SAndroid Build Coastguard Worker 1189*da0073e9SAndroid Build Coastguard Worker @dtypes(*integral_types()) 1190*da0073e9SAndroid Build Coastguard Worker def test_select_no_type_promotion(self, device, dtype): 1191*da0073e9SAndroid Build Coastguard Worker # see https://github.com/pytorch/pytorch/issues/82150 1192*da0073e9SAndroid Build Coastguard Worker idx = torch.tensor([[0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 1]]) 1193*da0073e9SAndroid Build Coastguard Worker val = torch.ones(6, dtype=dtype) 1194*da0073e9SAndroid Build Coastguard Worker s = torch.sparse_coo_tensor(idx, val, size=(3, 3)) 1195*da0073e9SAndroid Build Coastguard Worker 1196*da0073e9SAndroid Build Coastguard Worker for t in (s, s * torch.tensor(0, dtype=dtype)): 1197*da0073e9SAndroid Build Coastguard Worker # empty checks 1198*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.dtype, t[2].dtype) 1199*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.dtype, t[0, 1].dtype) 1200*da0073e9SAndroid Build Coastguard Worker # sum should not promote 1201*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.dtype, t[0, 0].dtype) 1202*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.dtype, t[1, 1].dtype) 1203*da0073e9SAndroid Build Coastguard Worker 1204*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 1205*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 1206*da0073e9SAndroid Build Coastguard Worker def test_index_select(self, device, dtype, coalesced): 1207*da0073e9SAndroid Build Coastguard Worker def test_shape(sparse_dims, nnz, sizes, select_dim, select_index, fail_message=None): 1208*da0073e9SAndroid Build Coastguard Worker if isinstance(select_index, int): 1209*da0073e9SAndroid Build Coastguard Worker select_index = [select_index] 1210*da0073e9SAndroid Build Coastguard Worker if isinstance(select_index, list): 1211*da0073e9SAndroid Build Coastguard Worker select_index = torch.tensor(select_index, device=device, dtype=torch.long) 1212*da0073e9SAndroid Build Coastguard Worker x, _, _ = self._gen_sparse(sparse_dims, nnz, sizes, dtype, device, coalesced) 1213*da0073e9SAndroid Build Coastguard Worker if fail_message: 1214*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(IndexError, fail_message): 1215*da0073e9SAndroid Build Coastguard Worker torch.index_select(x, select_dim, select_index) 1216*da0073e9SAndroid Build Coastguard Worker else: 1217*da0073e9SAndroid Build Coastguard Worker result = torch.index_select(x, select_dim, select_index) 1218*da0073e9SAndroid Build Coastguard Worker if result.is_sparse: 1219*da0073e9SAndroid Build Coastguard Worker result = result.to_dense() 1220*da0073e9SAndroid Build Coastguard Worker dense_result = torch.index_select(x.to_dense(), select_dim, select_index) 1221*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dense_result, result) 1222*da0073e9SAndroid Build Coastguard Worker 1223*da0073e9SAndroid Build Coastguard Worker sizes = [5, 7, 11, 13, 17] 1224*da0073e9SAndroid Build Coastguard Worker for d in range(len(sizes)): 1225*da0073e9SAndroid Build Coastguard Worker for index in [0, sizes[d] - 1, [0, sizes[d] // 2, sizes[d] - 1]]: 1226*da0073e9SAndroid Build Coastguard Worker test_shape(1, 10, sizes, d, index) 1227*da0073e9SAndroid Build Coastguard Worker test_shape(len(sizes) // 2, 10, sizes, d, index) 1228*da0073e9SAndroid Build Coastguard Worker test_shape(len(sizes), 10, sizes, d, index) 1229*da0073e9SAndroid Build Coastguard Worker 1230*da0073e9SAndroid Build Coastguard Worker def _test_index_select_exhaustive_index(self, sizes, dims, device, dtype, coalesced): 1231*da0073e9SAndroid Build Coastguard Worker t = make_tensor(sizes, dtype=dtype, device=device) 1232*da0073e9SAndroid Build Coastguard Worker t_sparse = t.to_sparse().coalesce() if coalesced else t.to_sparse() 1233*da0073e9SAndroid Build Coastguard Worker t_small_sparse, _, _ = self._gen_sparse(len(sizes), 2, sizes, dtype, device, coalesced) 1234*da0073e9SAndroid Build Coastguard Worker t_small = t_small_sparse.to_dense() 1235*da0073e9SAndroid Build Coastguard Worker for d in dims: 1236*da0073e9SAndroid Build Coastguard Worker # NOTE: indices are negative 1237*da0073e9SAndroid Build Coastguard Worker idx_dim_d_range = list(range(-sizes[d], 0)) 1238*da0073e9SAndroid Build Coastguard Worker for idx_len in range(sizes[d], sizes[d] + 1): 1239*da0073e9SAndroid Build Coastguard Worker # creates all possible valid indices into dim d of lenght idx_len 1240*da0073e9SAndroid Build Coastguard Worker for idx in itertools.product(*itertools.repeat(idx_dim_d_range, idx_len)): 1241*da0073e9SAndroid Build Coastguard Worker t_idx = torch.tensor(idx, dtype=torch.long, device=device) 1242*da0073e9SAndroid Build Coastguard Worker 1243*da0073e9SAndroid Build Coastguard Worker # NOTE: index_select for dense does not support negative indices, 1244*da0073e9SAndroid Build Coastguard Worker # hence + sizes[d]. See https://github.com/pytorch/pytorch/issues/76347 1245*da0073e9SAndroid Build Coastguard Worker 1246*da0073e9SAndroid Build Coastguard Worker # tests the nnz > sizes[d] branch 1247*da0073e9SAndroid Build Coastguard Worker dense_result = t.index_select(d, t_idx + sizes[d]) 1248*da0073e9SAndroid Build Coastguard Worker sparse_result = t_sparse.index_select(d, t_idx) 1249*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dense_result, sparse_result) 1250*da0073e9SAndroid Build Coastguard Worker 1251*da0073e9SAndroid Build Coastguard Worker # tests the nnz <= sizes[d] branch 1252*da0073e9SAndroid Build Coastguard Worker small_dense_result = t_small.index_select(d, t_idx + sizes[d]) 1253*da0073e9SAndroid Build Coastguard Worker small_sparse_result = t_small_sparse.index_select(d, t_idx) 1254*da0073e9SAndroid Build Coastguard Worker self.assertEqual(small_dense_result, small_sparse_result) 1255*da0073e9SAndroid Build Coastguard Worker 1256*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 1257*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 1258*da0073e9SAndroid Build Coastguard Worker def test_index_select_exhaustive_index_small(self, device, dtype, coalesced): 1259*da0073e9SAndroid Build Coastguard Worker # will trigger brute-force algo 1260*da0073e9SAndroid Build Coastguard Worker self._test_index_select_exhaustive_index((3, 3, 4), range(3), device, dtype, coalesced) 1261*da0073e9SAndroid Build Coastguard Worker 1262*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 1263*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 1264*da0073e9SAndroid Build Coastguard Worker def test_index_select_exhaustive_index_large(self, device, dtype, coalesced): 1265*da0073e9SAndroid Build Coastguard Worker # will trigger more sophisticated algos 1266*da0073e9SAndroid Build Coastguard Worker self._test_index_select_exhaustive_index((100, 50, 3, 3), (2, 3), device, dtype, coalesced) 1267*da0073e9SAndroid Build Coastguard Worker 1268*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 1269*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 1270*da0073e9SAndroid Build Coastguard Worker def test_index_select_empty_and_non_contiguous_index(self, device, dtype, coalesced): 1271*da0073e9SAndroid Build Coastguard Worker # empty index 1272*da0073e9SAndroid Build Coastguard Worker idx_empty = torch.tensor([], dtype=torch.long, device=device) 1273*da0073e9SAndroid Build Coastguard Worker t = make_tensor((5, 5), dtype=dtype, device=device) 1274*da0073e9SAndroid Build Coastguard Worker res_dense = t.index_select(0, idx_empty) 1275*da0073e9SAndroid Build Coastguard Worker res_sparse = t.to_sparse().index_select(0, idx_empty) 1276*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_dense, res_sparse) 1277*da0073e9SAndroid Build Coastguard Worker 1278*da0073e9SAndroid Build Coastguard Worker # non-contigous index 1279*da0073e9SAndroid Build Coastguard Worker idx = torch.randint(low=0, high=5, size=(10, 2), device=device)[:, 0] 1280*da0073e9SAndroid Build Coastguard Worker 1281*da0073e9SAndroid Build Coastguard Worker def run_test(sizes): 1282*da0073e9SAndroid Build Coastguard Worker # case nnz > size[d] 1283*da0073e9SAndroid Build Coastguard Worker t = make_tensor(sizes, dtype=dtype, device=device) 1284*da0073e9SAndroid Build Coastguard Worker res_dense = t.index_select(0, idx) 1285*da0073e9SAndroid Build Coastguard Worker res_sparse = t.to_sparse().index_select(0, idx) 1286*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_dense, res_sparse) 1287*da0073e9SAndroid Build Coastguard Worker 1288*da0073e9SAndroid Build Coastguard Worker # case nnz <= size[d] 1289*da0073e9SAndroid Build Coastguard Worker t_small_sparse, _, _ = self._gen_sparse(len(sizes), 2, sizes, dtype, device, coalesced) 1290*da0073e9SAndroid Build Coastguard Worker res_sparse = t_small_sparse.index_select(0, idx) 1291*da0073e9SAndroid Build Coastguard Worker res_dense = t_small_sparse.to_dense().index_select(0, idx) 1292*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_dense, res_sparse) 1293*da0073e9SAndroid Build Coastguard Worker 1294*da0073e9SAndroid Build Coastguard Worker # brute-force 1295*da0073e9SAndroid Build Coastguard Worker run_test((10, 10)) 1296*da0073e9SAndroid Build Coastguard Worker # more sophisticated algos 1297*da0073e9SAndroid Build Coastguard Worker run_test((10, 100, 100)) 1298*da0073e9SAndroid Build Coastguard Worker 1299*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1300*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 1301*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 1302*da0073e9SAndroid Build Coastguard Worker def test_index_select_parallelization(self, device, dtype, coalesced): 1303*da0073e9SAndroid Build Coastguard Worker """ 1304*da0073e9SAndroid Build Coastguard Worker Test with sizes that will trigger parallelization (i.e. with sizes 1305*da0073e9SAndroid Build Coastguard Worker that are >= at::internal::GRAIN_SIZE) 1306*da0073e9SAndroid Build Coastguard Worker """ 1307*da0073e9SAndroid Build Coastguard Worker def run_test(nnz, size): 1308*da0073e9SAndroid Build Coastguard Worker t_sparse, _, _ = self._gen_sparse(1, nnz, (size,), dtype, device, coalesced) 1309*da0073e9SAndroid Build Coastguard Worker t_dense = t_sparse.to_dense() 1310*da0073e9SAndroid Build Coastguard Worker 1311*da0073e9SAndroid Build Coastguard Worker # idx_small to (sort) and (binary) search into t_sparse 1312*da0073e9SAndroid Build Coastguard Worker idx_small = torch.randint(size, (nnz // 2,), device=device) 1313*da0073e9SAndroid Build Coastguard Worker # idx_large to (sort) and (binary) search into idx_large 1314*da0073e9SAndroid Build Coastguard Worker # NOTE: when coalesced=True, the (binary) search will be 1315*da0073e9SAndroid Build Coastguard Worker # done over t_sparse anyway, as it is already sorted. 1316*da0073e9SAndroid Build Coastguard Worker idx_large = torch.randint(size, (nnz * 2,), device=device) 1317*da0073e9SAndroid Build Coastguard Worker for idx in (idx_small, idx_large): 1318*da0073e9SAndroid Build Coastguard Worker res_dense = t_dense.index_select(0, idx) 1319*da0073e9SAndroid Build Coastguard Worker res_sparse = t_sparse.index_select(0, idx) 1320*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_dense, res_sparse) 1321*da0073e9SAndroid Build Coastguard Worker 1322*da0073e9SAndroid Build Coastguard Worker # NOTE: GRAIN_SIZE = 32768 1323*da0073e9SAndroid Build Coastguard Worker # case nnz <= size[d] 1324*da0073e9SAndroid Build Coastguard Worker tlen = 70000 # > 2 * GRAIN_SIZE 1325*da0073e9SAndroid Build Coastguard Worker run_test(tlen, tlen) 1326*da0073e9SAndroid Build Coastguard Worker 1327*da0073e9SAndroid Build Coastguard Worker # case nnz > size[d] 1328*da0073e9SAndroid Build Coastguard Worker run_test(tlen, tlen // 2) 1329*da0073e9SAndroid Build Coastguard Worker 1330*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1331*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 1332*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 1333*da0073e9SAndroid Build Coastguard Worker def test_mm(self, device, dtype, coalesced): 1334*da0073e9SAndroid Build Coastguard Worker def test_shape(di, dj, dk, nnz): 1335*da0073e9SAndroid Build Coastguard Worker x, _, _ = self._gen_sparse(2, nnz, [di, dj], dtype, device, coalesced) 1336*da0073e9SAndroid Build Coastguard Worker t = torch.randn(di, dk, dtype=dtype, device=device) 1337*da0073e9SAndroid Build Coastguard Worker y = torch.randn(dj, dk, dtype=dtype, device=device) 1338*da0073e9SAndroid Build Coastguard Worker alpha = random.random() 1339*da0073e9SAndroid Build Coastguard Worker beta = random.random() 1340*da0073e9SAndroid Build Coastguard Worker 1341*da0073e9SAndroid Build Coastguard Worker res = torch.addmm(t, x, y, beta=beta, alpha=alpha) 1342*da0073e9SAndroid Build Coastguard Worker expected = torch.addmm(t, self.safeToDense(x), y, beta=beta, alpha=alpha) 1343*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected) 1344*da0073e9SAndroid Build Coastguard Worker 1345*da0073e9SAndroid Build Coastguard Worker res = torch.addmm(t, x, y) 1346*da0073e9SAndroid Build Coastguard Worker expected = torch.addmm(t, self.safeToDense(x), y) 1347*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected) 1348*da0073e9SAndroid Build Coastguard Worker 1349*da0073e9SAndroid Build Coastguard Worker res = torch.mm(x, y) 1350*da0073e9SAndroid Build Coastguard Worker expected = torch.mm(self.safeToDense(x), y) 1351*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected) 1352*da0073e9SAndroid Build Coastguard Worker 1353*da0073e9SAndroid Build Coastguard Worker test_shape(10, 100, 100, 20) 1354*da0073e9SAndroid Build Coastguard Worker test_shape(100, 1000, 200, 20) 1355*da0073e9SAndroid Build Coastguard Worker test_shape(64, 10000, 300, 20) 1356*da0073e9SAndroid Build Coastguard Worker test_shape(0, 100, 100, 0) 1357*da0073e9SAndroid Build Coastguard Worker test_shape(10, 0, 100, 0) 1358*da0073e9SAndroid Build Coastguard Worker test_shape(10, 100, 0, 0) 1359*da0073e9SAndroid Build Coastguard Worker test_shape(10, 100, 0, 20) 1360*da0073e9SAndroid Build Coastguard Worker 1361*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 1362*da0073e9SAndroid Build Coastguard Worker IS_WINDOWS and TEST_CUDA, 1363*da0073e9SAndroid Build Coastguard Worker "bmm sparse-dense CUDA is not yet supported in Windows, at least up to CUDA 10.1" 1364*da0073e9SAndroid Build Coastguard Worker ) 1365*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 1366*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 1367*da0073e9SAndroid Build Coastguard Worker def test_bmm(self, device, dtype, coalesced): 1368*da0073e9SAndroid Build Coastguard Worker def test_shape(num_mats, dim_i, dim_j, dim_k, nnz): 1369*da0073e9SAndroid Build Coastguard Worker a_list = [] 1370*da0073e9SAndroid Build Coastguard Worker b_list = [] 1371*da0073e9SAndroid Build Coastguard Worker for mat_idx in range(num_mats): 1372*da0073e9SAndroid Build Coastguard Worker a_mat = self._gen_sparse(2, nnz, [dim_i, dim_j], dtype, device, coalesced)[0] 1373*da0073e9SAndroid Build Coastguard Worker b_mat = torch.randn([dim_j, dim_k], dtype=dtype, device=device) 1374*da0073e9SAndroid Build Coastguard Worker a_list.append(a_mat) 1375*da0073e9SAndroid Build Coastguard Worker b_list.append(b_mat) 1376*da0073e9SAndroid Build Coastguard Worker 1377*da0073e9SAndroid Build Coastguard Worker a = torch.stack(a_list) 1378*da0073e9SAndroid Build Coastguard Worker b = torch.stack(b_list) 1379*da0073e9SAndroid Build Coastguard Worker ab = a.bmm(b) 1380*da0073e9SAndroid Build Coastguard Worker 1381*da0073e9SAndroid Build Coastguard Worker # Compare each matrix against result from mm() 1382*da0073e9SAndroid Build Coastguard Worker for mat_idx in range(num_mats): 1383*da0073e9SAndroid Build Coastguard Worker a_mat = a_list[mat_idx] 1384*da0073e9SAndroid Build Coastguard Worker b_mat = b_list[mat_idx] 1385*da0073e9SAndroid Build Coastguard Worker ab_mat_bmm = ab[mat_idx] 1386*da0073e9SAndroid Build Coastguard Worker ab_mat_mm = a_mat.mm(b_mat) 1387*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ab_mat_bmm, ab_mat_mm) 1388*da0073e9SAndroid Build Coastguard Worker 1389*da0073e9SAndroid Build Coastguard Worker test_shape(10, 10, 100, 99, 20) 1390*da0073e9SAndroid Build Coastguard Worker test_shape(10, 100, 1000, 200, 20) 1391*da0073e9SAndroid Build Coastguard Worker test_shape(10, 64, 10000, 300, 20) 1392*da0073e9SAndroid Build Coastguard Worker test_shape(10, 0, 100, 99, 0) 1393*da0073e9SAndroid Build Coastguard Worker test_shape(10, 10, 0, 100, 0) 1394*da0073e9SAndroid Build Coastguard Worker test_shape(10, 10, 100, 0, 0) 1395*da0073e9SAndroid Build Coastguard Worker test_shape(10, 10, 100, 0, 20) 1396*da0073e9SAndroid Build Coastguard Worker test_shape(10, 10, 100, 0, 20) 1397*da0073e9SAndroid Build Coastguard Worker 1398*da0073e9SAndroid Build Coastguard Worker a = torch.rand([10, 23, 32], dtype=dtype, device=device) 1399*da0073e9SAndroid Build Coastguard Worker a[3] = torch.zeros(23, 32, dtype=dtype, device=device) 1400*da0073e9SAndroid Build Coastguard Worker a[6] = torch.zeros(23, 32, dtype=dtype, device=device) 1401*da0073e9SAndroid Build Coastguard Worker a = a.to_sparse() 1402*da0073e9SAndroid Build Coastguard Worker b = torch.rand([10, 32, 10], dtype=dtype, device=device) 1403*da0073e9SAndroid Build Coastguard Worker b[4] = torch.zeros(32, 10, dtype=dtype, device=device) 1404*da0073e9SAndroid Build Coastguard Worker b[6] = torch.zeros(32, 10, dtype=dtype, device=device) 1405*da0073e9SAndroid Build Coastguard Worker ab = a.bmm(b) 1406*da0073e9SAndroid Build Coastguard Worker for mat_idx in range(ab.size(0)): 1407*da0073e9SAndroid Build Coastguard Worker ab_mat = ab[mat_idx] 1408*da0073e9SAndroid Build Coastguard Worker ab_mat_check = a[mat_idx].mm(b[mat_idx]) 1409*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ab_mat, ab_mat_check) 1410*da0073e9SAndroid Build Coastguard Worker 1411*da0073e9SAndroid Build Coastguard Worker ab_traspose_check = b.transpose(1, 2).to_sparse().bmm( 1412*da0073e9SAndroid Build Coastguard Worker a.transpose(1, 2).to_dense() 1413*da0073e9SAndroid Build Coastguard Worker ).transpose(1, 2) 1414*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ab, ab_traspose_check) 1415*da0073e9SAndroid Build Coastguard Worker 1416*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1417*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 1418*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 1419*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 1420*da0073e9SAndroid Build Coastguard Worker IS_WINDOWS, 1421*da0073e9SAndroid Build Coastguard Worker "bmm sparse-dense CUDA is not yet supported in Windows, at least up to CUDA 10.1" 1422*da0073e9SAndroid Build Coastguard Worker ) 1423*da0073e9SAndroid Build Coastguard Worker def test_bmm_deterministic(self, device, dtype, coalesced): 1424*da0073e9SAndroid Build Coastguard Worker def test_shape(num_mats, dim_i, dim_j, dim_k, nnz): 1425*da0073e9SAndroid Build Coastguard Worker a_list = [] 1426*da0073e9SAndroid Build Coastguard Worker b_list = [] 1427*da0073e9SAndroid Build Coastguard Worker for mat_idx in range(num_mats): 1428*da0073e9SAndroid Build Coastguard Worker a_list.append(self._gen_sparse(2, nnz, [dim_i, dim_j], dtype, device, coalesced)[0]) 1429*da0073e9SAndroid Build Coastguard Worker b_list.append(torch.randn([dim_j, dim_k], dtype=dtype, device=device)) 1430*da0073e9SAndroid Build Coastguard Worker 1431*da0073e9SAndroid Build Coastguard Worker a = torch.stack(a_list).cuda() 1432*da0073e9SAndroid Build Coastguard Worker b = torch.stack(b_list).cuda() 1433*da0073e9SAndroid Build Coastguard Worker with DeterministicGuard(torch.are_deterministic_algorithms_enabled()): 1434*da0073e9SAndroid Build Coastguard Worker torch.use_deterministic_algorithms(False) 1435*da0073e9SAndroid Build Coastguard Worker ab_nondeterministic = torch.bmm(a, b) 1436*da0073e9SAndroid Build Coastguard Worker torch.use_deterministic_algorithms(True) 1437*da0073e9SAndroid Build Coastguard Worker ab_deterministic = torch.bmm(a, b) 1438*da0073e9SAndroid Build Coastguard Worker diff_abs = (ab_deterministic - ab_nondeterministic).abs() 1439*da0073e9SAndroid Build Coastguard Worker diff_rel = diff_abs / ab_deterministic.abs() 1440*da0073e9SAndroid Build Coastguard Worker diff_rel[torch.isnan(diff_rel)] = 0 1441*da0073e9SAndroid Build Coastguard Worker 1442*da0073e9SAndroid Build Coastguard Worker # deterministic and non-deterministic results should either be 1443*da0073e9SAndroid Build Coastguard Worker # equal or within a small relative difference 1444*da0073e9SAndroid Build Coastguard Worker equal_abs_or_rel = diff_abs.eq(0).logical_or(diff_rel.lt(0.001)) 1445*da0073e9SAndroid Build Coastguard Worker self.assertTrue(equal_abs_or_rel.all()) 1446*da0073e9SAndroid Build Coastguard Worker 1447*da0073e9SAndroid Build Coastguard Worker test_shape(10, 10, 100, 99, 20) 1448*da0073e9SAndroid Build Coastguard Worker test_shape(10, 100, 1000, 200, 20) 1449*da0073e9SAndroid Build Coastguard Worker test_shape(10, 64, 10000, 300, 20) 1450*da0073e9SAndroid Build Coastguard Worker test_shape(10, 0, 100, 99, 0) 1451*da0073e9SAndroid Build Coastguard Worker test_shape(10, 10, 0, 100, 0) 1452*da0073e9SAndroid Build Coastguard Worker test_shape(10, 10, 100, 0, 0) 1453*da0073e9SAndroid Build Coastguard Worker test_shape(10, 10, 100, 0, 20) 1454*da0073e9SAndroid Build Coastguard Worker test_shape(10, 10, 100, 0, 20) 1455*da0073e9SAndroid Build Coastguard Worker 1456*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1457*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 1458*da0073e9SAndroid Build Coastguard Worker not IS_WINDOWS or not TEST_WITH_ROCM, 1459*da0073e9SAndroid Build Coastguard Worker "this test ensures bmm sparse-dense CUDA gives an error when run on Windows with CUDA < 11.0" 1460*da0073e9SAndroid Build Coastguard Worker ) 1461*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 1462*da0073e9SAndroid Build Coastguard Worker def test_bmm_windows_error(self, device, dtype): 1463*da0073e9SAndroid Build Coastguard Worker a = torch.rand(2, 2, 2, dtype=dtype).to_sparse().cuda() 1464*da0073e9SAndroid Build Coastguard Worker b = torch.rand(2, 2, 2, dtype=dtype).cuda() 1465*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1466*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1467*da0073e9SAndroid Build Coastguard Worker "bmm sparse-dense CUDA is not supported on Windows with cuda before 11.0"): 1468*da0073e9SAndroid Build Coastguard Worker ab = a.bmm(b) 1469*da0073e9SAndroid Build Coastguard Worker 1470*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1471*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 1472*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 1473*da0073e9SAndroid Build Coastguard Worker def test_saddmm(self, device, dtype, coalesced): 1474*da0073e9SAndroid Build Coastguard Worker def test_shape(di, dj, dk, nnz): 1475*da0073e9SAndroid Build Coastguard Worker x = self._gen_sparse(2, nnz, [di, dj], dtype, device, coalesced)[0] 1476*da0073e9SAndroid Build Coastguard Worker t = self._gen_sparse(2, nnz, [di, dk], dtype, device, coalesced)[0] 1477*da0073e9SAndroid Build Coastguard Worker y = torch.randn(dj, dk, dtype=dtype, device=device) 1478*da0073e9SAndroid Build Coastguard Worker alpha = random.random() 1479*da0073e9SAndroid Build Coastguard Worker beta = random.random() 1480*da0073e9SAndroid Build Coastguard Worker 1481*da0073e9SAndroid Build Coastguard Worker res = torch.saddmm(t, x, y, beta=beta, alpha=alpha) 1482*da0073e9SAndroid Build Coastguard Worker expected = torch.addmm(self.safeToDense(t), self.safeToDense(x), y, beta=beta, alpha=alpha) 1483*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.safeToDense(res), expected) 1484*da0073e9SAndroid Build Coastguard Worker 1485*da0073e9SAndroid Build Coastguard Worker res = torch.saddmm(t, x, y) 1486*da0073e9SAndroid Build Coastguard Worker expected = torch.addmm(self.safeToDense(t), self.safeToDense(x), y) 1487*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.safeToDense(res), expected) 1488*da0073e9SAndroid Build Coastguard Worker 1489*da0073e9SAndroid Build Coastguard Worker res = torch.smm(x, y) 1490*da0073e9SAndroid Build Coastguard Worker expected = torch.mm(self.safeToDense(x), y) 1491*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.safeToDense(res), expected) 1492*da0073e9SAndroid Build Coastguard Worker 1493*da0073e9SAndroid Build Coastguard Worker test_shape(7, 5, 3, 20) 1494*da0073e9SAndroid Build Coastguard Worker test_shape(1000, 100, 100, 20) 1495*da0073e9SAndroid Build Coastguard Worker test_shape(3000, 64, 300, 20) 1496*da0073e9SAndroid Build Coastguard Worker test_shape(0, 100, 100, 0) 1497*da0073e9SAndroid Build Coastguard Worker test_shape(1000, 0, 100, 0) 1498*da0073e9SAndroid Build Coastguard Worker test_shape(1000, 100, 0, 0) 1499*da0073e9SAndroid Build Coastguard Worker 1500*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1501*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 1502*da0073e9SAndroid Build Coastguard Worker # adding a graph break before self.assertFalse(weight._indices().is_contiguous()) 1503*da0073e9SAndroid Build Coastguard Worker # makes the test pass so some existent sparse related bug 1504*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("skip") 1505*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 1506*da0073e9SAndroid Build Coastguard Worker def test_sspaddmm(self, device, dtype, coalesced): 1507*da0073e9SAndroid Build Coastguard Worker 1508*da0073e9SAndroid Build Coastguard Worker def test_shape(di, dj, dk, nnz): 1509*da0073e9SAndroid Build Coastguard Worker x = self._gen_sparse(2, nnz, [di, dj], dtype, device, coalesced)[0] 1510*da0073e9SAndroid Build Coastguard Worker t = self._gen_sparse(2, nnz, [di, dk], dtype, device, coalesced)[0] 1511*da0073e9SAndroid Build Coastguard Worker y = torch.randn(dj, dk, dtype=dtype, device=device) 1512*da0073e9SAndroid Build Coastguard Worker alpha = random.random() 1513*da0073e9SAndroid Build Coastguard Worker beta = random.random() 1514*da0073e9SAndroid Build Coastguard Worker 1515*da0073e9SAndroid Build Coastguard Worker res = t.sspaddmm(x, y, beta=beta, alpha=alpha) 1516*da0073e9SAndroid Build Coastguard Worker expected = torch.addmm(self.safeToDense(t), self.safeToDense(x), y, beta=beta, alpha=alpha) 1517*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.safeToDense(res), expected) 1518*da0073e9SAndroid Build Coastguard Worker 1519*da0073e9SAndroid Build Coastguard Worker res = t.sspaddmm(x, y) 1520*da0073e9SAndroid Build Coastguard Worker expected = torch.addmm(self.safeToDense(t), self.safeToDense(x), y) 1521*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.safeToDense(res), expected) 1522*da0073e9SAndroid Build Coastguard Worker 1523*da0073e9SAndroid Build Coastguard Worker test_shape(7, 5, 3, 20) 1524*da0073e9SAndroid Build Coastguard Worker test_shape(1000, 100, 100, 20) 1525*da0073e9SAndroid Build Coastguard Worker test_shape(3000, 64, 300, 20) 1526*da0073e9SAndroid Build Coastguard Worker test_shape(0, 100, 100, 0) 1527*da0073e9SAndroid Build Coastguard Worker test_shape(1000, 0, 100, 0) 1528*da0073e9SAndroid Build Coastguard Worker test_shape(1000, 100, 0, 0) 1529*da0073e9SAndroid Build Coastguard Worker 1530*da0073e9SAndroid Build Coastguard Worker # Test code from issue https://github.com/pytorch/pytorch/issues/45113 1531*da0073e9SAndroid Build Coastguard Worker batch_size, input_size, hidden_size = 5, 3, 7 1532*da0073e9SAndroid Build Coastguard Worker 1533*da0073e9SAndroid Build Coastguard Worker # Create coalesced sparse tensor with non-contiguous indices 1534*da0073e9SAndroid Build Coastguard Worker weight = torch.randn(hidden_size, input_size, dtype=dtype, device=device).to_sparse() 1535*da0073e9SAndroid Build Coastguard Worker self.assertTrue(weight.is_coalesced()) 1536*da0073e9SAndroid Build Coastguard Worker non_contig_indices = weight.indices().mT.contiguous().mT 1537*da0073e9SAndroid Build Coastguard Worker weight = torch.sparse_coo_tensor( 1538*da0073e9SAndroid Build Coastguard Worker indices=non_contig_indices, values=weight.values(), size=weight.shape) 1539*da0073e9SAndroid Build Coastguard Worker weight._coalesced_(True) 1540*da0073e9SAndroid Build Coastguard Worker self.assertFalse(weight._indices().is_contiguous()) 1541*da0073e9SAndroid Build Coastguard Worker # Create un/coalesced sparse tensor 1542*da0073e9SAndroid Build Coastguard Worker bias = torch.randn((hidden_size, 1), dtype=dtype, device=device).to_sparse() 1543*da0073e9SAndroid Build Coastguard Worker bias = torch.cat([bias] * batch_size, dim=1) 1544*da0073e9SAndroid Build Coastguard Worker 1545*da0073e9SAndroid Build Coastguard Worker if coalesced: 1546*da0073e9SAndroid Build Coastguard Worker bias = bias.coalesce() 1547*da0073e9SAndroid Build Coastguard Worker 1548*da0073e9SAndroid Build Coastguard Worker x = torch.randn(input_size, batch_size, dtype=dtype, device=device) 1549*da0073e9SAndroid Build Coastguard Worker res = bias.sspaddmm(weight, x) 1550*da0073e9SAndroid Build Coastguard Worker 1551*da0073e9SAndroid Build Coastguard Worker true_result = (bias.to_dense() + torch.matmul(weight.to_dense(), x)).to_sparse() 1552*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.safeToDense(res), self.safeToDense(true_result)) 1553*da0073e9SAndroid Build Coastguard Worker 1554*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 1555*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.bfloat16: 5e-2, torch.float16: 5e-2}) 1556*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble, torch.bfloat16, torch.float16) 1557*da0073e9SAndroid Build Coastguard Worker def test_sparse_addmm(self, device, dtype, coalesced): 1558*da0073e9SAndroid Build Coastguard Worker if (dtype is torch.bfloat16 or dtype is torch.float16) and device.startswith("cuda"): 1559*da0073e9SAndroid Build Coastguard Worker self.skipTest('addmm_sparse_cuda is not implemented for BFloat16 and Half') 1560*da0073e9SAndroid Build Coastguard Worker 1561*da0073e9SAndroid Build Coastguard Worker 1562*da0073e9SAndroid Build Coastguard Worker def test_shape(m, n, p, nnz, broadcast, alpha_beta=None): 1563*da0073e9SAndroid Build Coastguard Worker if alpha_beta is None: 1564*da0073e9SAndroid Build Coastguard Worker alpha = random.random() 1565*da0073e9SAndroid Build Coastguard Worker beta = random.random() 1566*da0073e9SAndroid Build Coastguard Worker else: 1567*da0073e9SAndroid Build Coastguard Worker alpha, beta = alpha_beta 1568*da0073e9SAndroid Build Coastguard Worker if broadcast: 1569*da0073e9SAndroid Build Coastguard Worker D1 = make_tensor((), dtype=dtype, device=device, requires_grad=True) 1570*da0073e9SAndroid Build Coastguard Worker else: 1571*da0073e9SAndroid Build Coastguard Worker D1 = make_tensor([n, p], dtype=dtype, device=device, requires_grad=True) 1572*da0073e9SAndroid Build Coastguard Worker D2 = make_tensor([m, p], dtype=dtype, device=device, requires_grad=True) 1573*da0073e9SAndroid Build Coastguard Worker S = self._gen_sparse(2, nnz, [n, m], dtype, device, coalesced)[0] 1574*da0073e9SAndroid Build Coastguard Worker S_dense = S.to_dense().requires_grad_(True) 1575*da0073e9SAndroid Build Coastguard Worker S.requires_grad_(True) 1576*da0073e9SAndroid Build Coastguard Worker Y = torch.sparse.addmm(D1, S, D2, beta=beta, alpha=alpha) 1577*da0073e9SAndroid Build Coastguard Worker Y_dense = torch.addmm(D1, S_dense, D2, beta=beta, alpha=alpha) 1578*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Y, Y_dense) 1579*da0073e9SAndroid Build Coastguard Worker 1580*da0073e9SAndroid Build Coastguard Worker if dtype not in {torch.double, torch.cdouble}: 1581*da0073e9SAndroid Build Coastguard Worker # gradcheck will likely fail with low-precision input dtypes. 1582*da0073e9SAndroid Build Coastguard Worker return 1583*da0073e9SAndroid Build Coastguard Worker 1584*da0073e9SAndroid Build Coastguard Worker def fn(S, D1, D2, beta=beta, alpha=alpha): 1585*da0073e9SAndroid Build Coastguard Worker return torch.sparse.addmm(D1, S, D2, beta=beta, alpha=alpha) 1586*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, (S, D1, D2), masked=True) 1587*da0073e9SAndroid Build Coastguard Worker 1588*da0073e9SAndroid Build Coastguard Worker test_shape(7, 8, 9, 20, False, None) 1589*da0073e9SAndroid Build Coastguard Worker test_shape(7, 8, 9, 20, True, None) 1590*da0073e9SAndroid Build Coastguard Worker test_shape(7, 8, 9, 20, False, (1, 0)) 1591*da0073e9SAndroid Build Coastguard Worker test_shape(7, 8, 9, 20, True, (1, 0)) 1592*da0073e9SAndroid Build Coastguard Worker test_shape(7, 8, 9, 20, False, (1, 1)) 1593*da0073e9SAndroid Build Coastguard Worker test_shape(7, 8, 9, 20, True, (1, 1)) 1594*da0073e9SAndroid Build Coastguard Worker 1595*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 1596*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 1597*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error") 1598*da0073e9SAndroid Build Coastguard Worker def test_sparse_mm(self, device, dtype, coalesced): 1599*da0073e9SAndroid Build Coastguard Worker def test_shape(d1, d2, d3, nnz, transposed): 1600*da0073e9SAndroid Build Coastguard Worker if transposed: 1601*da0073e9SAndroid Build Coastguard Worker D = torch.randn(d3, d2, dtype=dtype, 1602*da0073e9SAndroid Build Coastguard Worker device=device).t_().requires_grad_(True) 1603*da0073e9SAndroid Build Coastguard Worker else: 1604*da0073e9SAndroid Build Coastguard Worker D = torch.randn(d2, d3, dtype=dtype, device=device).requires_grad_(True) 1605*da0073e9SAndroid Build Coastguard Worker S = self._gen_sparse(2, nnz, [d1, d2], dtype, device, coalesced)[0] 1606*da0073e9SAndroid Build Coastguard Worker S_dense = S.to_dense().requires_grad_(True) 1607*da0073e9SAndroid Build Coastguard Worker S.requires_grad_(True) 1608*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.sparse.mm(S, D), torch.mm(S_dense, D)) 1609*da0073e9SAndroid Build Coastguard Worker 1610*da0073e9SAndroid Build Coastguard Worker def fn(S, D): 1611*da0073e9SAndroid Build Coastguard Worker return torch.sparse.mm(S, D) 1612*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, (S, D), masked=True) 1613*da0073e9SAndroid Build Coastguard Worker 1614*da0073e9SAndroid Build Coastguard Worker test_shape(7, 8, 9, 20, False) 1615*da0073e9SAndroid Build Coastguard Worker test_shape(7, 8, 9, 20, True) 1616*da0073e9SAndroid Build Coastguard Worker 1617*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 1618*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 1619*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error") 1620*da0073e9SAndroid Build Coastguard Worker @gradcheck_semantics() 1621*da0073e9SAndroid Build Coastguard Worker def test_sparse_mul(self, device, dtype, coalesced, gradcheck): 1622*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/79914 1623*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([[0., 1]], dtype=dtype, device=device).to_sparse().requires_grad_(True) 1624*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([[0., 1]], dtype=dtype, device=device).to_sparse().requires_grad_(True) 1625*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda x, y: torch.sparse.sum(x * y).to_dense(masked_grad=gradcheck.masked), [a, b]) 1626*da0073e9SAndroid Build Coastguard Worker 1627*da0073e9SAndroid Build Coastguard Worker def test_shape(sparse_dims, nnz, with_shape): 1628*da0073e9SAndroid Build Coastguard Worker a = self._gen_sparse(sparse_dims, nnz, with_shape, dtype, device, coalesced)[0].requires_grad_(True) 1629*da0073e9SAndroid Build Coastguard Worker b = self._gen_sparse(sparse_dims, nnz, with_shape, dtype, device, coalesced)[0].requires_grad_(True) 1630*da0073e9SAndroid Build Coastguard Worker 1631*da0073e9SAndroid Build Coastguard Worker self.assertEqual((a * b).to_dense(), a.to_dense() * b.to_dense(), masked=True) 1632*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda x, y: (x * y).to_dense(), [a, b]) 1633*da0073e9SAndroid Build Coastguard Worker # Issues with 0-dim indices/values 1634*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda x, y: torch.sparse.sum(x * y).to_dense(), [a, b], masked=True) 1635*da0073e9SAndroid Build Coastguard Worker 1636*da0073e9SAndroid Build Coastguard Worker # TODO: Re-enable these 1637*da0073e9SAndroid Build Coastguard Worker # test_shape(2, 3, [2, 3, 4, 5]) 1638*da0073e9SAndroid Build Coastguard Worker # test_shape(2, 3, [2, 2, 0]) 1639*da0073e9SAndroid Build Coastguard Worker 1640*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 1641*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 1642*da0073e9SAndroid Build Coastguard Worker def test_dsmm(self, device, dtype, coalesced): 1643*da0073e9SAndroid Build Coastguard Worker def test_shape(di, dj, dk, nnz): 1644*da0073e9SAndroid Build Coastguard Worker x = self._gen_sparse(2, nnz, [di, dj], dtype, device, coalesced)[0] 1645*da0073e9SAndroid Build Coastguard Worker y = self.randn(dj, dk, dtype=dtype, device=device) 1646*da0073e9SAndroid Build Coastguard Worker 1647*da0073e9SAndroid Build Coastguard Worker res = torch.dsmm(x, y) 1648*da0073e9SAndroid Build Coastguard Worker expected = torch.mm(self.safeToDense(x), y) 1649*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected) 1650*da0073e9SAndroid Build Coastguard Worker 1651*da0073e9SAndroid Build Coastguard Worker test_shape(7, 5, 3, 20) 1652*da0073e9SAndroid Build Coastguard Worker test_shape(1000, 100, 100, 20) 1653*da0073e9SAndroid Build Coastguard Worker test_shape(3000, 64, 300, 20) 1654*da0073e9SAndroid Build Coastguard Worker test_shape(0, 100, 100, 0) 1655*da0073e9SAndroid Build Coastguard Worker test_shape(1000, 0, 100, 0) 1656*da0073e9SAndroid Build Coastguard Worker test_shape(1000, 100, 0, 0) 1657*da0073e9SAndroid Build Coastguard Worker test_shape(1000, 100, 0, 20) 1658*da0073e9SAndroid Build Coastguard Worker 1659*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 1660*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 1661*da0073e9SAndroid Build Coastguard Worker def test_hsmm(self, device, dtype, coalesced): 1662*da0073e9SAndroid Build Coastguard Worker def test_shape(di, dj, dk, nnz): 1663*da0073e9SAndroid Build Coastguard Worker x = self._gen_sparse(2, nnz, [di, dj], dtype, device, coalesced)[0] 1664*da0073e9SAndroid Build Coastguard Worker y = self.randn(dj, dk, dtype=dtype, device=device) 1665*da0073e9SAndroid Build Coastguard Worker 1666*da0073e9SAndroid Build Coastguard Worker res = torch.hsmm(x, y) 1667*da0073e9SAndroid Build Coastguard Worker expected = torch.mm(self.safeToDense(x), y) 1668*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.to_dense(), expected) 1669*da0073e9SAndroid Build Coastguard Worker 1670*da0073e9SAndroid Build Coastguard Worker test_shape(7, 5, 3, 20) 1671*da0073e9SAndroid Build Coastguard Worker test_shape(1000, 100, 100, 20) 1672*da0073e9SAndroid Build Coastguard Worker test_shape(3000, 64, 300, 20) 1673*da0073e9SAndroid Build Coastguard Worker test_shape(0, 100, 100, 0) 1674*da0073e9SAndroid Build Coastguard Worker test_shape(1000, 0, 100, 0) 1675*da0073e9SAndroid Build Coastguard Worker test_shape(1000, 100, 0, 0) 1676*da0073e9SAndroid Build Coastguard Worker test_shape(1000, 100, 0, 20) 1677*da0073e9SAndroid Build Coastguard Worker 1678*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 1679*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 1680*da0073e9SAndroid Build Coastguard Worker def test_spadd(self, device, dtype, coalesced): 1681*da0073e9SAndroid Build Coastguard Worker 1682*da0073e9SAndroid Build Coastguard Worker def _test_spadd_shape(nnz, shape_i, shape_v=None): 1683*da0073e9SAndroid Build Coastguard Worker shape = shape_i + (shape_v or []) 1684*da0073e9SAndroid Build Coastguard Worker x, _, _ = self._gen_sparse(len(shape_i), nnz, shape, dtype, device, coalesced) 1685*da0073e9SAndroid Build Coastguard Worker y = self.randn(*shape, dtype=dtype, device=device) 1686*da0073e9SAndroid Build Coastguard Worker r = random.random() 1687*da0073e9SAndroid Build Coastguard Worker 1688*da0073e9SAndroid Build Coastguard Worker res = torch.add(y, x, alpha=r) 1689*da0073e9SAndroid Build Coastguard Worker expected = y + r * self.safeToDense(x) 1690*da0073e9SAndroid Build Coastguard Worker 1691*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected) 1692*da0073e9SAndroid Build Coastguard Worker 1693*da0073e9SAndroid Build Coastguard Worker # Non contiguous dense tensor 1694*da0073e9SAndroid Build Coastguard Worker s = list(shape) 1695*da0073e9SAndroid Build Coastguard Worker s[0] = shape[-1] 1696*da0073e9SAndroid Build Coastguard Worker s[-1] = shape[0] 1697*da0073e9SAndroid Build Coastguard Worker y = self.randn(*s, dtype=dtype, device=device) 1698*da0073e9SAndroid Build Coastguard Worker y.transpose_(0, len(s) - 1) 1699*da0073e9SAndroid Build Coastguard Worker r = random.random() 1700*da0073e9SAndroid Build Coastguard Worker 1701*da0073e9SAndroid Build Coastguard Worker res = torch.add(y, x, alpha=r) 1702*da0073e9SAndroid Build Coastguard Worker expected = y + r * self.safeToDense(x) 1703*da0073e9SAndroid Build Coastguard Worker 1704*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected) 1705*da0073e9SAndroid Build Coastguard Worker 1706*da0073e9SAndroid Build Coastguard Worker x, i, v = self._gen_sparse(len(shape_i), nnz, shape, dtype, device, coalesced) 1707*da0073e9SAndroid Build Coastguard Worker nnz = i.size(1) 1708*da0073e9SAndroid Build Coastguard Worker 1709*da0073e9SAndroid Build Coastguard Worker # Non contiguous sparse indices tensor 1710*da0073e9SAndroid Build Coastguard Worker x_ = self.sparse_tensor(i[:, ::2], v[:(nnz + 1) // 2], x.shape, dtype=dtype, device=device) 1711*da0073e9SAndroid Build Coastguard Worker res = torch.add(y, x_, alpha=r) 1712*da0073e9SAndroid Build Coastguard Worker expected = y + r * self.safeToDense(x_) 1713*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected) 1714*da0073e9SAndroid Build Coastguard Worker 1715*da0073e9SAndroid Build Coastguard Worker # Non contiguous sparse values tensor 1716*da0073e9SAndroid Build Coastguard Worker 1717*da0073e9SAndroid Build Coastguard Worker x_ = self.sparse_tensor(i[:, :(nnz + 1) // 2], v[::2], x.shape, dtype=dtype, device=device) 1718*da0073e9SAndroid Build Coastguard Worker res = torch.add(y, x_, alpha=r) 1719*da0073e9SAndroid Build Coastguard Worker expected = y + r * self.safeToDense(x_) 1720*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected) 1721*da0073e9SAndroid Build Coastguard Worker 1722*da0073e9SAndroid Build Coastguard Worker # Non contiguous sparse indices and values tensors 1723*da0073e9SAndroid Build Coastguard Worker x_ = self.sparse_tensor(i[:, 1::2], v[1::2], x.shape, dtype=dtype, device=device) 1724*da0073e9SAndroid Build Coastguard Worker res = torch.add(y, x_, alpha=r) 1725*da0073e9SAndroid Build Coastguard Worker expected = y + r * self.safeToDense(x_) 1726*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected) 1727*da0073e9SAndroid Build Coastguard Worker 1728*da0073e9SAndroid Build Coastguard Worker def _test_spadd(): 1729*da0073e9SAndroid Build Coastguard Worker _test_spadd_shape(10, [5, 6]) 1730*da0073e9SAndroid Build Coastguard Worker _test_spadd_shape(10, [10, 10, 10]) 1731*da0073e9SAndroid Build Coastguard Worker _test_spadd_shape(10, [50, 30, 20]) 1732*da0073e9SAndroid Build Coastguard Worker _test_spadd_shape(10, [5, 5, 5, 5, 5, 5]) 1733*da0073e9SAndroid Build Coastguard Worker _test_spadd_shape(0, [0, 30, 20]) 1734*da0073e9SAndroid Build Coastguard Worker _test_spadd_shape(0, [50, 0, 20]) 1735*da0073e9SAndroid Build Coastguard Worker _test_spadd_shape(0, [50, 30, 0]) 1736*da0073e9SAndroid Build Coastguard Worker 1737*da0073e9SAndroid Build Coastguard Worker def _test_spadd_hybrid(): 1738*da0073e9SAndroid Build Coastguard Worker _test_spadd_shape(10, [5, 6], [2, 3]) 1739*da0073e9SAndroid Build Coastguard Worker _test_spadd_shape(10, [10, 10, 10], [3]) 1740*da0073e9SAndroid Build Coastguard Worker _test_spadd_shape(10, [50, 30, 20], [2]) 1741*da0073e9SAndroid Build Coastguard Worker _test_spadd_shape(10, [5, 5, 5, 5, 5, 5], [2]) 1742*da0073e9SAndroid Build Coastguard Worker _test_spadd_shape(0, [0, 30, 20], [2, 0]) 1743*da0073e9SAndroid Build Coastguard Worker _test_spadd_shape(0, [50, 0, 20], [2, 0]) 1744*da0073e9SAndroid Build Coastguard Worker _test_spadd_shape(0, [50, 30, 0], [2, 0]) 1745*da0073e9SAndroid Build Coastguard Worker _test_spadd_shape(10, [50, 30, 20], [2, 0]) 1746*da0073e9SAndroid Build Coastguard Worker 1747*da0073e9SAndroid Build Coastguard Worker _test_spadd() 1748*da0073e9SAndroid Build Coastguard Worker _test_spadd_hybrid() 1749*da0073e9SAndroid Build Coastguard Worker 1750*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 1751*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 1752*da0073e9SAndroid Build Coastguard Worker def test_sparse_add_out_bfloat16(self, device, dtype, coalesced): 1753*da0073e9SAndroid Build Coastguard Worker # fp32 1754*da0073e9SAndroid Build Coastguard Worker x, _, _ = self._gen_sparse(3, 5, 10, dtype, device, coalesced) 1755*da0073e9SAndroid Build Coastguard Worker y, _, _ = self._gen_sparse(3, 5, 10, dtype, device, coalesced) 1756*da0073e9SAndroid Build Coastguard Worker res_fp32 = torch.add(x, y) 1757*da0073e9SAndroid Build Coastguard Worker 1758*da0073e9SAndroid Build Coastguard Worker # bfloat16 1759*da0073e9SAndroid Build Coastguard Worker x = x.bfloat16() 1760*da0073e9SAndroid Build Coastguard Worker y = y.bfloat16() 1761*da0073e9SAndroid Build Coastguard Worker res_bf16 = torch.add(x, y) 1762*da0073e9SAndroid Build Coastguard Worker res_bf16 = res_bf16.float() # to compare with reference 1763*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_fp32, res_bf16, atol=1e-2, rtol=0) 1764*da0073e9SAndroid Build Coastguard Worker 1765*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 1766*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 1767*da0073e9SAndroid Build Coastguard Worker def test_norm(self, device, dtype, coalesced): 1768*da0073e9SAndroid Build Coastguard Worker def test_shape(sparse_dims, nnz, with_size): 1769*da0073e9SAndroid Build Coastguard Worker x, _, _ = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced) 1770*da0073e9SAndroid Build Coastguard Worker y = x.coalesce() 1771*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.norm(), y._values().norm()) 1772*da0073e9SAndroid Build Coastguard Worker 1773*da0073e9SAndroid Build Coastguard Worker test_shape(3, 10, 100) 1774*da0073e9SAndroid Build Coastguard Worker test_shape(4, 10, [100, 100, 100, 5, 5, 5, 0]) 1775*da0073e9SAndroid Build Coastguard Worker test_shape(4, 0, [0, 0, 100, 5, 5, 5, 0]) 1776*da0073e9SAndroid Build Coastguard Worker 1777*da0073e9SAndroid Build Coastguard Worker # Unsupported arguments should error 1778*da0073e9SAndroid Build Coastguard Worker kwarg_error_pairs = [ 1779*da0073e9SAndroid Build Coastguard Worker ({'keepdim': True}, 1780*da0073e9SAndroid Build Coastguard Worker RuntimeError, r'norm_sparse currently does not support keepdim=True'), 1781*da0073e9SAndroid Build Coastguard Worker ({'dim': 0}, 1782*da0073e9SAndroid Build Coastguard Worker RuntimeError, r'norm_sparse currently only supports full reductions'), 1783*da0073e9SAndroid Build Coastguard Worker ({'dtype': torch.double, 'p': 'fro'}, 1784*da0073e9SAndroid Build Coastguard Worker ValueError, r'dtype argument is not supported in frobenius norm'), 1785*da0073e9SAndroid Build Coastguard Worker ({'dtype': torch.double, 'p': 0}, 1786*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"norm_sparse currently does not support 'dtype' argument") 1787*da0073e9SAndroid Build Coastguard Worker ] 1788*da0073e9SAndroid Build Coastguard Worker x = self._gen_sparse(3, 10, 100, dtype, device, coalesced)[0] 1789*da0073e9SAndroid Build Coastguard Worker for kwargs, err, msg in kwarg_error_pairs: 1790*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(err, msg): 1791*da0073e9SAndroid Build Coastguard Worker x.norm(**kwargs) 1792*da0073e9SAndroid Build Coastguard Worker 1793*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 1794*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 1795*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_CROSSREF, "fallback triggers cuda device error") 1796*da0073e9SAndroid Build Coastguard Worker def test_sparse_sum(self, device, dtype, coalesced): 1797*da0073e9SAndroid Build Coastguard Worker 1798*da0073e9SAndroid Build Coastguard Worker def run_tests(S, td=None): 1799*da0073e9SAndroid Build Coastguard Worker D = S.coalesce().to_dense().detach().requires_grad_(True) 1800*da0073e9SAndroid Build Coastguard Worker if td is None: 1801*da0073e9SAndroid Build Coastguard Worker S_sum = torch.sparse.sum(S) 1802*da0073e9SAndroid Build Coastguard Worker D_sum = D.sum() 1803*da0073e9SAndroid Build Coastguard Worker self.assertEqual(S_sum.item(), D_sum.item()) 1804*da0073e9SAndroid Build Coastguard Worker 1805*da0073e9SAndroid Build Coastguard Worker def fn(S): 1806*da0073e9SAndroid Build Coastguard Worker return torch.sparse.sum(S) 1807*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, (S,), masked=True) 1808*da0073e9SAndroid Build Coastguard Worker else: 1809*da0073e9SAndroid Build Coastguard Worker S_sum = torch.sparse.sum(S, td) 1810*da0073e9SAndroid Build Coastguard Worker D_sum = D.sum(td) 1811*da0073e9SAndroid Build Coastguard Worker self.assertEqual(S_sum.to_dense() if S_sum.is_sparse else S_sum, D_sum) 1812*da0073e9SAndroid Build Coastguard Worker 1813*da0073e9SAndroid Build Coastguard Worker def fn(S): 1814*da0073e9SAndroid Build Coastguard Worker res = torch.sparse.sum(S, td) 1815*da0073e9SAndroid Build Coastguard Worker return res.to_dense(masked_grad=True) 1816*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, (S,), masked=True) 1817*da0073e9SAndroid Build Coastguard Worker 1818*da0073e9SAndroid Build Coastguard Worker nnz = 10 1819*da0073e9SAndroid Build Coastguard Worker sparse_dims = 2 1820*da0073e9SAndroid Build Coastguard Worker with_size = [5, 5, 1, 4] # use a dense dim = 1 to test for squeeze 1821*da0073e9SAndroid Build Coastguard Worker test_dims = [] 1822*da0073e9SAndroid Build Coastguard Worker for i in range(1, 5): 1823*da0073e9SAndroid Build Coastguard Worker test_dims += itertools.combinations(range(len(with_size)), i) 1824*da0073e9SAndroid Build Coastguard Worker 1825*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/16501 1826*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[1., 0., 0., 1.], 1827*da0073e9SAndroid Build Coastguard Worker [0., 1., 0., 0.], 1828*da0073e9SAndroid Build Coastguard Worker [0., 1., 1., 0.], 1829*da0073e9SAndroid Build Coastguard Worker [0., 1., 0., 2.]], dtype=dtype, device=device).to_sparse() 1830*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.sparse.sum(x, dim=0), torch.sparse.sum(x, dim=-2)) 1831*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.sum(x.to_dense(), dim=0), torch.sparse.sum(x, dim=0).to_dense()) 1832*da0073e9SAndroid Build Coastguard Worker 1833*da0073e9SAndroid Build Coastguard Worker S = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0] 1834*da0073e9SAndroid Build Coastguard Worker 1835*da0073e9SAndroid Build Coastguard Worker # dim out of range 1836*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: torch.sparse.sum(S, 5)) 1837*da0073e9SAndroid Build Coastguard Worker 1838*da0073e9SAndroid Build Coastguard Worker # dim 0 appears multiple times in the list of dims 1839*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.sparse.sum(S, [0, 0])) 1840*da0073e9SAndroid Build Coastguard Worker 1841*da0073e9SAndroid Build Coastguard Worker # sum an empty tensor 1842*da0073e9SAndroid Build Coastguard Worker empty_S = torch.sparse_coo_tensor(size=with_size, dtype=dtype, device=device) 1843*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.sparse.sum(empty_S, [0]).to_dense(), torch.sum(empty_S.to_dense(), [0])) 1844*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.sparse.sum(empty_S), torch.tensor(0, dtype=dtype, device=device)) 1845*da0073e9SAndroid Build Coastguard Worker empty_S.requires_grad_(True) 1846*da0073e9SAndroid Build Coastguard Worker empty_S_sum = torch.sparse.sum(empty_S) 1847*da0073e9SAndroid Build Coastguard Worker empty_S_sum.backward() 1848*da0073e9SAndroid Build Coastguard Worker self.assertEqual(empty_S.grad.to_dense(), empty_S.clone().detach().to_dense()) 1849*da0073e9SAndroid Build Coastguard Worker 1850*da0073e9SAndroid Build Coastguard Worker # test values().sum() 1851*da0073e9SAndroid Build Coastguard Worker S = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0] 1852*da0073e9SAndroid Build Coastguard Worker run_tests(S.requires_grad_(True)) 1853*da0073e9SAndroid Build Coastguard Worker 1854*da0073e9SAndroid Build Coastguard Worker for test_dim in test_dims: 1855*da0073e9SAndroid Build Coastguard Worker S = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0] 1856*da0073e9SAndroid Build Coastguard Worker run_tests(S.requires_grad_(True), test_dim) 1857*da0073e9SAndroid Build Coastguard Worker 1858*da0073e9SAndroid Build Coastguard Worker def _test_basic_ops_shape(self, nnz_x1, nnz_x2, shape_i, shape_v, dtype, device, coalesced): 1859*da0073e9SAndroid Build Coastguard Worker shape = shape_i + (shape_v) 1860*da0073e9SAndroid Build Coastguard Worker x1, _, _ = self._gen_sparse(len(shape_i), nnz_x1, shape, dtype, device, coalesced) 1861*da0073e9SAndroid Build Coastguard Worker x2, _, _ = self._gen_sparse(len(shape_i), nnz_x2, shape, dtype, device, coalesced) 1862*da0073e9SAndroid Build Coastguard Worker 1863*da0073e9SAndroid Build Coastguard Worker y1 = x1 + x2 1864*da0073e9SAndroid Build Coastguard Worker y2 = x1.clone() 1865*da0073e9SAndroid Build Coastguard Worker y2.add_(x2) 1866*da0073e9SAndroid Build Coastguard Worker expected = self.safeToDense(x1) + self.safeToDense(x2) 1867*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.safeToDense(y1), expected) 1868*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.safeToDense(y2), expected) 1869*da0073e9SAndroid Build Coastguard Worker 1870*da0073e9SAndroid Build Coastguard Worker y1 = x1 - x2 1871*da0073e9SAndroid Build Coastguard Worker y2 = x1.clone() 1872*da0073e9SAndroid Build Coastguard Worker y2.sub_(x2) 1873*da0073e9SAndroid Build Coastguard Worker expected = self.safeToDense(x1) - self.safeToDense(x2) 1874*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.safeToDense(y1), expected) 1875*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.safeToDense(y2), expected) 1876*da0073e9SAndroid Build Coastguard Worker 1877*da0073e9SAndroid Build Coastguard Worker y1 = x1 * x2 1878*da0073e9SAndroid Build Coastguard Worker y2 = x1.clone() 1879*da0073e9SAndroid Build Coastguard Worker y2.mul_(x2) 1880*da0073e9SAndroid Build Coastguard Worker expected = self.safeToDense(x1) * self.safeToDense(x2) 1881*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.safeToDense(y1), expected) 1882*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.safeToDense(y2), expected) 1883*da0073e9SAndroid Build Coastguard Worker 1884*da0073e9SAndroid Build Coastguard Worker y1 = x1 * 37.5 1885*da0073e9SAndroid Build Coastguard Worker y2 = x1.clone() 1886*da0073e9SAndroid Build Coastguard Worker y2.mul_(37.5) 1887*da0073e9SAndroid Build Coastguard Worker expected = self.safeToDense(x1) * 37.5 1888*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.safeToDense(y1), expected) 1889*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.safeToDense(y2), expected) 1890*da0073e9SAndroid Build Coastguard Worker 1891*da0073e9SAndroid Build Coastguard Worker y1 = x1 / 37.5 1892*da0073e9SAndroid Build Coastguard Worker y2 = x1.clone() 1893*da0073e9SAndroid Build Coastguard Worker y2.div_(37.5) 1894*da0073e9SAndroid Build Coastguard Worker expected = self.safeToDense(x1) / 37.5 1895*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.safeToDense(y1), expected) 1896*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.safeToDense(y2), expected) 1897*da0073e9SAndroid Build Coastguard Worker 1898*da0073e9SAndroid Build Coastguard Worker y1 = x1 // 37.5 1899*da0073e9SAndroid Build Coastguard Worker y2 = x1.clone() 1900*da0073e9SAndroid Build Coastguard Worker y2.floor_divide_(37.5) 1901*da0073e9SAndroid Build Coastguard Worker expected = self.safeToDense(x1) // 37.5 1902*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.safeToDense(y1), expected) 1903*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.safeToDense(y2), expected) 1904*da0073e9SAndroid Build Coastguard Worker 1905*da0073e9SAndroid Build Coastguard Worker # TODO: add back inplace support 1906*da0073e9SAndroid Build Coastguard Worker y1 = x1 ** 2 1907*da0073e9SAndroid Build Coastguard Worker y2 = x1.clone() 1908*da0073e9SAndroid Build Coastguard Worker y2 = y2.pow(2) 1909*da0073e9SAndroid Build Coastguard Worker expected = self.safeToDense(x1) ** 2 1910*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.safeToDense(y1), expected) 1911*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.safeToDense(y2), expected) 1912*da0073e9SAndroid Build Coastguard Worker 1913*da0073e9SAndroid Build Coastguard Worker y = x1.clone() 1914*da0073e9SAndroid Build Coastguard Worker y.zero_() 1915*da0073e9SAndroid Build Coastguard Worker expected = torch.zeros(x1.size(), dtype=dtype, device=device) 1916*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.safeToDense(y), expected) 1917*da0073e9SAndroid Build Coastguard Worker 1918*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1.is_coalesced(), coalesced) 1919*da0073e9SAndroid Build Coastguard Worker y = x1.coalesce() 1920*da0073e9SAndroid Build Coastguard Worker z = x1.coalesce() 1921*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1.is_coalesced(), coalesced) 1922*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y.is_coalesced()) 1923*da0073e9SAndroid Build Coastguard Worker y._values().add_(1) 1924*da0073e9SAndroid Build Coastguard Worker if not x1.is_coalesced(): 1925*da0073e9SAndroid Build Coastguard Worker # check that coalesce is out of place if the original tensor is not 1926*da0073e9SAndroid Build Coastguard Worker # coalesced. 1927*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z._values() + 1, y._values()) 1928*da0073e9SAndroid Build Coastguard Worker else: 1929*da0073e9SAndroid Build Coastguard Worker # check that coalesce is in-place if the original tensor is 1930*da0073e9SAndroid Build Coastguard Worker # coalesced. 1931*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z._values(), y._values()) 1932*da0073e9SAndroid Build Coastguard Worker 1933*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 1934*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 1935*da0073e9SAndroid Build Coastguard Worker def test_basic_ops(self, device, dtype, coalesced): 1936*da0073e9SAndroid Build Coastguard Worker 1937*da0073e9SAndroid Build Coastguard Worker def _test_basic_ops(): 1938*da0073e9SAndroid Build Coastguard Worker self._test_basic_ops_shape(9, 12, [5, 6], [], dtype, device, coalesced) 1939*da0073e9SAndroid Build Coastguard Worker self._test_basic_ops_shape(9, 12, [10, 10, 10], [], dtype, device, coalesced) 1940*da0073e9SAndroid Build Coastguard Worker self._test_basic_ops_shape(9, 12, [50, 30, 20], [], dtype, device, coalesced) 1941*da0073e9SAndroid Build Coastguard Worker self._test_basic_ops_shape(9, 12, [5, 5, 5, 5, 5, 5], [], dtype, device, coalesced) 1942*da0073e9SAndroid Build Coastguard Worker self._test_basic_ops_shape(0, 12, [10, 10, 10], [], dtype, device, coalesced) 1943*da0073e9SAndroid Build Coastguard Worker self._test_basic_ops_shape(9, 0, [10, 10, 10], [], dtype, device, coalesced) 1944*da0073e9SAndroid Build Coastguard Worker self._test_basic_ops_shape(0, 0, [10, 10, 10], [], dtype, device, coalesced) 1945*da0073e9SAndroid Build Coastguard Worker self._test_basic_ops_shape(0, 0, [10, 10, 0], [], dtype, device, coalesced) 1946*da0073e9SAndroid Build Coastguard Worker self._test_basic_ops_shape(0, 0, [], [], dtype, device, coalesced) 1947*da0073e9SAndroid Build Coastguard Worker 1948*da0073e9SAndroid Build Coastguard Worker def _test_basic_ops_hybrid(): 1949*da0073e9SAndroid Build Coastguard Worker self._test_basic_ops_shape(9, 12, [5, 6], [2, 3], dtype, device, coalesced) 1950*da0073e9SAndroid Build Coastguard Worker self._test_basic_ops_shape(9, 12, [10, 10, 10], [3], dtype, device, coalesced) 1951*da0073e9SAndroid Build Coastguard Worker self._test_basic_ops_shape(9, 12, [50, 30, 20], [2], dtype, device, coalesced) 1952*da0073e9SAndroid Build Coastguard Worker self._test_basic_ops_shape(9, 12, [5, 5, 5, 5, 5, 5], [2], dtype, device, coalesced) 1953*da0073e9SAndroid Build Coastguard Worker self._test_basic_ops_shape(0, 12, [10, 10, 10], [2], dtype, device, coalesced) 1954*da0073e9SAndroid Build Coastguard Worker self._test_basic_ops_shape(9, 0, [10, 10, 10], [2], dtype, device, coalesced) 1955*da0073e9SAndroid Build Coastguard Worker self._test_basic_ops_shape(0, 0, [10, 10, 10], [2], dtype, device, coalesced) 1956*da0073e9SAndroid Build Coastguard Worker self._test_basic_ops_shape(9, 12, [10, 10, 10], [2, 0], dtype, device, coalesced) 1957*da0073e9SAndroid Build Coastguard Worker self._test_basic_ops_shape(0, 12, [10, 10, 10], [2, 0], dtype, device, coalesced) 1958*da0073e9SAndroid Build Coastguard Worker self._test_basic_ops_shape(9, 0, [10, 10, 10], [2, 0], dtype, device, coalesced) 1959*da0073e9SAndroid Build Coastguard Worker self._test_basic_ops_shape(0, 0, [10, 10, 10], [2, 0], dtype, device, coalesced) 1960*da0073e9SAndroid Build Coastguard Worker self._test_basic_ops_shape(0, 0, [10, 10, 0], [2, 0], dtype, device, coalesced) 1961*da0073e9SAndroid Build Coastguard Worker 1962*da0073e9SAndroid Build Coastguard Worker _test_basic_ops() 1963*da0073e9SAndroid Build Coastguard Worker _test_basic_ops_hybrid() 1964*da0073e9SAndroid Build Coastguard Worker 1965*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 1966*da0073e9SAndroid Build Coastguard Worker def test_add_dense_sparse_mismatch(self, device, dtype): 1967*da0073e9SAndroid Build Coastguard Worker def test_shape(dense_size, sparse_dims_shape, dense_dims_shape, sparse_size): 1968*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(dense_size, dtype=dtype, device=device) 1969*da0073e9SAndroid Build Coastguard Worker sparse_y = self.sparse_tensor(torch.zeros(sparse_dims_shape, dtype=torch.int64, device=device), 1970*da0073e9SAndroid Build Coastguard Worker torch.randn(dense_dims_shape, dtype=dtype, device=device), 1971*da0073e9SAndroid Build Coastguard Worker torch.Size(sparse_size)) 1972*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1973*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1974*da0073e9SAndroid Build Coastguard Worker "add: expected 'self' and 'other' to have same size"): 1975*da0073e9SAndroid Build Coastguard Worker x + sparse_y 1976*da0073e9SAndroid Build Coastguard Worker 1977*da0073e9SAndroid Build Coastguard Worker test_shape([3, 4], [1, 4], [4, 4, 4], [3, 4, 4]) 1978*da0073e9SAndroid Build Coastguard Worker test_shape([3, 4, 0], [1, 4], [4, 4, 4, 0], [3, 4, 4, 0]) 1979*da0073e9SAndroid Build Coastguard Worker 1980*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Not a TorchDynamo suitable test") 1981*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 1982*da0073e9SAndroid Build Coastguard Worker def test_add_noncontiguous(self, device, dtype): 1983*da0073e9SAndroid Build Coastguard Worker indices = self.index_tensor([[1, 2], [0, 2]], device=device) 1984*da0073e9SAndroid Build Coastguard Worker values = torch.tensor([1.], dtype=dtype, device=device).expand(2, 3, 4, 5) 1985*da0073e9SAndroid Build Coastguard Worker x = self.sparse_tensor(indices, values, dtype=dtype, device=device) 1986*da0073e9SAndroid Build Coastguard Worker assert not x._values().is_contiguous() 1987*da0073e9SAndroid Build Coastguard Worker y = x + x 1988*da0073e9SAndroid Build Coastguard Worker expected = self.safeToDense(x) + self.safeToDense(x) 1989*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.safeToDense(y), expected) 1990*da0073e9SAndroid Build Coastguard Worker 1991*da0073e9SAndroid Build Coastguard Worker def _test_sparse_mask_shape(self, nnz_x1, nnz_x2, shape_i, shape_v, dtype, device, coalesced): 1992*da0073e9SAndroid Build Coastguard Worker shape = shape_i + (shape_v or []) 1993*da0073e9SAndroid Build Coastguard Worker x1, _, _ = self._gen_sparse(len(shape_i), nnz_x1, shape, dtype, device, coalesced) 1994*da0073e9SAndroid Build Coastguard Worker x2, _, _ = self._gen_sparse(len(shape_i), nnz_x2, shape, dtype, device, coalesced) 1995*da0073e9SAndroid Build Coastguard Worker 1996*da0073e9SAndroid Build Coastguard Worker y1 = x1 + x2 1997*da0073e9SAndroid Build Coastguard Worker y2 = x1.clone() 1998*da0073e9SAndroid Build Coastguard Worker y2.add_(x2) 1999*da0073e9SAndroid Build Coastguard Worker expected = self.safeToDense(x1) + self.safeToDense(x2) 2000*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.safeToDense(y1), expected) 2001*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.safeToDense(y2), expected) 2002*da0073e9SAndroid Build Coastguard Worker 2003*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 2004*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 2005*da0073e9SAndroid Build Coastguard Worker def test_sparse_mask(self, device, dtype, coalesced): 2006*da0073e9SAndroid Build Coastguard Worker def _test_sparse_mask_fixed(): 2007*da0073e9SAndroid Build Coastguard Worker i = self.index_tensor([ 2008*da0073e9SAndroid Build Coastguard Worker [1, 3, 0, 4], 2009*da0073e9SAndroid Build Coastguard Worker [2, 1, 2, 3], 2010*da0073e9SAndroid Build Coastguard Worker ], device=device) 2011*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([1, 2, 3, 4], dtype=dtype, device=device) 2012*da0073e9SAndroid Build Coastguard Worker x = self.sparse_tensor(i, v, torch.Size([5, 4]), dtype=dtype, device=device).coalesce() 2013*da0073e9SAndroid Build Coastguard Worker dense = torch.tensor([ 2014*da0073e9SAndroid Build Coastguard Worker [1, 2, 3, 4], 2015*da0073e9SAndroid Build Coastguard Worker [5, 6, 7, 8], 2016*da0073e9SAndroid Build Coastguard Worker [9, 10, 11, 12], 2017*da0073e9SAndroid Build Coastguard Worker [13, 14, 15, 16], 2018*da0073e9SAndroid Build Coastguard Worker [17, 18, 19, 20], 2019*da0073e9SAndroid Build Coastguard Worker ], dtype=dtype, device=device) 2020*da0073e9SAndroid Build Coastguard Worker exp_v = torch.tensor([7, 14, 3, 20], dtype=dtype, device=device) 2021*da0073e9SAndroid Build Coastguard Worker res_dense_lhs = dense.sparse_mask(x) 2022*da0073e9SAndroid Build Coastguard Worker sparse = dense.to_sparse() 2023*da0073e9SAndroid Build Coastguard Worker res_sparse_lhs = sparse.sparse_mask(x) 2024*da0073e9SAndroid Build Coastguard Worker expected = self.sparse_tensor(i, exp_v, torch.Size([5, 4]), dtype=dtype, device=device) 2025*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_dense_lhs.coalesce(), expected.coalesce()) 2026*da0073e9SAndroid Build Coastguard Worker # check no side effects for the coalesce flag. 2027*da0073e9SAndroid Build Coastguard Worker self.assertTrue(sparse.is_coalesced()) 2028*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_sparse_lhs.coalesce(), expected.coalesce()) 2029*da0073e9SAndroid Build Coastguard Worker 2030*da0073e9SAndroid Build Coastguard Worker i = self.index_tensor([ 2031*da0073e9SAndroid Build Coastguard Worker [1, 3, 0, 4], 2032*da0073e9SAndroid Build Coastguard Worker [2, 1, 2, 3], 2033*da0073e9SAndroid Build Coastguard Worker ], device=device) 2034*da0073e9SAndroid Build Coastguard Worker v = torch.empty([4, 0], dtype=dtype, device=device) 2035*da0073e9SAndroid Build Coastguard Worker x = self.sparse_tensor(i, v, torch.Size([5, 4, 0])).coalesce() 2036*da0073e9SAndroid Build Coastguard Worker dense = torch.empty([5, 4, 0], dtype=dtype, device=device) 2037*da0073e9SAndroid Build Coastguard Worker exp_v = torch.empty([4, 0], dtype=dtype, device=device) 2038*da0073e9SAndroid Build Coastguard Worker res_dense_lhs = dense.sparse_mask(x) 2039*da0073e9SAndroid Build Coastguard Worker sparse = dense.to_sparse(2) 2040*da0073e9SAndroid Build Coastguard Worker res_sparse_lhs = sparse.sparse_mask(x) 2041*da0073e9SAndroid Build Coastguard Worker expected = self.sparse_tensor(i, exp_v, torch.Size([5, 4, 0]), dtype=dtype, device=device) 2042*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_dense_lhs.coalesce(), expected.coalesce()) 2043*da0073e9SAndroid Build Coastguard Worker # check no side effects for the coalesce flag. 2044*da0073e9SAndroid Build Coastguard Worker self.assertTrue(sparse.is_coalesced()) 2045*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_sparse_lhs.coalesce(), expected.coalesce()) 2046*da0073e9SAndroid Build Coastguard Worker 2047*da0073e9SAndroid Build Coastguard Worker _test_sparse_mask_fixed() 2048*da0073e9SAndroid Build Coastguard Worker 2049*da0073e9SAndroid Build Coastguard Worker self._test_sparse_mask_shape(9, 12, [5, 6], [], dtype, device, coalesced) 2050*da0073e9SAndroid Build Coastguard Worker self._test_sparse_mask_shape(9, 12, [10, 10, 10], [], dtype, device, coalesced) 2051*da0073e9SAndroid Build Coastguard Worker self._test_sparse_mask_shape(9, 12, [50, 30, 20], [], dtype, device, coalesced) 2052*da0073e9SAndroid Build Coastguard Worker self._test_sparse_mask_shape(9, 12, [5, 5, 5, 5, 5, 5], [], dtype, device, coalesced) 2053*da0073e9SAndroid Build Coastguard Worker self._test_sparse_mask_shape(0, 12, [10, 10, 10], [], dtype, device, coalesced) 2054*da0073e9SAndroid Build Coastguard Worker self._test_sparse_mask_shape(9, 0, [10, 10, 10], [], dtype, device, coalesced) 2055*da0073e9SAndroid Build Coastguard Worker self._test_sparse_mask_shape(0, 0, [10, 10, 10], [], dtype, device, coalesced) 2056*da0073e9SAndroid Build Coastguard Worker self._test_sparse_mask_shape(0, 0, [10, 10, 0], [], dtype, device, coalesced) 2057*da0073e9SAndroid Build Coastguard Worker 2058*da0073e9SAndroid Build Coastguard Worker # check repetitions and matchings in the intersection 2059*da0073e9SAndroid Build Coastguard Worker lhs = torch.randint(0, 5, (100,), device=device) 2060*da0073e9SAndroid Build Coastguard Worker rhs = torch.randint(0, 5, (100,), device=device).to_sparse() 2061*da0073e9SAndroid Build Coastguard Worker self.assertEqual(lhs.to_sparse().sparse_mask(rhs), lhs.sparse_mask(rhs)) 2062*da0073e9SAndroid Build Coastguard Worker 2063*da0073e9SAndroid Build Coastguard Worker # check coalesce 2064*da0073e9SAndroid Build Coastguard Worker sparse_c = torch.rand(3, 3, device=device).to_sparse() 2065*da0073e9SAndroid Build Coastguard Worker sparse_unc = torch.rand(3, 3, device=device).to_sparse()._coalesced_(False) 2066*da0073e9SAndroid Build Coastguard Worker for lhs, rhs in [(sparse_c, sparse_unc), (sparse_unc, sparse_c)]: 2067*da0073e9SAndroid Build Coastguard Worker res_all_sparse = lhs.sparse_mask(rhs) 2068*da0073e9SAndroid Build Coastguard Worker res_dense_sparse = lhs.to_dense().sparse_mask(rhs) 2069*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_all_sparse.coalesce(), res_dense_sparse.coalesce()) 2070*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rhs.is_coalesced(), res_all_sparse.is_coalesced()) 2071*da0073e9SAndroid Build Coastguard Worker 2072*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 2073*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 2074*da0073e9SAndroid Build Coastguard Worker def test_sparse_mask_hybrid(self, device, dtype, coalesced): 2075*da0073e9SAndroid Build Coastguard Worker def _test_sparse_mask_hybrid_fixed(): 2076*da0073e9SAndroid Build Coastguard Worker i = self.index_tensor([ 2077*da0073e9SAndroid Build Coastguard Worker [1, 3, 0, 4], 2078*da0073e9SAndroid Build Coastguard Worker [2, 1, 2, 3], 2079*da0073e9SAndroid Build Coastguard Worker ]) 2080*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([[1, 2], [2, 3], [3, 4], [4, 5]]) 2081*da0073e9SAndroid Build Coastguard Worker # TODO: This is also testing that, if coalesce is a no-op, 2082*da0073e9SAndroid Build Coastguard Worker # the indices don't get permuted. I don't know if we actually 2083*da0073e9SAndroid Build Coastguard Worker # want to give this invariant. 2084*da0073e9SAndroid Build Coastguard Worker x = self.sparse_tensor(i, v, torch.Size([5, 4, 2])).coalesce() 2085*da0073e9SAndroid Build Coastguard Worker dense = torch.tensor([ 2086*da0073e9SAndroid Build Coastguard Worker [[1, 3], [2, 2], [3, 3], [4, 2]], 2087*da0073e9SAndroid Build Coastguard Worker [[5, 7], [6, 7], [7, 9], [8, 9]], 2088*da0073e9SAndroid Build Coastguard Worker [[9, 2], [10, 4], [11, 1], [12, 3]], 2089*da0073e9SAndroid Build Coastguard Worker [[13, 5], [14, 1], [15, 1], [16, 6]], 2090*da0073e9SAndroid Build Coastguard Worker [[17, 7], [18, 2], [19, 7], [20, 1]], 2091*da0073e9SAndroid Build Coastguard Worker ]) 2092*da0073e9SAndroid Build Coastguard Worker res_dense_lhs = dense.sparse_mask(x) 2093*da0073e9SAndroid Build Coastguard Worker sparse = dense.to_sparse(2) 2094*da0073e9SAndroid Build Coastguard Worker res_sparse_lhs = sparse.sparse_mask(x) 2095*da0073e9SAndroid Build Coastguard Worker exp_v = torch.tensor([[7, 9], [14, 1], [3, 3], [20, 1]]) 2096*da0073e9SAndroid Build Coastguard Worker expected = self.sparse_tensor(i, exp_v, torch.Size([5, 4, 2])) 2097*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_dense_lhs.coalesce(), expected.coalesce()) 2098*da0073e9SAndroid Build Coastguard Worker # check no side effects for the coalesce flag 2099*da0073e9SAndroid Build Coastguard Worker self.assertTrue(sparse.is_coalesced()) 2100*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_sparse_lhs.coalesce(), expected.coalesce()) 2101*da0073e9SAndroid Build Coastguard Worker 2102*da0073e9SAndroid Build Coastguard Worker i = self.index_tensor([ 2103*da0073e9SAndroid Build Coastguard Worker [1, 3, 0, 4], 2104*da0073e9SAndroid Build Coastguard Worker [2, 1, 2, 3], 2105*da0073e9SAndroid Build Coastguard Worker ]) 2106*da0073e9SAndroid Build Coastguard Worker v = torch.empty(4, 2, 0) 2107*da0073e9SAndroid Build Coastguard Worker x = self.sparse_tensor(i, v, torch.Size([5, 4, 2, 0])).coalesce() 2108*da0073e9SAndroid Build Coastguard Worker dense = torch.empty(5, 4, 2, 0) 2109*da0073e9SAndroid Build Coastguard Worker res_dense_lhs = dense.sparse_mask(x) 2110*da0073e9SAndroid Build Coastguard Worker sparse = dense.to_sparse(2) 2111*da0073e9SAndroid Build Coastguard Worker res_sparse_lhs = sparse.sparse_mask(x) 2112*da0073e9SAndroid Build Coastguard Worker exp_v = torch.empty(4, 2, 0) 2113*da0073e9SAndroid Build Coastguard Worker expected = self.sparse_tensor(i, exp_v, torch.Size([5, 4, 2, 0])) 2114*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_dense_lhs.coalesce(), expected.coalesce()) 2115*da0073e9SAndroid Build Coastguard Worker # check no side effects for the coalesce flag 2116*da0073e9SAndroid Build Coastguard Worker self.assertTrue(sparse.is_coalesced()) 2117*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_sparse_lhs.coalesce(), expected.coalesce()) 2118*da0073e9SAndroid Build Coastguard Worker 2119*da0073e9SAndroid Build Coastguard Worker _test_sparse_mask_hybrid_fixed() 2120*da0073e9SAndroid Build Coastguard Worker 2121*da0073e9SAndroid Build Coastguard Worker self._test_sparse_mask_shape(9, 12, [5, 6], [2, 3], dtype, device, coalesced) 2122*da0073e9SAndroid Build Coastguard Worker self._test_sparse_mask_shape(9, 12, [10, 10, 10], [3], dtype, device, coalesced) 2123*da0073e9SAndroid Build Coastguard Worker self._test_sparse_mask_shape(9, 12, [50, 30, 20], [2], dtype, device, coalesced) 2124*da0073e9SAndroid Build Coastguard Worker self._test_sparse_mask_shape(9, 12, [5, 5, 5, 5, 5, 5], [2], dtype, device, coalesced) 2125*da0073e9SAndroid Build Coastguard Worker self._test_sparse_mask_shape(0, 12, [10, 10, 10], [2], dtype, device, coalesced) 2126*da0073e9SAndroid Build Coastguard Worker self._test_sparse_mask_shape(9, 0, [10, 10, 10], [2], dtype, device, coalesced) 2127*da0073e9SAndroid Build Coastguard Worker self._test_sparse_mask_shape(0, 0, [10, 10, 10], [2], dtype, device, coalesced) 2128*da0073e9SAndroid Build Coastguard Worker self._test_sparse_mask_shape(9, 12, [10, 10, 10], [2, 0], dtype, device, coalesced) 2129*da0073e9SAndroid Build Coastguard Worker self._test_sparse_mask_shape(0, 12, [10, 10, 10], [2, 0], dtype, device, coalesced) 2130*da0073e9SAndroid Build Coastguard Worker self._test_sparse_mask_shape(9, 0, [10, 10, 10], [2, 0], dtype, device, coalesced) 2131*da0073e9SAndroid Build Coastguard Worker self._test_sparse_mask_shape(0, 0, [10, 10, 10], [2, 0], dtype, device, coalesced) 2132*da0073e9SAndroid Build Coastguard Worker self._test_sparse_mask_shape(0, 0, [10, 10, 0], [2, 0], dtype, device, coalesced) 2133*da0073e9SAndroid Build Coastguard Worker 2134*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 2135*da0073e9SAndroid Build Coastguard Worker @skipIfCrossRef 2136*da0073e9SAndroid Build Coastguard Worker def test_sparse_mask_backward(self, device, dtype): 2137*da0073e9SAndroid Build Coastguard Worker from itertools import product, repeat 2138*da0073e9SAndroid Build Coastguard Worker 2139*da0073e9SAndroid Build Coastguard Worker shape = (5, 5) 2140*da0073e9SAndroid Build Coastguard Worker sparse_dims = len(shape) 2141*da0073e9SAndroid Build Coastguard Worker nnzs = (0, 5, 15, 25) 2142*da0073e9SAndroid Build Coastguard Worker 2143*da0073e9SAndroid Build Coastguard Worker lhs_data = torch.arange(1, 26, device=device).reshape(shape).to(dtype).to_sparse(sparse_dims) 2144*da0073e9SAndroid Build Coastguard Worker rhs_data = lhs_data.clone() 2145*da0073e9SAndroid Build Coastguard Worker 2146*da0073e9SAndroid Build Coastguard Worker for nnz in nnzs: 2147*da0073e9SAndroid Build Coastguard Worker for lhs_is_coalesced, rhs_is_coalesced in product(*repeat((True, False), 2)): 2148*da0073e9SAndroid Build Coastguard Worker lhs = torch.sparse_coo_tensor( 2149*da0073e9SAndroid Build Coastguard Worker lhs_data._indices()[:, :nnz], 2150*da0073e9SAndroid Build Coastguard Worker lhs_data._values()[:nnz], 2151*da0073e9SAndroid Build Coastguard Worker lhs_data.shape 2152*da0073e9SAndroid Build Coastguard Worker ).clone()._coalesced_(lhs_is_coalesced).requires_grad_(True) 2153*da0073e9SAndroid Build Coastguard Worker 2154*da0073e9SAndroid Build Coastguard Worker rhs = torch.sparse_coo_tensor( 2155*da0073e9SAndroid Build Coastguard Worker lhs_data._indices()[:, -nnz:], 2156*da0073e9SAndroid Build Coastguard Worker lhs_data._values()[-nnz:], 2157*da0073e9SAndroid Build Coastguard Worker lhs_data.shape 2158*da0073e9SAndroid Build Coastguard Worker ).clone()._coalesced_(rhs_is_coalesced) 2159*da0073e9SAndroid Build Coastguard Worker 2160*da0073e9SAndroid Build Coastguard Worker # To test masked semantics we need to make sure that 2161*da0073e9SAndroid Build Coastguard Worker # sparsity_pattern(lhs) == sparsity_pattern(lhs.grad). 2162*da0073e9SAndroid Build Coastguard Worker # lhs.sparse_mask(lhs_mask) accomplishes that. 2163*da0073e9SAndroid Build Coastguard Worker lhs_mask = lhs.detach().clone() 2164*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda x: x.sparse_mask(lhs_mask).sparse_mask(rhs).to_dense(masked_grad=True), (lhs,), masked=True) 2165*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda x: x.sparse_mask(rhs).to_dense(masked_grad=False), (lhs,), masked=False) 2166*da0073e9SAndroid Build Coastguard Worker 2167*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 2168*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 2169*da0073e9SAndroid Build Coastguard Worker def test_zeros(self, device, dtype, coalesced): 2170*da0073e9SAndroid Build Coastguard Worker def _test_zeros(nnzs, shape, out_shape_i, out_shape_v=None): 2171*da0073e9SAndroid Build Coastguard Worker out_shape = out_shape_i + (out_shape_v or []) 2172*da0073e9SAndroid Build Coastguard Worker for nnz in nnzs: 2173*da0073e9SAndroid Build Coastguard Worker out, _, _ = self._gen_sparse(len(out_shape_i), nnz, out_shape, dtype, device, coalesced) 2174*da0073e9SAndroid Build Coastguard Worker torch.zeros(*shape, out=out, dtype=dtype, device=device) 2175*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(out.size()), tuple(shape)) 2176*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out._indices().numel() == out._values().numel() == 0) 2177*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out._nnz(), 0) 2178*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.sparse_dim(), len(shape)) 2179*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.dense_dim(), 0) 2180*da0073e9SAndroid Build Coastguard Worker 2181*da0073e9SAndroid Build Coastguard Worker def test_shape(i_shapes, v_shapes, shape, nnzs): 2182*da0073e9SAndroid Build Coastguard Worker for i_dim in range(1, len(i_shapes) + 1): 2183*da0073e9SAndroid Build Coastguard Worker for v_dim in range(len(v_shapes) + 1): 2184*da0073e9SAndroid Build Coastguard Worker _test_zeros(nnzs, shape, i_shapes[:i_dim], v_shapes[:v_dim]) 2185*da0073e9SAndroid Build Coastguard Worker test_shape([2, 3, 4], [3, 4, 5, 6], [2, 3, 4], [9, 12]) 2186*da0073e9SAndroid Build Coastguard Worker test_shape([0, 3, 4], [3, 4, 5, 6], [2, 3, 4], [0]) 2187*da0073e9SAndroid Build Coastguard Worker test_shape([2, 3, 4], [0, 4, 5, 6], [2, 3, 4], [9, 12]) 2188*da0073e9SAndroid Build Coastguard Worker test_shape([2, 3, 4], [3, 4, 5, 6], [2, 3, 0], [9, 12]) 2189*da0073e9SAndroid Build Coastguard Worker test_shape([0, 3, 4], [3, 4, 5, 6], [2, 3, 0], [0]) 2190*da0073e9SAndroid Build Coastguard Worker test_shape([2, 3, 4], [0, 4, 5, 6], [2, 3, 0], [9, 12]) 2191*da0073e9SAndroid Build Coastguard Worker 2192*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 2193*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 2194*da0073e9SAndroid Build Coastguard Worker def test_zeros_like(self, device, dtype, coalesced): 2195*da0073e9SAndroid Build Coastguard Worker def _test_zeros_like(nnzs, template_shape_i, template_shape_v=None): 2196*da0073e9SAndroid Build Coastguard Worker template_shape_v = template_shape_v or [] 2197*da0073e9SAndroid Build Coastguard Worker template_shape = template_shape_i + template_shape_v 2198*da0073e9SAndroid Build Coastguard Worker for nnz in nnzs: 2199*da0073e9SAndroid Build Coastguard Worker t, _, _ = self._gen_sparse(len(template_shape_i), nnz, template_shape, dtype, device, coalesced) 2200*da0073e9SAndroid Build Coastguard Worker res = torch.zeros_like(t) 2201*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(res.size()), tuple(template_shape)) 2202*da0073e9SAndroid Build Coastguard Worker self.assertTrue(res._indices().numel() == res._values().numel() == 0) 2203*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res._nnz(), 0) 2204*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.sparse_dim(), len(template_shape_i)) 2205*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.dense_dim(), len(template_shape_v)) 2206*da0073e9SAndroid Build Coastguard Worker 2207*da0073e9SAndroid Build Coastguard Worker def test_shape(i_shapes, v_shapes, nnzs): 2208*da0073e9SAndroid Build Coastguard Worker for i_dim in range(1, len(i_shapes) + 1): 2209*da0073e9SAndroid Build Coastguard Worker for v_dim in range(len(v_shapes) + 1): 2210*da0073e9SAndroid Build Coastguard Worker _test_zeros_like(nnzs, i_shapes[:i_dim], v_shapes[:v_dim]) 2211*da0073e9SAndroid Build Coastguard Worker test_shape([2, 3, 4], [3, 4, 5, 6], [9, 12]) 2212*da0073e9SAndroid Build Coastguard Worker test_shape([0, 3, 4], [3, 4, 5, 6], [0]) 2213*da0073e9SAndroid Build Coastguard Worker test_shape([2, 3, 4], [0, 4, 5, 6], [9, 12]) 2214*da0073e9SAndroid Build Coastguard Worker test_shape([2, 3, 4], [3, 4, 5, 6], [9, 12]) 2215*da0073e9SAndroid Build Coastguard Worker test_shape([0, 3, 4], [3, 4, 5, 6], [0]) 2216*da0073e9SAndroid Build Coastguard Worker test_shape([2, 3, 4], [0, 4, 5, 6], [9, 12]) 2217*da0073e9SAndroid Build Coastguard Worker 2218*da0073e9SAndroid Build Coastguard Worker sparse_tensor, _, _ = self._gen_sparse(len([2, 3]), 9, [2, 3] + [5, 6], dtype, device, coalesced) 2219*da0073e9SAndroid Build Coastguard Worker data = (sparse_tensor, sparse_tensor, sparse_tensor, sparse_tensor.unsqueeze(0)) 2220*da0073e9SAndroid Build Coastguard Worker mem_formats = [torch.channels_last, torch.contiguous_format, torch.preserve_format, torch.channels_last_3d] 2221*da0073e9SAndroid Build Coastguard Worker for x, mem_format in zip(data, mem_formats): 2222*da0073e9SAndroid Build Coastguard Worker 2223*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "memory format option is only supported by strided tensors"): 2224*da0073e9SAndroid Build Coastguard Worker result = torch.zeros_like(x, memory_format=mem_format) 2225*da0073e9SAndroid Build Coastguard Worker 2226*da0073e9SAndroid Build Coastguard Worker result = torch.zeros_like(x, layout=torch.strided, memory_format=mem_format) 2227*da0073e9SAndroid Build Coastguard Worker self.assertTrue(result.layout == torch.strided) 2228*da0073e9SAndroid Build Coastguard Worker 2229*da0073e9SAndroid Build Coastguard Worker dense_tensor = sparse_tensor.to_dense() 2230*da0073e9SAndroid Build Coastguard Worker result = torch.zeros_like(dense_tensor, layout=torch.sparse_coo) 2231*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dense_tensor.shape, result.shape) 2232*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.layout, torch.sparse_coo) 2233*da0073e9SAndroid Build Coastguard Worker 2234*da0073e9SAndroid Build Coastguard Worker sparse_zeros = torch.sparse_coo_tensor(dense_tensor.shape) 2235*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result._indices().shape, sparse_zeros._indices().shape) 2236*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result._values().shape, sparse_zeros._values().shape) 2237*da0073e9SAndroid Build Coastguard Worker 2238*da0073e9SAndroid Build Coastguard Worker def _assert_sparse_invars(self, t): 2239*da0073e9SAndroid Build Coastguard Worker # SparseTensor has the following invariants: 2240*da0073e9SAndroid Build Coastguard Worker # - sparse_dim + dense_dim = len(SparseTensor.shape) 2241*da0073e9SAndroid Build Coastguard Worker # - SparseTensor._indices().shape = (sparse_dim, nnz) 2242*da0073e9SAndroid Build Coastguard Worker # - SparseTensor._values().shape = (nnz, SparseTensor.shape[sparse_dim:]) 2243*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.sparse_dim() + t.dense_dim(), len(t.shape)) 2244*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(t._indices().shape), (t.sparse_dim(), t._nnz())) 2245*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(t._values().shape), (t._nnz(), ) + t.shape[t.sparse_dim():]) 2246*da0073e9SAndroid Build Coastguard Worker 2247*da0073e9SAndroid Build Coastguard Worker def _test_empty_like(self, sparse_tensor, dtype, device, coalesced): 2248*da0073e9SAndroid Build Coastguard Worker 2249*da0073e9SAndroid Build Coastguard Worker result = torch.empty_like(sparse_tensor) 2250*da0073e9SAndroid Build Coastguard Worker self.assertTrue(result.is_sparse) 2251*da0073e9SAndroid Build Coastguard Worker self._assert_sparse_invars(result) 2252*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, sparse_tensor.shape) 2253*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, sparse_tensor.dtype) 2254*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.device, sparse_tensor.device) 2255*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.sparse_dim(), sparse_tensor.sparse_dim()) 2256*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dense_dim(), sparse_tensor.dense_dim()) 2257*da0073e9SAndroid Build Coastguard Worker 2258*da0073e9SAndroid Build Coastguard Worker sparse_tensor, _, _ = self._gen_sparse(len([2, 3]), 9, [2, 3] + [5, 6], dtype, device, coalesced) 2259*da0073e9SAndroid Build Coastguard Worker data = (sparse_tensor, sparse_tensor, sparse_tensor, sparse_tensor.unsqueeze(0)) 2260*da0073e9SAndroid Build Coastguard Worker mem_formats = [torch.channels_last, torch.contiguous_format, torch.preserve_format, torch.channels_last_3d] 2261*da0073e9SAndroid Build Coastguard Worker for x, mem_format in zip(data, mem_formats): 2262*da0073e9SAndroid Build Coastguard Worker 2263*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "memory format option is only supported by strided tensors"): 2264*da0073e9SAndroid Build Coastguard Worker result = torch.empty_like(x, memory_format=mem_format) 2265*da0073e9SAndroid Build Coastguard Worker 2266*da0073e9SAndroid Build Coastguard Worker result = torch.empty_like(x, layout=torch.strided, memory_format=mem_format) 2267*da0073e9SAndroid Build Coastguard Worker self.assertTrue(result.layout == torch.strided) 2268*da0073e9SAndroid Build Coastguard Worker 2269*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2270*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"Could not run 'aten::empty_strided' with arguments from the 'Sparse(CPU|CUDA)' backend" 2271*da0073e9SAndroid Build Coastguard Worker ): 2272*da0073e9SAndroid Build Coastguard Worker dense_tensor = sparse_tensor.to_dense() 2273*da0073e9SAndroid Build Coastguard Worker result = torch.empty_like(dense_tensor, layout=torch.sparse_coo) 2274*da0073e9SAndroid Build Coastguard Worker 2275*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 2276*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 2277*da0073e9SAndroid Build Coastguard Worker def test_empty_like(self, device, dtype, coalesced): 2278*da0073e9SAndroid Build Coastguard Worker # tests https://github.com/pytorch/pytorch/issues/43699 2279*da0073e9SAndroid Build Coastguard Worker 2280*da0073e9SAndroid Build Coastguard Worker if coalesced: 2281*da0073e9SAndroid Build Coastguard Worker input_coalesced = torch.sparse_coo_tensor( 2282*da0073e9SAndroid Build Coastguard Worker indices=torch.tensor([[0, 1, 2]]), 2283*da0073e9SAndroid Build Coastguard Worker values=torch.tensor([3.0, -4.0, 5.0]), 2284*da0073e9SAndroid Build Coastguard Worker size=[3, ], 2285*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 2286*da0073e9SAndroid Build Coastguard Worker device=device 2287*da0073e9SAndroid Build Coastguard Worker ).coalesce() 2288*da0073e9SAndroid Build Coastguard Worker self._test_empty_like(input_coalesced, dtype, device, coalesced) 2289*da0073e9SAndroid Build Coastguard Worker 2290*da0073e9SAndroid Build Coastguard Worker # hybrid sparse input 2291*da0073e9SAndroid Build Coastguard Worker input_coalesced = torch.sparse_coo_tensor( 2292*da0073e9SAndroid Build Coastguard Worker indices=torch.tensor([[1, 3], [2, 4]]), 2293*da0073e9SAndroid Build Coastguard Worker values=torch.tensor([[-1.0, 3.0], [-5.0, 7.0]]), 2294*da0073e9SAndroid Build Coastguard Worker size=[4, 5, 2], 2295*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 2296*da0073e9SAndroid Build Coastguard Worker device=device 2297*da0073e9SAndroid Build Coastguard Worker ).coalesce() 2298*da0073e9SAndroid Build Coastguard Worker self._test_empty_like(input_coalesced, dtype, device, coalesced) 2299*da0073e9SAndroid Build Coastguard Worker 2300*da0073e9SAndroid Build Coastguard Worker if not coalesced: 2301*da0073e9SAndroid Build Coastguard Worker # test uncoalesced input 2302*da0073e9SAndroid Build Coastguard Worker input_uncoalesced = torch.sparse_coo_tensor( 2303*da0073e9SAndroid Build Coastguard Worker indices=torch.tensor([[0], [1], [2], [0], [1], [2]]).transpose(1, 0), 2304*da0073e9SAndroid Build Coastguard Worker values=torch.tensor([2.0, -3.0, -4.0, 1.0, -1.0, 1.5]), 2305*da0073e9SAndroid Build Coastguard Worker size=[3, ], 2306*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 2307*da0073e9SAndroid Build Coastguard Worker device=device 2308*da0073e9SAndroid Build Coastguard Worker ) 2309*da0073e9SAndroid Build Coastguard Worker self._test_empty_like(input_uncoalesced, dtype, device, coalesced) 2310*da0073e9SAndroid Build Coastguard Worker 2311*da0073e9SAndroid Build Coastguard Worker # test on empty sparse tensor 2312*da0073e9SAndroid Build Coastguard Worker input_uncoalesced = torch.sparse_coo_tensor( 2313*da0073e9SAndroid Build Coastguard Worker indices=torch.zeros([2, 0]), 2314*da0073e9SAndroid Build Coastguard Worker values=torch.zeros([0, 5, 5, 5, 5, 5, 5, 0]), 2315*da0073e9SAndroid Build Coastguard Worker size=[0, 0, 5, 5, 5, 5, 5, 5, 0], 2316*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 2317*da0073e9SAndroid Build Coastguard Worker device=device 2318*da0073e9SAndroid Build Coastguard Worker ) 2319*da0073e9SAndroid Build Coastguard Worker self._test_empty_like(input_uncoalesced, dtype, device, coalesced) 2320*da0073e9SAndroid Build Coastguard Worker 2321*da0073e9SAndroid Build Coastguard Worker def _test_narrow(self, input, narrow_args): 2322*da0073e9SAndroid Build Coastguard Worker expected = input.to_dense().narrow(*narrow_args) 2323*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, input.narrow_copy(*narrow_args).to_dense()) 2324*da0073e9SAndroid Build Coastguard Worker 2325*da0073e9SAndroid Build Coastguard Worker def _all_narrow_combs(self, shape): 2326*da0073e9SAndroid Build Coastguard Worker for dim, dim_sz in enumerate(shape): 2327*da0073e9SAndroid Build Coastguard Worker for start in range(dim_sz): 2328*da0073e9SAndroid Build Coastguard Worker for length in range(dim_sz - start): 2329*da0073e9SAndroid Build Coastguard Worker yield [dim, start, length] 2330*da0073e9SAndroid Build Coastguard Worker 2331*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 2332*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 2333*da0073e9SAndroid Build Coastguard Worker def test_narrow(self, device, dtype, coalesced): 2334*da0073e9SAndroid Build Coastguard Worker shape = [3, 3, 4, 2] 2335*da0073e9SAndroid Build Coastguard Worker input, _, _ = self._gen_sparse(4, 19, shape, dtype, device, coalesced) 2336*da0073e9SAndroid Build Coastguard Worker for narrow_args in self._all_narrow_combs(shape): 2337*da0073e9SAndroid Build Coastguard Worker self._test_narrow(input, narrow_args) 2338*da0073e9SAndroid Build Coastguard Worker 2339*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: input.narrow_copy(-1, 0, 3)) # dim < 0 2340*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: input.narrow_copy(10, 0, 3)) # dim > input.dim() 2341*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: input.narrow_copy(0, shape[0] + 1, 3)) # start > size of dim 2342*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: input.narrow_copy(0, 2, shape[0])) # start+length > size of dim 2343*da0073e9SAndroid Build Coastguard Worker 2344*da0073e9SAndroid Build Coastguard Worker with_dense, _, _ = self._gen_sparse(2, 7, shape, dtype, device, coalesced) 2345*da0073e9SAndroid Build Coastguard Worker for narrow_args in self._all_narrow_combs(shape): 2346*da0073e9SAndroid Build Coastguard Worker self._test_narrow(with_dense, narrow_args) 2347*da0073e9SAndroid Build Coastguard Worker 2348*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: with_dense.narrow_copy(10, 0, 3)) # dim > sparseDim + denseDim 2349*da0073e9SAndroid Build Coastguard Worker 2350*da0073e9SAndroid Build Coastguard Worker def _test_log1p_tensor(self, sparse_tensor, coalesced): 2351*da0073e9SAndroid Build Coastguard Worker def is_integral(dtype): 2352*da0073e9SAndroid Build Coastguard Worker return dtype in integral_types() 2353*da0073e9SAndroid Build Coastguard Worker 2354*da0073e9SAndroid Build Coastguard Worker dense_tensor = sparse_tensor.to_dense() 2355*da0073e9SAndroid Build Coastguard Worker expected_output = dense_tensor.log1p() 2356*da0073e9SAndroid Build Coastguard Worker is_integral_dtype = is_integral(sparse_tensor.dtype) 2357*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_output, sparse_tensor.log1p().to_dense()) 2358*da0073e9SAndroid Build Coastguard Worker if is_integral_dtype: 2359*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "result type .* can't be cast to"): 2360*da0073e9SAndroid Build Coastguard Worker sparse_tensor.coalesce().log1p_() 2361*da0073e9SAndroid Build Coastguard Worker else: 2362*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_output, sparse_tensor.coalesce().log1p_().to_dense()) 2363*da0073e9SAndroid Build Coastguard Worker 2364*da0073e9SAndroid Build Coastguard Worker if not coalesced: 2365*da0073e9SAndroid Build Coastguard Worker # test in-place op on uncoalesced input 2366*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "log1p_ requires coalesced input"): 2367*da0073e9SAndroid Build Coastguard Worker sparse_tensor.log1p_() 2368*da0073e9SAndroid Build Coastguard Worker 2369*da0073e9SAndroid Build Coastguard Worker if is_integral_dtype: 2370*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "only Tensors of floating point dtype can require gradients"): 2371*da0073e9SAndroid Build Coastguard Worker sparse_tensor.requires_grad_() 2372*da0073e9SAndroid Build Coastguard Worker 2373*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 2374*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types()) 2375*da0073e9SAndroid Build Coastguard Worker def test_log1p(self, device, dtype, coalesced): 2376*da0073e9SAndroid Build Coastguard Worker if coalesced: 2377*da0073e9SAndroid Build Coastguard Worker input_coalesced = torch.sparse_coo_tensor( 2378*da0073e9SAndroid Build Coastguard Worker indices=torch.tensor([[0], [1], [2]]).transpose(1, 0), 2379*da0073e9SAndroid Build Coastguard Worker values=torch.tensor([3.0, 4.0, 5.0]), 2380*da0073e9SAndroid Build Coastguard Worker size=[3, ], 2381*da0073e9SAndroid Build Coastguard Worker device=device, 2382*da0073e9SAndroid Build Coastguard Worker dtype=dtype 2383*da0073e9SAndroid Build Coastguard Worker ).coalesce() 2384*da0073e9SAndroid Build Coastguard Worker self._test_log1p_tensor(input_coalesced, coalesced) 2385*da0073e9SAndroid Build Coastguard Worker 2386*da0073e9SAndroid Build Coastguard Worker # hybrid sparse input 2387*da0073e9SAndroid Build Coastguard Worker input_coalesced = torch.sparse_coo_tensor( 2388*da0073e9SAndroid Build Coastguard Worker indices=torch.tensor([[1, 3], [2, 4]]), 2389*da0073e9SAndroid Build Coastguard Worker values=torch.tensor([[1.0, 3.0], [5.0, 7.0]]), 2390*da0073e9SAndroid Build Coastguard Worker size=[4, 5, 2], 2391*da0073e9SAndroid Build Coastguard Worker device=device, 2392*da0073e9SAndroid Build Coastguard Worker dtype=dtype 2393*da0073e9SAndroid Build Coastguard Worker ).coalesce() 2394*da0073e9SAndroid Build Coastguard Worker self._test_log1p_tensor(input_coalesced, coalesced) 2395*da0073e9SAndroid Build Coastguard Worker 2396*da0073e9SAndroid Build Coastguard Worker if not coalesced: 2397*da0073e9SAndroid Build Coastguard Worker # test uncoalesced input 2398*da0073e9SAndroid Build Coastguard Worker input_uncoalesced = torch.sparse_coo_tensor( 2399*da0073e9SAndroid Build Coastguard Worker indices=torch.tensor([[0], [1], [2], [0], [1], [2]]).transpose(1, 0), 2400*da0073e9SAndroid Build Coastguard Worker values=torch.tensor([2.0, 3.0, 4.0, 1.0, 1.0, 1.0]), 2401*da0073e9SAndroid Build Coastguard Worker size=[3, ], 2402*da0073e9SAndroid Build Coastguard Worker device=device, 2403*da0073e9SAndroid Build Coastguard Worker dtype=dtype 2404*da0073e9SAndroid Build Coastguard Worker ) 2405*da0073e9SAndroid Build Coastguard Worker self._test_log1p_tensor(input_uncoalesced, coalesced) 2406*da0073e9SAndroid Build Coastguard Worker 2407*da0073e9SAndroid Build Coastguard Worker # test on empty sparse tensor 2408*da0073e9SAndroid Build Coastguard Worker input_uncoalesced = torch.sparse_coo_tensor( 2409*da0073e9SAndroid Build Coastguard Worker indices=torch.zeros([2, 0]), 2410*da0073e9SAndroid Build Coastguard Worker values=torch.zeros([0, 5, 5, 5, 5, 5, 5, 0]), 2411*da0073e9SAndroid Build Coastguard Worker size=[0, 0, 5, 5, 5, 5, 5, 5, 0], 2412*da0073e9SAndroid Build Coastguard Worker device=device, 2413*da0073e9SAndroid Build Coastguard Worker dtype=dtype 2414*da0073e9SAndroid Build Coastguard Worker ) 2415*da0073e9SAndroid Build Coastguard Worker # empty tensors are coalesced at creation (nnz < 2) we must force the uncoalesced state 2416*da0073e9SAndroid Build Coastguard Worker input_uncoalesced._coalesced_(False) 2417*da0073e9SAndroid Build Coastguard Worker self._test_log1p_tensor(input_uncoalesced, coalesced) 2418*da0073e9SAndroid Build Coastguard Worker 2419*da0073e9SAndroid Build Coastguard Worker def _test_neg_negative(self, sparse_tensor): 2420*da0073e9SAndroid Build Coastguard Worker dense_tensor = sparse_tensor.to_dense() 2421*da0073e9SAndroid Build Coastguard Worker expected_output = dense_tensor.neg() 2422*da0073e9SAndroid Build Coastguard Worker 2423*da0073e9SAndroid Build Coastguard Worker ops = ( 2424*da0073e9SAndroid Build Coastguard Worker torch.neg, torch.Tensor.neg, torch.Tensor.neg_, 2425*da0073e9SAndroid Build Coastguard Worker torch.negative, torch.Tensor.negative, torch.Tensor.negative_, 2426*da0073e9SAndroid Build Coastguard Worker operator.neg 2427*da0073e9SAndroid Build Coastguard Worker ) 2428*da0073e9SAndroid Build Coastguard Worker for op in ops: 2429*da0073e9SAndroid Build Coastguard Worker sparse_tensor_copy = sparse_tensor.clone() 2430*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_output, op(sparse_tensor_copy).to_dense()) 2431*da0073e9SAndroid Build Coastguard Worker 2432*da0073e9SAndroid Build Coastguard Worker if op in (torch.neg, torch.negative): 2433*da0073e9SAndroid Build Coastguard Worker sparse_tensor_out = torch.zeros_like(sparse_tensor) 2434*da0073e9SAndroid Build Coastguard Worker op(sparse_tensor, out=sparse_tensor_out) 2435*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_output, sparse_tensor_out.to_dense()) 2436*da0073e9SAndroid Build Coastguard Worker 2437*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 2438*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 2439*da0073e9SAndroid Build Coastguard Worker def test_neg_negative(self, device, dtype, coalesced): 2440*da0073e9SAndroid Build Coastguard Worker 2441*da0073e9SAndroid Build Coastguard Worker if coalesced: 2442*da0073e9SAndroid Build Coastguard Worker input_coalesced = torch.sparse_coo_tensor( 2443*da0073e9SAndroid Build Coastguard Worker indices=torch.tensor([[0, 1, 2]]), 2444*da0073e9SAndroid Build Coastguard Worker values=torch.tensor([3.0, -4.0, 5.0]), 2445*da0073e9SAndroid Build Coastguard Worker size=[3, ], 2446*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 2447*da0073e9SAndroid Build Coastguard Worker device=device 2448*da0073e9SAndroid Build Coastguard Worker ).coalesce() 2449*da0073e9SAndroid Build Coastguard Worker self._test_neg_negative(input_coalesced) 2450*da0073e9SAndroid Build Coastguard Worker 2451*da0073e9SAndroid Build Coastguard Worker # hybrid sparse input 2452*da0073e9SAndroid Build Coastguard Worker input_coalesced = torch.sparse_coo_tensor( 2453*da0073e9SAndroid Build Coastguard Worker indices=torch.tensor([[1, 3], [2, 4]]), 2454*da0073e9SAndroid Build Coastguard Worker values=torch.tensor([[-1.0, 3.0], [-5.0, 7.0]]), 2455*da0073e9SAndroid Build Coastguard Worker size=[4, 5, 2], 2456*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 2457*da0073e9SAndroid Build Coastguard Worker device=device 2458*da0073e9SAndroid Build Coastguard Worker ).coalesce() 2459*da0073e9SAndroid Build Coastguard Worker self._test_neg_negative(input_coalesced) 2460*da0073e9SAndroid Build Coastguard Worker 2461*da0073e9SAndroid Build Coastguard Worker if not coalesced: 2462*da0073e9SAndroid Build Coastguard Worker # test uncoalesced input 2463*da0073e9SAndroid Build Coastguard Worker input_uncoalesced = torch.sparse_coo_tensor( 2464*da0073e9SAndroid Build Coastguard Worker indices=torch.tensor([[0], [1], [2], [0], [1], [2]]).transpose(1, 0), 2465*da0073e9SAndroid Build Coastguard Worker values=torch.tensor([2.0, -3.0, -4.0, 1.0, -1.0, 1.5]), 2466*da0073e9SAndroid Build Coastguard Worker size=[3, ], 2467*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 2468*da0073e9SAndroid Build Coastguard Worker device=device 2469*da0073e9SAndroid Build Coastguard Worker ) 2470*da0073e9SAndroid Build Coastguard Worker self._test_neg_negative(input_uncoalesced) 2471*da0073e9SAndroid Build Coastguard Worker 2472*da0073e9SAndroid Build Coastguard Worker # test on empty sparse tensor 2473*da0073e9SAndroid Build Coastguard Worker input_uncoalesced = torch.sparse_coo_tensor( 2474*da0073e9SAndroid Build Coastguard Worker indices=torch.zeros([2, 0]), 2475*da0073e9SAndroid Build Coastguard Worker values=torch.zeros([0, 5, 5, 5, 5, 5, 5, 0]), 2476*da0073e9SAndroid Build Coastguard Worker size=[0, 0, 5, 5, 5, 5, 5, 5, 0], 2477*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 2478*da0073e9SAndroid Build Coastguard Worker device=device 2479*da0073e9SAndroid Build Coastguard Worker ) 2480*da0073e9SAndroid Build Coastguard Worker self._test_neg_negative(input_uncoalesced) 2481*da0073e9SAndroid Build Coastguard Worker 2482*da0073e9SAndroid Build Coastguard Worker def _test_asin_arcsin(self, sparse_tensor, coalesced): 2483*da0073e9SAndroid Build Coastguard Worker def is_integral(dtype): 2484*da0073e9SAndroid Build Coastguard Worker return dtype in integral_types() 2485*da0073e9SAndroid Build Coastguard Worker is_integral_dtype = is_integral(sparse_tensor.dtype) 2486*da0073e9SAndroid Build Coastguard Worker 2487*da0073e9SAndroid Build Coastguard Worker dense_tensor = sparse_tensor.to_dense() 2488*da0073e9SAndroid Build Coastguard Worker expected_output = dense_tensor.asin() 2489*da0073e9SAndroid Build Coastguard Worker 2490*da0073e9SAndroid Build Coastguard Worker ops = ( 2491*da0073e9SAndroid Build Coastguard Worker torch.asin, torch.Tensor.asin, 2492*da0073e9SAndroid Build Coastguard Worker torch.arcsin, torch.Tensor.arcsin, 2493*da0073e9SAndroid Build Coastguard Worker ) 2494*da0073e9SAndroid Build Coastguard Worker for op in ops: 2495*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_output, op(sparse_tensor).to_dense()) 2496*da0073e9SAndroid Build Coastguard Worker if op in (torch.asin, torch.arcsin): 2497*da0073e9SAndroid Build Coastguard Worker sparse_tensor_out = torch.zeros_like(sparse_tensor) 2498*da0073e9SAndroid Build Coastguard Worker if not is_integral_dtype: 2499*da0073e9SAndroid Build Coastguard Worker op(sparse_tensor, out=sparse_tensor_out) 2500*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_output, sparse_tensor_out.to_dense()) 2501*da0073e9SAndroid Build Coastguard Worker else: 2502*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "result type .* can't be cast to"): 2503*da0073e9SAndroid Build Coastguard Worker op(sparse_tensor, out=sparse_tensor_out) 2504*da0073e9SAndroid Build Coastguard Worker 2505*da0073e9SAndroid Build Coastguard Worker for op in (torch.Tensor.asin_, torch.Tensor.arcsin_): 2506*da0073e9SAndroid Build Coastguard Worker if is_integral_dtype: 2507*da0073e9SAndroid Build Coastguard Worker # test coalesce on integral dtype tensor 2508*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "result type .* can't be cast to"): 2509*da0073e9SAndroid Build Coastguard Worker op(sparse_tensor.clone().coalesce()).to_dense() 2510*da0073e9SAndroid Build Coastguard Worker else: 2511*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_output, op(sparse_tensor.clone().coalesce()).to_dense()) 2512*da0073e9SAndroid Build Coastguard Worker 2513*da0073e9SAndroid Build Coastguard Worker if not coalesced: 2514*da0073e9SAndroid Build Coastguard Worker # test in-place op on uncoalesced input 2515*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "asin_ requires coalesced input"): 2516*da0073e9SAndroid Build Coastguard Worker op(sparse_tensor) 2517*da0073e9SAndroid Build Coastguard Worker 2518*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 2519*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types()) 2520*da0073e9SAndroid Build Coastguard Worker def test_asin_arcsin(self, device, dtype, coalesced): 2521*da0073e9SAndroid Build Coastguard Worker if coalesced: 2522*da0073e9SAndroid Build Coastguard Worker input_coalesced = torch.sparse_coo_tensor( 2523*da0073e9SAndroid Build Coastguard Worker indices=torch.tensor([[0, 1, 2, 3]]), 2524*da0073e9SAndroid Build Coastguard Worker values=torch.tensor([0.5, -0.5, 0.7, -0.7]), 2525*da0073e9SAndroid Build Coastguard Worker size=[4, ], 2526*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 2527*da0073e9SAndroid Build Coastguard Worker device=device 2528*da0073e9SAndroid Build Coastguard Worker ).coalesce() 2529*da0073e9SAndroid Build Coastguard Worker self._test_asin_arcsin(input_coalesced, coalesced) 2530*da0073e9SAndroid Build Coastguard Worker 2531*da0073e9SAndroid Build Coastguard Worker # hybrid sparse input 2532*da0073e9SAndroid Build Coastguard Worker input_coalesced = torch.sparse_coo_tensor( 2533*da0073e9SAndroid Build Coastguard Worker indices=torch.tensor([[1, 3], [2, 4]]), 2534*da0073e9SAndroid Build Coastguard Worker values=torch.tensor([[-0.1, 0.24], [-0.44, 0.1]]), 2535*da0073e9SAndroid Build Coastguard Worker size=[4, 5, 2], 2536*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 2537*da0073e9SAndroid Build Coastguard Worker device=device 2538*da0073e9SAndroid Build Coastguard Worker ).coalesce() 2539*da0073e9SAndroid Build Coastguard Worker self._test_asin_arcsin(input_coalesced, coalesced) 2540*da0073e9SAndroid Build Coastguard Worker 2541*da0073e9SAndroid Build Coastguard Worker if not coalesced: 2542*da0073e9SAndroid Build Coastguard Worker # test uncoalesced input 2543*da0073e9SAndroid Build Coastguard Worker input_uncoalesced = torch.sparse_coo_tensor( 2544*da0073e9SAndroid Build Coastguard Worker indices=torch.tensor([[0], [1], [2], [0], [1], [2]]).transpose(1, 0), 2545*da0073e9SAndroid Build Coastguard Worker values=torch.tensor([0.3, -0.3, -0.4, 0.3, -0.5, 0.15]), 2546*da0073e9SAndroid Build Coastguard Worker size=[3, ], 2547*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 2548*da0073e9SAndroid Build Coastguard Worker device=device 2549*da0073e9SAndroid Build Coastguard Worker ) 2550*da0073e9SAndroid Build Coastguard Worker self._test_asin_arcsin(input_uncoalesced, coalesced) 2551*da0073e9SAndroid Build Coastguard Worker 2552*da0073e9SAndroid Build Coastguard Worker # test on empty sparse tensor 2553*da0073e9SAndroid Build Coastguard Worker input_uncoalesced = torch.sparse_coo_tensor( 2554*da0073e9SAndroid Build Coastguard Worker indices=torch.zeros([2, 0]), 2555*da0073e9SAndroid Build Coastguard Worker values=torch.zeros([0, 5, 5, 5, 5, 5, 5, 0]), 2556*da0073e9SAndroid Build Coastguard Worker size=[0, 0, 5, 5, 5, 5, 5, 5, 0], 2557*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 2558*da0073e9SAndroid Build Coastguard Worker device=device 2559*da0073e9SAndroid Build Coastguard Worker ) 2560*da0073e9SAndroid Build Coastguard Worker # empty tensors are coalesced at creation (nnz < 2) we must force the uncoalesced state 2561*da0073e9SAndroid Build Coastguard Worker input_uncoalesced._coalesced_(False) 2562*da0073e9SAndroid Build Coastguard Worker self._test_asin_arcsin(input_uncoalesced, coalesced) 2563*da0073e9SAndroid Build Coastguard Worker 2564*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 2565*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 2566*da0073e9SAndroid Build Coastguard Worker def test_mv(self, device, dtype, coalesced): 2567*da0073e9SAndroid Build Coastguard Worker def test_shape(di, dj, dk, nnz): 2568*da0073e9SAndroid Build Coastguard Worker x, _, _ = self._gen_sparse(2, nnz, [di, dj], dtype, device, coalesced) 2569*da0073e9SAndroid Build Coastguard Worker t = torch.randn(dk, dtype=dtype, device=device) 2570*da0073e9SAndroid Build Coastguard Worker 2571*da0073e9SAndroid Build Coastguard Worker res = x.matmul(t) 2572*da0073e9SAndroid Build Coastguard Worker expected = self.safeToDense(x).matmul(t) 2573*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected) 2574*da0073e9SAndroid Build Coastguard Worker 2575*da0073e9SAndroid Build Coastguard Worker test_shape(10, 100, 100, 20) 2576*da0073e9SAndroid Build Coastguard Worker test_shape(100, 1000, 1000, 20) 2577*da0073e9SAndroid Build Coastguard Worker test_shape(64, 10000, 10000, 20) 2578*da0073e9SAndroid Build Coastguard Worker test_shape(0, 100, 100, 0) 2579*da0073e9SAndroid Build Coastguard Worker test_shape(10, 0, 0, 0) 2580*da0073e9SAndroid Build Coastguard Worker test_shape(10, 100, 100, 0) 2581*da0073e9SAndroid Build Coastguard Worker test_shape(10, 100, 100, 20) 2582*da0073e9SAndroid Build Coastguard Worker 2583*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"mv: expected self\.size\(-1\) == vec\.size\(-1\)"): 2584*da0073e9SAndroid Build Coastguard Worker test_shape(10, 100, 10, 20) 2585*da0073e9SAndroid Build Coastguard Worker 2586*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "mv: two tensor dim should be 2 and 1"): 2587*da0073e9SAndroid Build Coastguard Worker x, _, _ = self._gen_sparse(2, 20, [10, 100], dtype, device, coalesced) 2588*da0073e9SAndroid Build Coastguard Worker y, _, _ = self._gen_sparse(2, 20, [10, 100], dtype, device, coalesced) 2589*da0073e9SAndroid Build Coastguard Worker res = x.mv(y) 2590*da0073e9SAndroid Build Coastguard Worker 2591*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 2592*da0073e9SAndroid Build Coastguard Worker def test_sparse_add_coalesce(self, device, dtype): 2593*da0073e9SAndroid Build Coastguard Worker i = self.index_tensor([[1, 2, 1]], device=device) 2594*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([3, 4, 5], dtype=dtype, device=device) 2595*da0073e9SAndroid Build Coastguard Worker x = self.sparse_tensor(i, v, torch.Size([3])) 2596*da0073e9SAndroid Build Coastguard Worker y = self.sparse_tensor(i, v, torch.Size([3])) 2597*da0073e9SAndroid Build Coastguard Worker z = x + y 2598*da0073e9SAndroid Build Coastguard Worker 2599*da0073e9SAndroid Build Coastguard Worker self.assertFalse(z._indices().numel() != 2 and z.is_coalesced()) 2600*da0073e9SAndroid Build Coastguard Worker 2601*da0073e9SAndroid Build Coastguard Worker i = self.index_tensor([[1, 2, 1]], device=device) 2602*da0073e9SAndroid Build Coastguard Worker v = torch.empty([3, 0], dtype=dtype, device=device) 2603*da0073e9SAndroid Build Coastguard Worker x = self.sparse_tensor(i, v, torch.Size([3, 0])) 2604*da0073e9SAndroid Build Coastguard Worker y = self.sparse_tensor(i, v, torch.Size([3, 0])) 2605*da0073e9SAndroid Build Coastguard Worker z = x + y 2606*da0073e9SAndroid Build Coastguard Worker 2607*da0073e9SAndroid Build Coastguard Worker self.assertFalse(z._indices().numel() != 2 and z.is_coalesced()) 2608*da0073e9SAndroid Build Coastguard Worker 2609*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 2610*da0073e9SAndroid Build Coastguard Worker def test_storage_not_null(self, device): 2611*da0073e9SAndroid Build Coastguard Worker x = torch.sparse_coo_tensor((2,), dtype=torch.float32, device=device) 2612*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(x.get_device(), -1) 2613*da0073e9SAndroid Build Coastguard Worker 2614*da0073e9SAndroid Build Coastguard Worker x = torch.sparse_coo_tensor((2, 0), dtype=torch.float32, device=device) 2615*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(x.get_device(), -1) 2616*da0073e9SAndroid Build Coastguard Worker 2617*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 2618*da0073e9SAndroid Build Coastguard Worker @deviceCountAtLeast(2) 2619*da0073e9SAndroid Build Coastguard Worker def test_same_gpu(self, devices): 2620*da0073e9SAndroid Build Coastguard Worker def check_device(x, device_id): 2621*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.get_device(), device_id) 2622*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x._values().get_device(), device_id) 2623*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x._indices().get_device(), device_id) 2624*da0073e9SAndroid Build Coastguard Worker 2625*da0073e9SAndroid Build Coastguard Worker dev1, dev2 = devices[0], devices[1] 2626*da0073e9SAndroid Build Coastguard Worker 2627*da0073e9SAndroid Build Coastguard Worker i = self.index_tensor([[2]], device=dev2) 2628*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([5], device=dev2) 2629*da0073e9SAndroid Build Coastguard Worker x = self.sparse_tensor(i, v, torch.Size([3]), device=1) 2630*da0073e9SAndroid Build Coastguard Worker check_device(x, 1) 2631*da0073e9SAndroid Build Coastguard Worker 2632*da0073e9SAndroid Build Coastguard Worker i = self.index_tensor([[2]], device=dev2) 2633*da0073e9SAndroid Build Coastguard Worker v = torch.empty(1, 0, device=dev2) 2634*da0073e9SAndroid Build Coastguard Worker x = self.sparse_tensor(i, v, torch.Size([3, 0]), device=1) 2635*da0073e9SAndroid Build Coastguard Worker check_device(x, 1) 2636*da0073e9SAndroid Build Coastguard Worker 2637*da0073e9SAndroid Build Coastguard Worker x = self.sparse_empty(3, device=1) 2638*da0073e9SAndroid Build Coastguard Worker check_device(x, 1) 2639*da0073e9SAndroid Build Coastguard Worker 2640*da0073e9SAndroid Build Coastguard Worker x = self.sparse_empty(3, 0, device=1) 2641*da0073e9SAndroid Build Coastguard Worker check_device(x, 1) 2642*da0073e9SAndroid Build Coastguard Worker 2643*da0073e9SAndroid Build Coastguard Worker def _test_new_device(self, size, device=torch.cuda): 2644*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(device): 2645*da0073e9SAndroid Build Coastguard Worker x = torch.sparse_coo_tensor(size, device='cuda', dtype=torch.float64) 2646*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.get_device(), device) 2647*da0073e9SAndroid Build Coastguard Worker x1 = x.new() 2648*da0073e9SAndroid Build Coastguard Worker x2 = x.new(2, 3) 2649*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1.get_device(), device) 2650*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x2.get_device(), device) 2651*da0073e9SAndroid Build Coastguard Worker 2652*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 2653*da0073e9SAndroid Build Coastguard Worker def test_new_device_single_gpu(self): 2654*da0073e9SAndroid Build Coastguard Worker self._test_new_device((), 0) 2655*da0073e9SAndroid Build Coastguard Worker self._test_new_device((30, 20), 0) 2656*da0073e9SAndroid Build Coastguard Worker self._test_new_device((30, 20, 10), 0) 2657*da0073e9SAndroid Build Coastguard Worker self._test_new_device((30, 20, 10, 0), 0) 2658*da0073e9SAndroid Build Coastguard Worker 2659*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 2660*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") 2661*da0073e9SAndroid Build Coastguard Worker def test_new_device_multi_gpu(self): 2662*da0073e9SAndroid Build Coastguard Worker self._test_new_device((), 1) 2663*da0073e9SAndroid Build Coastguard Worker self._test_new_device((30, 20), 1) 2664*da0073e9SAndroid Build Coastguard Worker self._test_new_device((30, 20, 10), 1) 2665*da0073e9SAndroid Build Coastguard Worker self._test_new_device((30, 20, 10, 0), 1) 2666*da0073e9SAndroid Build Coastguard Worker 2667*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 2668*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 2669*da0073e9SAndroid Build Coastguard Worker def test_new(self, device, dtype, coalesced): 2670*da0073e9SAndroid Build Coastguard Worker def test_shape(sparse_dims, nnz, with_size): 2671*da0073e9SAndroid Build Coastguard Worker x, indices, values = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced) 2672*da0073e9SAndroid Build Coastguard Worker if not x.is_cuda: 2673*da0073e9SAndroid Build Coastguard Worker # CUDA sparse tensors currently requires the size to be 2674*da0073e9SAndroid Build Coastguard Worker # specified if nDimV > 0 2675*da0073e9SAndroid Build Coastguard Worker out = x.new(indices, values).coalesce() 2676*da0073e9SAndroid Build Coastguard Worker x_c = x.coalesce() 2677*da0073e9SAndroid Build Coastguard Worker self.assertEqual((out.indices(), out.values()), (x_c.indices(), x_c.values())) 2678*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.new(indices, values, x.size()), x) 2679*da0073e9SAndroid Build Coastguard Worker 2680*da0073e9SAndroid Build Coastguard Worker test_shape(3, 10, 100) 2681*da0073e9SAndroid Build Coastguard Worker test_shape(3, 0, [100, 100, 0]) 2682*da0073e9SAndroid Build Coastguard Worker 2683*da0073e9SAndroid Build Coastguard Worker @onlyCPU # not really, but we only really want to run this once 2684*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float64, torch.float32, torch.float16, torch.cfloat, torch.cdouble) 2685*da0073e9SAndroid Build Coastguard Worker def test_factory(self, device, dtype): 2686*da0073e9SAndroid Build Coastguard Worker for test_empty_tensor in [True, False]: 2687*da0073e9SAndroid Build Coastguard Worker if test_empty_tensor: 2688*da0073e9SAndroid Build Coastguard Worker default_size = torch.Size([1, 3, 0]) 2689*da0073e9SAndroid Build Coastguard Worker size = torch.Size([3, 3, 0]) 2690*da0073e9SAndroid Build Coastguard Worker else: 2691*da0073e9SAndroid Build Coastguard Worker default_size = torch.Size([1, 3]) 2692*da0073e9SAndroid Build Coastguard Worker size = torch.Size([3, 3]) 2693*da0073e9SAndroid Build Coastguard Worker for include_size in [True, False]: 2694*da0073e9SAndroid Build Coastguard Worker for use_tensor_idx in [True, False]: 2695*da0073e9SAndroid Build Coastguard Worker for use_tensor_val in [True, False]: 2696*da0073e9SAndroid Build Coastguard Worker for use_cuda in ([False] if not torch.cuda.is_available() else [True, False]): 2697*da0073e9SAndroid Build Coastguard Worker # have to include size with cuda sparse tensors 2698*da0073e9SAndroid Build Coastguard Worker include_size = include_size or use_cuda 2699*da0073e9SAndroid Build Coastguard Worker long_dtype = torch.int64 2700*da0073e9SAndroid Build Coastguard Worker device = torch.device('cpu') if not use_cuda else \ 2701*da0073e9SAndroid Build Coastguard Worker torch.device(torch.cuda.device_count() - 1) 2702*da0073e9SAndroid Build Coastguard Worker indices = torch.tensor(([0], [2]), dtype=long_dtype) if use_tensor_idx else ([0], [2]) 2703*da0073e9SAndroid Build Coastguard Worker if test_empty_tensor: 2704*da0073e9SAndroid Build Coastguard Worker values = torch.empty(1, 0).to(dtype) 2705*da0073e9SAndroid Build Coastguard Worker else: 2706*da0073e9SAndroid Build Coastguard Worker if use_tensor_val: 2707*da0073e9SAndroid Build Coastguard Worker values = torch.tensor([1.], dtype=dtype) 2708*da0073e9SAndroid Build Coastguard Worker else: 2709*da0073e9SAndroid Build Coastguard Worker values = 1. 2710*da0073e9SAndroid Build Coastguard Worker if include_size: 2711*da0073e9SAndroid Build Coastguard Worker sparse_tensor = torch.sparse_coo_tensor(indices, values, size, dtype=dtype, 2712*da0073e9SAndroid Build Coastguard Worker device=device, requires_grad=True) 2713*da0073e9SAndroid Build Coastguard Worker else: 2714*da0073e9SAndroid Build Coastguard Worker sparse_tensor = torch.sparse_coo_tensor(indices, values, dtype=dtype, 2715*da0073e9SAndroid Build Coastguard Worker device=device, requires_grad=True) 2716*da0073e9SAndroid Build Coastguard Worker self.assertEqual(indices, sparse_tensor._indices()) 2717*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values, sparse_tensor._values()) 2718*da0073e9SAndroid Build Coastguard Worker self.assertEqual(size if include_size else default_size, sparse_tensor.size()) 2719*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dtype, sparse_tensor.dtype) 2720*da0073e9SAndroid Build Coastguard Worker if use_cuda: 2721*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device, sparse_tensor._values().device) 2722*da0073e9SAndroid Build Coastguard Worker self.assertEqual(True, sparse_tensor.requires_grad) 2723*da0073e9SAndroid Build Coastguard Worker 2724*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 2725*da0073e9SAndroid Build Coastguard Worker def test_factory_size_check(self, device, dtype): 2726*da0073e9SAndroid Build Coastguard Worker indices = self.index_tensor([[1, 2], 2727*da0073e9SAndroid Build Coastguard Worker [0, 2]], device=device) 2728*da0073e9SAndroid Build Coastguard Worker values = torch.tensor([.5, .5], dtype=dtype, device=device) 2729*da0073e9SAndroid Build Coastguard Worker sizes = torch.Size([2, 3]) 2730*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "size is inconsistent with indices"): 2731*da0073e9SAndroid Build Coastguard Worker torch.sparse_coo_tensor(indices, values, sizes, dtype=dtype, device=device) 2732*da0073e9SAndroid Build Coastguard Worker 2733*da0073e9SAndroid Build Coastguard Worker indices.fill_(-1) 2734*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "found negative index"): 2735*da0073e9SAndroid Build Coastguard Worker torch.sparse_coo_tensor(indices, values, sizes, dtype=dtype, device=device) 2736*da0073e9SAndroid Build Coastguard Worker 2737*da0073e9SAndroid Build Coastguard Worker indices = self.index_tensor([[1, 2], 2738*da0073e9SAndroid Build Coastguard Worker [0, 2]], device=device) 2739*da0073e9SAndroid Build Coastguard Worker values = torch.empty([2, 1, 0], dtype=dtype, device=device) 2740*da0073e9SAndroid Build Coastguard Worker sizes = torch.Size([2, 3, 1, 0]) 2741*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "size is inconsistent with indices"): 2742*da0073e9SAndroid Build Coastguard Worker torch.sparse_coo_tensor(indices, values, sizes, dtype=dtype, device=device) 2743*da0073e9SAndroid Build Coastguard Worker 2744*da0073e9SAndroid Build Coastguard Worker indices = self.index_tensor([[1, 2], 2745*da0073e9SAndroid Build Coastguard Worker [0, 2]], device=device) 2746*da0073e9SAndroid Build Coastguard Worker values = torch.empty([2, 2, 2], dtype=dtype, device=device) 2747*da0073e9SAndroid Build Coastguard Worker sizes = torch.Size([0, 0, 2, 2]) 2748*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "size is inconsistent with indices"): 2749*da0073e9SAndroid Build Coastguard Worker torch.sparse_coo_tensor(indices, values, sizes, dtype=dtype, device=device) 2750*da0073e9SAndroid Build Coastguard Worker 2751*da0073e9SAndroid Build Coastguard Worker indices = self.index_tensor([[1, 2], 2752*da0073e9SAndroid Build Coastguard Worker [0, 2]], device=device) 2753*da0073e9SAndroid Build Coastguard Worker values = torch.tensor([[1, 1, 1], [1, 1, 1]], dtype=dtype, device=device) 2754*da0073e9SAndroid Build Coastguard Worker sizes = torch.Size([3, 3, 2]) 2755*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "values has incorrect size"): 2756*da0073e9SAndroid Build Coastguard Worker torch.sparse_coo_tensor(indices, values, sizes, dtype=dtype, device=device) 2757*da0073e9SAndroid Build Coastguard Worker 2758*da0073e9SAndroid Build Coastguard Worker indices = self.index_tensor([[1, 2], 2759*da0073e9SAndroid Build Coastguard Worker [0, 2]], device=device) 2760*da0073e9SAndroid Build Coastguard Worker values = torch.empty([2, 1, 0], dtype=dtype, device=device) 2761*da0073e9SAndroid Build Coastguard Worker sizes = torch.Size([3, 3, 2, 0]) 2762*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "values has incorrect size"): 2763*da0073e9SAndroid Build Coastguard Worker torch.sparse_coo_tensor(indices, values, sizes, dtype=dtype, device=device) 2764*da0073e9SAndroid Build Coastguard Worker 2765*da0073e9SAndroid Build Coastguard Worker def test_factory_empty_indices(self, device): 2766*da0073e9SAndroid Build Coastguard Worker tensor = torch.sparse_coo_tensor(torch.Size([2, 0]), device=device) 2767*da0073e9SAndroid Build Coastguard Worker expected_indices = torch.empty((2, 0), dtype=torch.long, device=device) 2768*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor._indices(), expected_indices) 2769*da0073e9SAndroid Build Coastguard Worker 2770*da0073e9SAndroid Build Coastguard Worker tensor = torch.sparse_coo_tensor(torch.Size([2, 2, 0]), device=device) 2771*da0073e9SAndroid Build Coastguard Worker expected_indices = torch.empty((3, 0), dtype=torch.long, device=device) 2772*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor._indices(), expected_indices) 2773*da0073e9SAndroid Build Coastguard Worker 2774*da0073e9SAndroid Build Coastguard Worker tensor = torch.sparse_coo_tensor(torch.Size([2, 2, 0, 0]), device=device) 2775*da0073e9SAndroid Build Coastguard Worker expected_indices = torch.empty((4, 0), dtype=torch.long, device=device) 2776*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor._indices(), expected_indices) 2777*da0073e9SAndroid Build Coastguard Worker 2778*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 2779*da0073e9SAndroid Build Coastguard Worker def test_factory_nnz(self, device, dtype): 2780*da0073e9SAndroid Build Coastguard Worker indices = self.index_tensor([[0]], device=device) # (sparse_dim, nnz): (1, 1) 2781*da0073e9SAndroid Build Coastguard Worker values = torch.tensor([[1, 1], [1, 1]], dtype=dtype, device=device) # (nnz, ...): (2, 2) 2782*da0073e9SAndroid Build Coastguard Worker sizes = torch.Size([2, 2]) 2783*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "indices and values must have same nnz"): 2784*da0073e9SAndroid Build Coastguard Worker torch.sparse_coo_tensor(indices, values, sizes, dtype=dtype, device=device) 2785*da0073e9SAndroid Build Coastguard Worker 2786*da0073e9SAndroid Build Coastguard Worker indices = self.index_tensor([[0]], device=device) # (sparse_dim, nnz): (1, 1) 2787*da0073e9SAndroid Build Coastguard Worker values = torch.empty([2, 0], dtype=dtype, device=device) # (nnz, ...): (2, 0) 2788*da0073e9SAndroid Build Coastguard Worker sizes = torch.Size([2, 0]) 2789*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "indices and values must have same nnz"): 2790*da0073e9SAndroid Build Coastguard Worker torch.sparse_coo_tensor(indices, values, sizes, dtype=dtype, device=device) 2791*da0073e9SAndroid Build Coastguard Worker 2792*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 2793*da0073e9SAndroid Build Coastguard Worker def test_factory_nnz_zero(self, device, dtype): 2794*da0073e9SAndroid Build Coastguard Worker def test_shape(i_shape, v_shape, size, expected_size): 2795*da0073e9SAndroid Build Coastguard Worker if size: 2796*da0073e9SAndroid Build Coastguard Worker t = torch.sparse_coo_tensor(torch.empty(i_shape), torch.empty(v_shape), torch.Size(size), 2797*da0073e9SAndroid Build Coastguard Worker dtype=dtype, device=device) 2798*da0073e9SAndroid Build Coastguard Worker else: 2799*da0073e9SAndroid Build Coastguard Worker t = torch.sparse_coo_tensor(torch.empty(i_shape), torch.empty(v_shape), dtype=dtype, device=device) 2800*da0073e9SAndroid Build Coastguard Worker expected_indices = torch.empty(i_shape, device=device, dtype=torch.int64) 2801*da0073e9SAndroid Build Coastguard Worker expected_values = torch.empty(v_shape, device=device, dtype=dtype) 2802*da0073e9SAndroid Build Coastguard Worker expected_size = torch.Size(expected_size) 2803*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t._indices(), expected_indices) 2804*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t._values(), expected_values) 2805*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.size(), expected_size) 2806*da0073e9SAndroid Build Coastguard Worker 2807*da0073e9SAndroid Build Coastguard Worker test_shape([1, 0], [0, 2, 4, 0], None, [0, 2, 4, 0]) 2808*da0073e9SAndroid Build Coastguard Worker test_shape([3, 0], [0, 2, 4, 0], None, [0, 0, 0, 2, 4, 0]) 2809*da0073e9SAndroid Build Coastguard Worker test_shape([1, 0], [0, 2, 4, 0], [0, 2, 4, 0], [0, 2, 4, 0]) 2810*da0073e9SAndroid Build Coastguard Worker test_shape([3, 0], [0, 2, 4, 0], [0, 0, 0, 2, 4, 0], [0, 0, 0, 2, 4, 0]) 2811*da0073e9SAndroid Build Coastguard Worker test_shape([3, 0], [0, 2, 4, 0], [1, 2, 3, 2, 4, 0], [1, 2, 3, 2, 4, 0]) 2812*da0073e9SAndroid Build Coastguard Worker 2813*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 2814*da0073e9SAndroid Build Coastguard Worker def test_factory_dense_dim(self, device, dtype): 2815*da0073e9SAndroid Build Coastguard Worker indices = self.index_tensor([[0]], device=device) 2816*da0073e9SAndroid Build Coastguard Worker values = torch.tensor([[[1, 1, 1], [1, 1, 1]]], dtype=dtype, device=device) 2817*da0073e9SAndroid Build Coastguard Worker sizes = torch.Size([1, 3, 4]) 2818*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "values has incorrect size"): 2819*da0073e9SAndroid Build Coastguard Worker torch.sparse_coo_tensor(indices, values, sizes) 2820*da0073e9SAndroid Build Coastguard Worker 2821*da0073e9SAndroid Build Coastguard Worker indices = self.index_tensor([[0]], device=device) 2822*da0073e9SAndroid Build Coastguard Worker values = torch.empty([1, 2, 3, 0], dtype=dtype, device=device) 2823*da0073e9SAndroid Build Coastguard Worker sizes = torch.Size([1, 3, 4, 0]) 2824*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "values has incorrect size"): 2825*da0073e9SAndroid Build Coastguard Worker torch.sparse_coo_tensor(indices, values, sizes) 2826*da0073e9SAndroid Build Coastguard Worker 2827*da0073e9SAndroid Build Coastguard Worker @onlyCPU 2828*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float16, torch.float32, torch.float64, torch.cfloat, torch.cdouble, torch.int64) 2829*da0073e9SAndroid Build Coastguard Worker def test_factory_type_inference(self, device, dtype): 2830*da0073e9SAndroid Build Coastguard Worker t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.tensor([1.], dtype=dtype)) 2831*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dtype, t.dtype) 2832*da0073e9SAndroid Build Coastguard Worker t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.tensor([1])) 2833*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.int64, t.dtype) 2834*da0073e9SAndroid Build Coastguard Worker 2835*da0073e9SAndroid Build Coastguard Worker t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.HalfTensor(1, 0)) 2836*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.float16, t.dtype) 2837*da0073e9SAndroid Build Coastguard Worker t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.FloatTensor(1, 0)) 2838*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.float32, t.dtype) 2839*da0073e9SAndroid Build Coastguard Worker t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.DoubleTensor(1, 0)) 2840*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.float64, t.dtype) 2841*da0073e9SAndroid Build Coastguard Worker t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.LongTensor(1, 0)) 2842*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.int64, t.dtype) 2843*da0073e9SAndroid Build Coastguard Worker 2844*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 2845*da0073e9SAndroid Build Coastguard Worker def test_factory_device_type_inference(self, device): 2846*da0073e9SAndroid Build Coastguard Worker # both indices/values are CUDA 2847*da0073e9SAndroid Build Coastguard Worker 2848*da0073e9SAndroid Build Coastguard Worker cpu_cuda = ('cpu', 'cuda') 2849*da0073e9SAndroid Build Coastguard Worker cpu_cuda_none = cpu_cuda + (None,) 2850*da0073e9SAndroid Build Coastguard Worker for indices_device, values_device, device in itertools.product(cpu_cuda, 2851*da0073e9SAndroid Build Coastguard Worker cpu_cuda, 2852*da0073e9SAndroid Build Coastguard Worker cpu_cuda_none): 2853*da0073e9SAndroid Build Coastguard Worker indices = torch.tensor(([0], [2]), device=indices_device) 2854*da0073e9SAndroid Build Coastguard Worker values = torch.tensor([1.], device=values_device) 2855*da0073e9SAndroid Build Coastguard Worker empty_values = torch.empty(1, 0).to(values_device) 2856*da0073e9SAndroid Build Coastguard Worker shape = (1, 3) 2857*da0073e9SAndroid Build Coastguard Worker empty_shape = (1, 3, 0) 2858*da0073e9SAndroid Build Coastguard Worker if device is None and indices_device != values_device: 2859*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 2860*da0073e9SAndroid Build Coastguard Worker torch.sparse_coo_tensor(indices, values, shape, device=device) 2861*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 2862*da0073e9SAndroid Build Coastguard Worker torch.sparse_coo_tensor(indices, empty_values, empty_shape, device=device) 2863*da0073e9SAndroid Build Coastguard Worker else: 2864*da0073e9SAndroid Build Coastguard Worker t = torch.sparse_coo_tensor(indices, values, shape, device=device) 2865*da0073e9SAndroid Build Coastguard Worker t_empty = torch.sparse_coo_tensor(indices, empty_values, empty_shape, device=device) 2866*da0073e9SAndroid Build Coastguard Worker should_be_cuda = (device == 'cuda' or (device is None and values_device == 'cuda')) 2867*da0073e9SAndroid Build Coastguard Worker self.assertEqual(should_be_cuda, t.is_cuda) 2868*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.is_cuda, t_empty.is_cuda) 2869*da0073e9SAndroid Build Coastguard Worker 2870*da0073e9SAndroid Build Coastguard Worker @onlyCPU 2871*da0073e9SAndroid Build Coastguard Worker def test_factory_copy(self, device): 2872*da0073e9SAndroid Build Coastguard Worker def test_tensor(indices, values, indices_equal, values_equal): 2873*da0073e9SAndroid Build Coastguard Worker sparse_tensor = torch.sparse_coo_tensor(indices, values, dtype=torch.float64, device=device) 2874*da0073e9SAndroid Build Coastguard Worker if indices_equal: 2875*da0073e9SAndroid Build Coastguard Worker self.assertEqual(indices.data_ptr(), sparse_tensor._indices().data_ptr()) 2876*da0073e9SAndroid Build Coastguard Worker else: 2877*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(indices.data_ptr(), sparse_tensor._indices().data_ptr()) 2878*da0073e9SAndroid Build Coastguard Worker if values_equal: 2879*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values.data_ptr(), sparse_tensor._values().data_ptr()) 2880*da0073e9SAndroid Build Coastguard Worker else: 2881*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(values.data_ptr(), sparse_tensor._values().data_ptr()) 2882*da0073e9SAndroid Build Coastguard Worker 2883*da0073e9SAndroid Build Coastguard Worker # both correct 2884*da0073e9SAndroid Build Coastguard Worker indices = torch.tensor(([0], [2]), dtype=torch.int64) 2885*da0073e9SAndroid Build Coastguard Worker values = torch.tensor([1.], dtype=torch.float64) 2886*da0073e9SAndroid Build Coastguard Worker test_tensor(indices, values, True, True) 2887*da0073e9SAndroid Build Coastguard Worker 2888*da0073e9SAndroid Build Coastguard Worker indices = torch.tensor(([0], [2]), dtype=torch.int64) 2889*da0073e9SAndroid Build Coastguard Worker values = torch.DoubleTensor(1, 0) 2890*da0073e9SAndroid Build Coastguard Worker test_tensor(indices, values, True, True) 2891*da0073e9SAndroid Build Coastguard Worker 2892*da0073e9SAndroid Build Coastguard Worker # only indices correct 2893*da0073e9SAndroid Build Coastguard Worker indices = torch.tensor(([0], [2]), dtype=torch.int64) 2894*da0073e9SAndroid Build Coastguard Worker values = torch.tensor([1.], dtype=torch.float32) 2895*da0073e9SAndroid Build Coastguard Worker test_tensor(indices, values, True, False) 2896*da0073e9SAndroid Build Coastguard Worker 2897*da0073e9SAndroid Build Coastguard Worker indices = torch.tensor(([0], [2]), dtype=torch.int64) 2898*da0073e9SAndroid Build Coastguard Worker values = torch.tensor([1.], dtype=torch.float16) 2899*da0073e9SAndroid Build Coastguard Worker test_tensor(indices, values, True, False) 2900*da0073e9SAndroid Build Coastguard Worker 2901*da0073e9SAndroid Build Coastguard Worker indices = torch.tensor(([0], [2]), dtype=torch.int64) 2902*da0073e9SAndroid Build Coastguard Worker values = torch.FloatTensor(1, 0) 2903*da0073e9SAndroid Build Coastguard Worker test_tensor(indices, values, True, True) # An empty tensor's data_ptr is always equal to 0 2904*da0073e9SAndroid Build Coastguard Worker 2905*da0073e9SAndroid Build Coastguard Worker # only values correct 2906*da0073e9SAndroid Build Coastguard Worker indices = torch.tensor(([0], [2]), dtype=torch.int32) 2907*da0073e9SAndroid Build Coastguard Worker values = torch.tensor([1.], dtype=torch.float64) 2908*da0073e9SAndroid Build Coastguard Worker test_tensor(indices, values, False, True) 2909*da0073e9SAndroid Build Coastguard Worker 2910*da0073e9SAndroid Build Coastguard Worker indices = torch.tensor(([0], [2]), dtype=torch.int32) 2911*da0073e9SAndroid Build Coastguard Worker values = torch.DoubleTensor(1, 0) 2912*da0073e9SAndroid Build Coastguard Worker test_tensor(indices, values, False, True) 2913*da0073e9SAndroid Build Coastguard Worker 2914*da0073e9SAndroid Build Coastguard Worker # neither correct 2915*da0073e9SAndroid Build Coastguard Worker indices = torch.tensor(([0], [2]), dtype=torch.int32) 2916*da0073e9SAndroid Build Coastguard Worker values = torch.tensor([1.], dtype=torch.float32) 2917*da0073e9SAndroid Build Coastguard Worker test_tensor(indices, values, False, False) 2918*da0073e9SAndroid Build Coastguard Worker 2919*da0073e9SAndroid Build Coastguard Worker indices = torch.tensor(([0], [2]), dtype=torch.int32) 2920*da0073e9SAndroid Build Coastguard Worker values = torch.FloatTensor(1, 0) 2921*da0073e9SAndroid Build Coastguard Worker test_tensor(indices, values, False, True) # An empty tensor's data_ptr is always equal to 0 2922*da0073e9SAndroid Build Coastguard Worker 2923*da0073e9SAndroid Build Coastguard Worker # complex support 2924*da0073e9SAndroid Build Coastguard Worker indices = torch.tensor(([0], [2]), dtype=torch.int64) 2925*da0073e9SAndroid Build Coastguard Worker values = make_tensor([1, ], dtype=torch.cdouble, device=device) 2926*da0073e9SAndroid Build Coastguard Worker test_tensor(indices, values, True, False) 2927*da0073e9SAndroid Build Coastguard Worker 2928*da0073e9SAndroid Build Coastguard Worker indices = torch.tensor(([0], [2]), dtype=torch.int32) 2929*da0073e9SAndroid Build Coastguard Worker values = make_tensor([1, 1], dtype=torch.cdouble, device=device) 2930*da0073e9SAndroid Build Coastguard Worker test_tensor(indices, values, False, False) 2931*da0073e9SAndroid Build Coastguard Worker 2932*da0073e9SAndroid Build Coastguard Worker @onlyCPU # just run once, we test both cpu and cuda 2933*da0073e9SAndroid Build Coastguard Worker def test_legacy_new_device(self, device): 2934*da0073e9SAndroid Build Coastguard Worker i = torch.tensor([[0, 1, 1], [2, 0, 2]]) 2935*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([3., 4., 5.]) 2936*da0073e9SAndroid Build Coastguard Worker size = torch.Size([2, 3]) 2937*da0073e9SAndroid Build Coastguard Worker 2938*da0073e9SAndroid Build Coastguard Worker x = torch.sparse_coo_tensor(i, v, size, device='cpu') 2939*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: x.new(device='cuda')) 2940*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: x.new(i, v, device='cuda')) 2941*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: x.new(i, v, size, device='cuda')) 2942*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: x.new(torch.Size([2, 3, 4]), device='cuda')) 2943*da0073e9SAndroid Build Coastguard Worker 2944*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 2945*da0073e9SAndroid Build Coastguard Worker x = torch.sparse_coo_tensor(i, v, size, device='cuda') 2946*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: x.new(device='cpu')) 2947*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: x.new(i, v, device='cpu')) 2948*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: x.new(i, v, size, device='cpu')) 2949*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: x.new(torch.Size([2, 3, 4]), device='cpu')) 2950*da0073e9SAndroid Build Coastguard Worker 2951*da0073e9SAndroid Build Coastguard Worker def test_legacy_new(self, device): 2952*da0073e9SAndroid Build Coastguard Worker i = torch.tensor([[0, 1, 1], [2, 0, 2]]) 2953*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([3., 4., 5.]) 2954*da0073e9SAndroid Build Coastguard Worker size = torch.Size([2, 3]) 2955*da0073e9SAndroid Build Coastguard Worker s = torch.sparse_coo_tensor(i, v, size) 2956*da0073e9SAndroid Build Coastguard Worker 2957*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.sparse_coo, s.new(device='cpu').layout) 2958*da0073e9SAndroid Build Coastguard Worker self.assertRaises(TypeError, lambda: s.new(v.untyped_storage())) 2959*da0073e9SAndroid Build Coastguard Worker self.assertRaises(TypeError, lambda: s.new(v)) 2960*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.sparse_coo, s.new(torch.Size([2, 3])).layout) 2961*da0073e9SAndroid Build Coastguard Worker self.assertRaises(TypeError, lambda: s.new([6])) 2962*da0073e9SAndroid Build Coastguard Worker 2963*da0073e9SAndroid Build Coastguard Worker @onlyCPU # not really, but we only really want to run this once 2964*da0073e9SAndroid Build Coastguard Worker def test_dtypes(self, device): 2965*da0073e9SAndroid Build Coastguard Worker all_sparse_dtypes = all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16) 2966*da0073e9SAndroid Build Coastguard Worker do_test_dtypes(self, all_sparse_dtypes, torch.sparse_coo, torch.device('cpu')) 2967*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 2968*da0073e9SAndroid Build Coastguard Worker do_test_dtypes(self, all_sparse_dtypes, torch.sparse_coo, torch.device('cuda:0')) 2969*da0073e9SAndroid Build Coastguard Worker 2970*da0073e9SAndroid Build Coastguard Worker def _test_empty_full(self, device, dtype, requires_grad): 2971*da0073e9SAndroid Build Coastguard Worker shape = (2, 3) 2972*da0073e9SAndroid Build Coastguard Worker layout = torch.sparse_coo 2973*da0073e9SAndroid Build Coastguard Worker 2974*da0073e9SAndroid Build Coastguard Worker def check_value(tensor, value=None, dtype=dtype, requires_grad=requires_grad): 2975*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shape, tensor.shape) 2976*da0073e9SAndroid Build Coastguard Worker self.assertIs(dtype, tensor.dtype) 2977*da0073e9SAndroid Build Coastguard Worker self.assertIs(layout, tensor.layout) 2978*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.requires_grad, requires_grad) 2979*da0073e9SAndroid Build Coastguard Worker if tensor.is_cuda and device is not None: 2980*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device, tensor.device) 2981*da0073e9SAndroid Build Coastguard Worker if value is not None: 2982*da0073e9SAndroid Build Coastguard Worker fill = tensor.empty(shape, dtype=dtype).fill_(value) 2983*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor, fill) 2984*da0073e9SAndroid Build Coastguard Worker 2985*da0073e9SAndroid Build Coastguard Worker v = torch.sparse_coo_tensor(shape, dtype=dtype, device=device, requires_grad=requires_grad) 2986*da0073e9SAndroid Build Coastguard Worker check_value(v) 2987*da0073e9SAndroid Build Coastguard Worker 2988*da0073e9SAndroid Build Coastguard Worker out = v.new() 2989*da0073e9SAndroid Build Coastguard Worker check_value(torch.zeros(shape, out=out, device=device, requires_grad=requires_grad)) 2990*da0073e9SAndroid Build Coastguard Worker 2991*da0073e9SAndroid Build Coastguard Worker int64_dtype = torch.int64 2992*da0073e9SAndroid Build Coastguard Worker check_value(v.new_empty(shape), requires_grad=False) 2993*da0073e9SAndroid Build Coastguard Worker check_value(v.new_empty(shape, dtype=int64_dtype, device=device, requires_grad=False), 2994*da0073e9SAndroid Build Coastguard Worker dtype=int64_dtype, requires_grad=False) 2995*da0073e9SAndroid Build Coastguard Worker check_value(torch.empty_like(v), requires_grad=False) 2996*da0073e9SAndroid Build Coastguard Worker check_value(torch.empty_like(v, dtype=int64_dtype, layout=layout, device=device, requires_grad=False), 2997*da0073e9SAndroid Build Coastguard Worker dtype=int64_dtype, requires_grad=False) 2998*da0073e9SAndroid Build Coastguard Worker 2999*da0073e9SAndroid Build Coastguard Worker @onlyCPU # not really, but we only really want to run this once 3000*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 3001*da0073e9SAndroid Build Coastguard Worker @parametrize('requires_grad', (True, False)) 3002*da0073e9SAndroid Build Coastguard Worker def test_empty_full(self, device, dtype, requires_grad): 3003*da0073e9SAndroid Build Coastguard Worker if requires_grad and not (dtype.is_floating_point or dtype.is_complex): 3004*da0073e9SAndroid Build Coastguard Worker self.skipTest(f'requires_grad==True requires float or complex dtype, got {dtype}') 3005*da0073e9SAndroid Build Coastguard Worker 3006*da0073e9SAndroid Build Coastguard Worker self._test_empty_full(device, dtype, requires_grad) 3007*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 3008*da0073e9SAndroid Build Coastguard Worker self._test_empty_full(None, dtype, requires_grad) 3009*da0073e9SAndroid Build Coastguard Worker self._test_empty_full(torch.device('cuda:0'), dtype, requires_grad) 3010*da0073e9SAndroid Build Coastguard Worker 3011*da0073e9SAndroid Build Coastguard Worker def test_is_sparse(self, device): 3012*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3) 3013*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.is_sparse) 3014*da0073e9SAndroid Build Coastguard Worker 3015*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3, 0) 3016*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.is_sparse) 3017*da0073e9SAndroid Build Coastguard Worker 3018*da0073e9SAndroid Build Coastguard Worker x = self.sparse_empty(1, 0, device=device) 3019*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.is_sparse) 3020*da0073e9SAndroid Build Coastguard Worker 3021*da0073e9SAndroid Build Coastguard Worker def test_resize_as(self, device): 3022*da0073e9SAndroid Build Coastguard Worker def do_test(t): 3023*da0073e9SAndroid Build Coastguard Worker y = t.new().resize_as_(t).zero_() 3024*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.shape, t.shape) 3025*da0073e9SAndroid Build Coastguard Worker # Check that y can be added to t. Currently, this requires that 3026*da0073e9SAndroid Build Coastguard Worker # sparse_dim and dense_dim match. 3027*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, t + y) 3028*da0073e9SAndroid Build Coastguard Worker 3029*da0073e9SAndroid Build Coastguard Worker do_test(self.sparse_empty([3, 0], device=device)) 3030*da0073e9SAndroid Build Coastguard Worker do_test(self.sparse_empty([3, 3], device=device)) 3031*da0073e9SAndroid Build Coastguard Worker 3032*da0073e9SAndroid Build Coastguard Worker def _test_resize_shape(self, x_i, x_v, x_size, y_i, y_v, y_size, dtype, device): 3033*da0073e9SAndroid Build Coastguard Worker x_v_numel = torch.zeros(x_v).numel() 3034*da0073e9SAndroid Build Coastguard Worker x = torch.sparse_coo_tensor(torch.zeros(x_i), 3035*da0073e9SAndroid Build Coastguard Worker torch.arange(x_v_numel).resize_(x_v).to(torch.float), 3036*da0073e9SAndroid Build Coastguard Worker torch.Size(x_size), dtype=dtype, device=device) 3037*da0073e9SAndroid Build Coastguard Worker x_dense = x.to_dense() 3038*da0073e9SAndroid Build Coastguard Worker y = torch.sparse_coo_tensor(torch.zeros(y_i), 3039*da0073e9SAndroid Build Coastguard Worker torch.ones(y_v).to(torch.float), 3040*da0073e9SAndroid Build Coastguard Worker torch.Size(y_size), dtype=dtype, device=device) 3041*da0073e9SAndroid Build Coastguard Worker y_dense = y.to_dense() 3042*da0073e9SAndroid Build Coastguard Worker x.resize_as_(y) 3043*da0073e9SAndroid Build Coastguard Worker x_dense.resize_as_(y_dense) 3044*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.shape, y.shape) 3045*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.sparse_dim(), y.sparse_dim()) 3046*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.dense_dim(), y.dense_dim()) 3047*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.shape, x_dense.shape) 3048*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.shape, y_dense.shape) 3049*da0073e9SAndroid Build Coastguard Worker # Here we make sure that the original data are preserved after resizing 3050*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.to_dense().view(-1)[0:x_v_numel].view(x_v), 3051*da0073e9SAndroid Build Coastguard Worker x_dense.view(-1)[0:x_v_numel].view(x_v)) 3052*da0073e9SAndroid Build Coastguard Worker 3053*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 3054*da0073e9SAndroid Build Coastguard Worker def test_resize(self, device, dtype): 3055*da0073e9SAndroid Build Coastguard Worker # 1. Expand the size of some dense dimensions [Supported] 3056*da0073e9SAndroid Build Coastguard Worker self._test_resize_shape([1, 1], [1, 2, 3], [2, 2, 3], 3057*da0073e9SAndroid Build Coastguard Worker [1, 1], [1, 2, 4], [2, 2, 4], 3058*da0073e9SAndroid Build Coastguard Worker dtype=dtype, device=device) 3059*da0073e9SAndroid Build Coastguard Worker 3060*da0073e9SAndroid Build Coastguard Worker self._test_resize_shape([1, 1], [1, 2, 0], [2, 2, 0], 3061*da0073e9SAndroid Build Coastguard Worker [1, 1], [1, 2, 4], [2, 2, 4], 3062*da0073e9SAndroid Build Coastguard Worker dtype=dtype, device=device) 3063*da0073e9SAndroid Build Coastguard Worker 3064*da0073e9SAndroid Build Coastguard Worker # 2. Expand the size of some sparse dimensions [Supported] 3065*da0073e9SAndroid Build Coastguard Worker self._test_resize_shape([1, 1], [1, 2, 3], [2, 2, 3], 3066*da0073e9SAndroid Build Coastguard Worker [1, 1], [1, 2, 3], [4, 2, 3], 3067*da0073e9SAndroid Build Coastguard Worker dtype=dtype, device=device) 3068*da0073e9SAndroid Build Coastguard Worker 3069*da0073e9SAndroid Build Coastguard Worker # 3. Change the shapes of both sparse and dense dimensions when nnz is zero [Supported] 3070*da0073e9SAndroid Build Coastguard Worker self._test_resize_shape([1, 0], [0, 2, 3], [2, 2, 3], 3071*da0073e9SAndroid Build Coastguard Worker [2, 0], [0, 2, 4, 5], [1, 1, 2, 4, 5], 3072*da0073e9SAndroid Build Coastguard Worker dtype=dtype, device=device) 3073*da0073e9SAndroid Build Coastguard Worker 3074*da0073e9SAndroid Build Coastguard Worker self._test_resize_shape([1, 0], [0, 2, 3], [2, 2, 3], 3075*da0073e9SAndroid Build Coastguard Worker [2, 0], [0, 2, 4, 0], [1, 1, 2, 4, 0], 3076*da0073e9SAndroid Build Coastguard Worker dtype=dtype, device=device) 3077*da0073e9SAndroid Build Coastguard Worker 3078*da0073e9SAndroid Build Coastguard Worker # 4. Add dims to dense dimensions [Not Supported] 3079*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "changing the number of dense dimensions"): 3080*da0073e9SAndroid Build Coastguard Worker self._test_resize_shape([1, 1], [1, 2, 3], [2, 2, 3], 3081*da0073e9SAndroid Build Coastguard Worker [1, 1], [1, 2, 3, 4], [2, 2, 3, 4], 3082*da0073e9SAndroid Build Coastguard Worker dtype=dtype, device=device) 3083*da0073e9SAndroid Build Coastguard Worker 3084*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "changing the number of dense dimensions"): 3085*da0073e9SAndroid Build Coastguard Worker self._test_resize_shape([1, 1], [1, 2, 3], [2, 2, 3], 3086*da0073e9SAndroid Build Coastguard Worker [1, 1], [1, 2, 3, 0], [2, 2, 3, 0], 3087*da0073e9SAndroid Build Coastguard Worker dtype=dtype, device=device) 3088*da0073e9SAndroid Build Coastguard Worker 3089*da0073e9SAndroid Build Coastguard Worker # 5. Remove dims from dense dimensions [Not Supported] 3090*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "changing the number of dense dimensions"): 3091*da0073e9SAndroid Build Coastguard Worker self._test_resize_shape([1, 1], [1, 2, 3], [2, 2, 3], 3092*da0073e9SAndroid Build Coastguard Worker [1, 1], [1, 2], [2, 2], 3093*da0073e9SAndroid Build Coastguard Worker dtype=dtype, device=device) 3094*da0073e9SAndroid Build Coastguard Worker 3095*da0073e9SAndroid Build Coastguard Worker # 6. Change the number of sparse dimensions on a non-empty sparse tensor [Not Supported] 3096*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "changing the number of sparse dimensions"): 3097*da0073e9SAndroid Build Coastguard Worker self._test_resize_shape([1, 1], [1, 2, 3], [2, 2, 3], 3098*da0073e9SAndroid Build Coastguard Worker [2, 1], [1, 2, 3], [1, 2, 2, 3], 3099*da0073e9SAndroid Build Coastguard Worker dtype=dtype, device=device) 3100*da0073e9SAndroid Build Coastguard Worker 3101*da0073e9SAndroid Build Coastguard Worker # 7. Shrink the size of some sparse dimensions on a non-empty sparse tensor [Not Supported] 3102*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "shrinking the size of sparse dimensions"): 3103*da0073e9SAndroid Build Coastguard Worker self._test_resize_shape([1, 1], [1, 2, 3], [2, 2, 3], 3104*da0073e9SAndroid Build Coastguard Worker [1, 1], [1, 2, 3], [1, 2, 3], 3105*da0073e9SAndroid Build Coastguard Worker dtype=dtype, device=device) 3106*da0073e9SAndroid Build Coastguard Worker 3107*da0073e9SAndroid Build Coastguard Worker # 8. Shrink the size of some dense dimensions on a non-empty sparse tensor [Not Supported] 3108*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "shrinking the size of dense dimensions"): 3109*da0073e9SAndroid Build Coastguard Worker self._test_resize_shape([1, 1], [1, 2, 3], [2, 2, 3], 3110*da0073e9SAndroid Build Coastguard Worker [1, 1], [1, 2, 2], [2, 2, 2], 3111*da0073e9SAndroid Build Coastguard Worker dtype=dtype, device=device) 3112*da0073e9SAndroid Build Coastguard Worker 3113*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "shrinking the size of dense dimensions"): 3114*da0073e9SAndroid Build Coastguard Worker self._test_resize_shape([1, 1], [1, 2, 3], [2, 2, 3], 3115*da0073e9SAndroid Build Coastguard Worker [1, 1], [1, 2, 0], [2, 2, 0], 3116*da0073e9SAndroid Build Coastguard Worker dtype=dtype, device=device) 3117*da0073e9SAndroid Build Coastguard Worker 3118*da0073e9SAndroid Build Coastguard Worker def test_is_nonzero(self, device): 3119*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.sparse_coo_tensor(([0],), 1., (1,), device=device).is_nonzero()) 3120*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.sparse_coo_tensor(([0],), 0., (1,), device=device).is_nonzero()) 3121*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.sparse_coo_tensor(([0], [0]), 0., (1, 1), device=device).is_nonzero()) 3122*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.sparse_coo_tensor(([0, 0],), (0., 0.), (1,), device=device).is_nonzero()) 3123*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.sparse_coo_tensor(([0, 0],), (-1., 1.), (1,), device=device).is_nonzero()) 3124*da0073e9SAndroid Build Coastguard Worker 3125*da0073e9SAndroid Build Coastguard Worker # scalar sparse tensor 3126*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.sparse_coo_tensor(torch.zeros(0, 1), 12.3, [], device=device).is_nonzero()) 3127*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with no values is ambiguous"): 3128*da0073e9SAndroid Build Coastguard Worker torch.sparse_coo_tensor(([0, 1],), torch.empty(2, 0), (4, 0), device=device).is_nonzero() 3129*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.sparse_coo_tensor(([0],), 2.3 - 4.5j, (1,), dtype=torch.cfloat, device=device) 3130*da0073e9SAndroid Build Coastguard Worker .is_nonzero()) 3131*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.sparse_coo_tensor(([0],), 2.3 - 4.5j, (1,), dtype=torch.cdouble, device=device) 3132*da0073e9SAndroid Build Coastguard Worker .is_nonzero()) 3133*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.sparse_coo_tensor(([0],), 0. + 0j, (1,), dtype=torch.cfloat, device=device) 3134*da0073e9SAndroid Build Coastguard Worker .is_nonzero()) 3135*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.sparse_coo_tensor(([0],), 0. + 0j, (1,), dtype=torch.cdouble, device=device) 3136*da0073e9SAndroid Build Coastguard Worker .is_nonzero()) 3137*da0073e9SAndroid Build Coastguard Worker 3138*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 3139*da0073e9SAndroid Build Coastguard Worker def test_change_tensor_metadata(self, device, dtype): 3140*da0073e9SAndroid Build Coastguard Worker i = self.index_tensor([[0], [1]], device=device) 3141*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([[3, 4, 5]], dtype=dtype, device=device) 3142*da0073e9SAndroid Build Coastguard Worker t = torch.sparse_coo_tensor(i, v, torch.Size([1, 2, 3]), dtype=dtype, device=device) 3143*da0073e9SAndroid Build Coastguard Worker i.resize_(2, 3) 3144*da0073e9SAndroid Build Coastguard Worker v.resize_(4, 5) 3145*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(t.coalesce().indices().size()), [2, 1]) 3146*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(t.coalesce().values().size()), [1, 3]) 3147*da0073e9SAndroid Build Coastguard Worker 3148*da0073e9SAndroid Build Coastguard Worker i = self.index_tensor([[0], [1]], device=device) 3149*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([[3, 4, 5]], dtype=dtype, device=device) 3150*da0073e9SAndroid Build Coastguard Worker t = torch.sparse_coo_tensor(i, v, torch.Size([1, 2, 3])) 3151*da0073e9SAndroid Build Coastguard Worker i.resize_as_(self.index_tensor([0, 1], device=device)) 3152*da0073e9SAndroid Build Coastguard Worker v.resize_as_(torch.tensor([3, 4, 5], dtype=dtype, device=device)) 3153*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(t.coalesce().indices().size()), [2, 1]) 3154*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(t.coalesce().values().size()), [1, 3]) 3155*da0073e9SAndroid Build Coastguard Worker 3156*da0073e9SAndroid Build Coastguard Worker i = self.index_tensor([[0], [1]], device=device) 3157*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([[3, 4, 5]], dtype=dtype, device=device) 3158*da0073e9SAndroid Build Coastguard Worker t = torch.sparse_coo_tensor(i, v, torch.Size([1, 2, 3])) 3159*da0073e9SAndroid Build Coastguard Worker i.as_strided_((2, 1), (1, 1)) 3160*da0073e9SAndroid Build Coastguard Worker v.as_strided_((1, 3), (1, 1)) 3161*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(t.coalesce().indices().size()), [2, 1]) 3162*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(t.coalesce().values().size()), [1, 3]) 3163*da0073e9SAndroid Build Coastguard Worker 3164*da0073e9SAndroid Build Coastguard Worker i = self.index_tensor([[0], [1]], device=device) 3165*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([[3, 4, 5]], dtype=dtype, device=device) 3166*da0073e9SAndroid Build Coastguard Worker t = torch.sparse_coo_tensor(i, v, torch.Size([1, 2, 3])) 3167*da0073e9SAndroid Build Coastguard Worker i.set_(self.index_tensor([0, 1], device=device)) 3168*da0073e9SAndroid Build Coastguard Worker v.set_(torch.tensor([3, 4, 5], dtype=dtype, device=device)) 3169*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(t.coalesce().indices().size()), [2, 1]) 3170*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(t.coalesce().values().size()), [1, 3]) 3171*da0073e9SAndroid Build Coastguard Worker 3172*da0073e9SAndroid Build Coastguard Worker i = self.index_tensor([[0], [1]], device=device) 3173*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([[3, 4, 5]], dtype=dtype, device=device) 3174*da0073e9SAndroid Build Coastguard Worker t = torch.sparse_coo_tensor(i, v, torch.Size([1, 2, 3])) 3175*da0073e9SAndroid Build Coastguard Worker i.transpose_(0, 1) 3176*da0073e9SAndroid Build Coastguard Worker v.transpose_(0, 1) 3177*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(t.coalesce().indices().size()), [2, 1]) 3178*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(t.coalesce().values().size()), [1, 3]) 3179*da0073e9SAndroid Build Coastguard Worker 3180*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 3181*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 3182*da0073e9SAndroid Build Coastguard Worker def test_pickle(self, device, dtype, coalesced): 3183*da0073e9SAndroid Build Coastguard Worker import pickle 3184*da0073e9SAndroid Build Coastguard Worker 3185*da0073e9SAndroid Build Coastguard Worker shape_sparse_dim_nnz = [ 3186*da0073e9SAndroid Build Coastguard Worker ((), 0, 2), 3187*da0073e9SAndroid Build Coastguard Worker ((0,), 0, 10), 3188*da0073e9SAndroid Build Coastguard Worker ((2,), 0, 3), 3189*da0073e9SAndroid Build Coastguard Worker ((100, 3), 1, 3), 3190*da0073e9SAndroid Build Coastguard Worker ((100, 20, 3), 2, 0), 3191*da0073e9SAndroid Build Coastguard Worker ((10, 0, 3), 0, 3), 3192*da0073e9SAndroid Build Coastguard Worker ((10, 0, 3), 0, 0), 3193*da0073e9SAndroid Build Coastguard Worker ] 3194*da0073e9SAndroid Build Coastguard Worker 3195*da0073e9SAndroid Build Coastguard Worker for shape, sparse_dim, nnz in shape_sparse_dim_nnz: 3196*da0073e9SAndroid Build Coastguard Worker indices_shape = torch.Size((sparse_dim, nnz)) 3197*da0073e9SAndroid Build Coastguard Worker values_shape = torch.Size((nnz,) + shape[sparse_dim:]) 3198*da0073e9SAndroid Build Coastguard Worker indices = torch.arange(indices_shape.numel(), dtype=self.index_tensor(0).dtype, 3199*da0073e9SAndroid Build Coastguard Worker device=device).view(indices_shape) 3200*da0073e9SAndroid Build Coastguard Worker for d in range(sparse_dim): 3201*da0073e9SAndroid Build Coastguard Worker indices[d].clamp_(max=(shape[d] - 1)) # make it valid index 3202*da0073e9SAndroid Build Coastguard Worker if not coalesced and indices.numel() > 0: 3203*da0073e9SAndroid Build Coastguard Worker indices[:, -1] = indices[:, 0] # make it uncoalesced 3204*da0073e9SAndroid Build Coastguard Worker values_numel = values_shape.numel() 3205*da0073e9SAndroid Build Coastguard Worker values = torch.arange(values_numel, dtype=dtype, 3206*da0073e9SAndroid Build Coastguard Worker device=device).view(values_shape).div_(values_numel / 2.) 3207*da0073e9SAndroid Build Coastguard Worker sp_tensor = self.sparse_tensor(indices, values, shape) 3208*da0073e9SAndroid Build Coastguard Worker serialized = pickle.dumps(sp_tensor) 3209*da0073e9SAndroid Build Coastguard Worker sp_tensor_loaded = pickle.loads(serialized) 3210*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sp_tensor, sp_tensor_loaded) 3211*da0073e9SAndroid Build Coastguard Worker 3212*da0073e9SAndroid Build Coastguard Worker def test_any(self, device): 3213*da0073e9SAndroid Build Coastguard Worker t = torch.sparse_coo_tensor(torch.tensor(([0, 0], [2, 0])), torch.tensor([False, False]), device=device) 3214*da0073e9SAndroid Build Coastguard Worker t_any = torch.tensor(False) 3215*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.any(t), t_any) 3216*da0073e9SAndroid Build Coastguard Worker t = torch.sparse_coo_tensor(torch.tensor(([0, 0], [2, 0])), torch.tensor([True, False]), device=device) 3217*da0073e9SAndroid Build Coastguard Worker t_any = torch.tensor(True) 3218*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.any(t), t_any) 3219*da0073e9SAndroid Build Coastguard Worker 3220*da0073e9SAndroid Build Coastguard Worker def test_isnan(self, device): 3221*da0073e9SAndroid Build Coastguard Worker t = torch.sparse_coo_tensor(torch.tensor(([0, 0], [0, 2])), torch.tensor([1, 4]), device=device) 3222*da0073e9SAndroid Build Coastguard Worker t_nan = torch.sparse_coo_tensor(torch.tensor(([0, 0], [0, 2])), torch.tensor([False, False]), device=device) 3223*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.isnan(t).int(), t_nan.int()) 3224*da0073e9SAndroid Build Coastguard Worker t = torch.sparse_coo_tensor(torch.tensor(([0, 0], [0, 2])), torch.tensor([1, float("nan")]), device=device) 3225*da0073e9SAndroid Build Coastguard Worker t_nan = torch.sparse_coo_tensor(torch.tensor(([0, 0], [0, 2])), torch.tensor([False, True]), device=device) 3226*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.isnan(t).int(), t_nan.int()) 3227*da0073e9SAndroid Build Coastguard Worker 3228*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 3229*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.float64) 3230*da0073e9SAndroid Build Coastguard Worker def test_div_rounding_mode(self, device, dtype, coalesced): 3231*da0073e9SAndroid Build Coastguard Worker sparse, _, _ = self._gen_sparse(2, 10, (10, 10), dtype, 3232*da0073e9SAndroid Build Coastguard Worker device, coalesced) 3233*da0073e9SAndroid Build Coastguard Worker dense = self.safeToDense(sparse) 3234*da0073e9SAndroid Build Coastguard Worker 3235*da0073e9SAndroid Build Coastguard Worker for mode in (None, 'floor', 'trunc'): 3236*da0073e9SAndroid Build Coastguard Worker actual = sparse.div(-2, rounding_mode=mode) 3237*da0073e9SAndroid Build Coastguard Worker expect = dense.div(-2, rounding_mode=mode) 3238*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.safeToDense(actual), expect) 3239*da0073e9SAndroid Build Coastguard Worker 3240*da0073e9SAndroid Build Coastguard Worker # Test inplace 3241*da0073e9SAndroid Build Coastguard Worker actual = sparse.clone().div_(-2, rounding_mode=mode) 3242*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.safeToDense(actual), expect) 3243*da0073e9SAndroid Build Coastguard Worker 3244*da0073e9SAndroid Build Coastguard Worker # Test out argument 3245*da0073e9SAndroid Build Coastguard Worker actual.zero_() 3246*da0073e9SAndroid Build Coastguard Worker torch.div(sparse, -2, rounding_mode=mode, out=actual) 3247*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.safeToDense(actual), expect) 3248*da0073e9SAndroid Build Coastguard Worker 3249*da0073e9SAndroid Build Coastguard Worker def test_div_by_sparse_error(self, device): 3250*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 'Sparse division requires', 3251*da0073e9SAndroid Build Coastguard Worker lambda: torch.tensor(1., device=device).to_sparse() 3252*da0073e9SAndroid Build Coastguard Worker / torch.tensor(1., device=device).to_sparse()) 3253*da0073e9SAndroid Build Coastguard Worker 3254*da0073e9SAndroid Build Coastguard Worker def test_floor_divide_by_sparse_error(self, device): 3255*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 'Sparse floor division requires', 3256*da0073e9SAndroid Build Coastguard Worker lambda: torch.tensor(1., device=device).to_sparse() 3257*da0073e9SAndroid Build Coastguard Worker // torch.tensor(1., device=device).to_sparse()) 3258*da0073e9SAndroid Build Coastguard Worker 3259*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "Numpy not found") 3260*da0073e9SAndroid Build Coastguard Worker @onlyCPU 3261*da0073e9SAndroid Build Coastguard Worker def test_sparse_to_numpy(self, device): 3262*da0073e9SAndroid Build Coastguard Worker t = torch.sparse_coo_tensor(torch.tensor(([0, 0], [2, 0])), torch.tensor([1, 4])) 3263*da0073e9SAndroid Build Coastguard Worker self.assertRaises(TypeError, lambda: t.numpy()) 3264*da0073e9SAndroid Build Coastguard Worker 3265*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 3266*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 3267*da0073e9SAndroid Build Coastguard Worker def test_softmax(self, device, dtype, coalesced): 3268*da0073e9SAndroid Build Coastguard Worker import torch.nn.functional as F 3269*da0073e9SAndroid Build Coastguard Worker 3270*da0073e9SAndroid Build Coastguard Worker def to_dense(sparse, fill_value=None): 3271*da0073e9SAndroid Build Coastguard Worker """ 3272*da0073e9SAndroid Build Coastguard Worker Return dense tensor from a sparse tensor using given fill value. 3273*da0073e9SAndroid Build Coastguard Worker """ 3274*da0073e9SAndroid Build Coastguard Worker if fill_value is None or fill_value == 0: 3275*da0073e9SAndroid Build Coastguard Worker return sparse.to_dense() 3276*da0073e9SAndroid Build Coastguard Worker sparse = sparse.coalesce() 3277*da0073e9SAndroid Build Coastguard Worker dense = torch.full(sparse.shape, fill_value, dtype=sparse.dtype, device=sparse.device) 3278*da0073e9SAndroid Build Coastguard Worker for idx, value in zip(sparse._indices().t(), sparse._values()): 3279*da0073e9SAndroid Build Coastguard Worker dense[tuple(idx)] = value 3280*da0073e9SAndroid Build Coastguard Worker return dense 3281*da0073e9SAndroid Build Coastguard Worker 3282*da0073e9SAndroid Build Coastguard Worker def softmax_to_dense(sparse, dim): 3283*da0073e9SAndroid Build Coastguard Worker """Dense softmax of a sparse tensor. Useful only for testing softmax 3284*da0073e9SAndroid Build Coastguard Worker correctness. 3285*da0073e9SAndroid Build Coastguard Worker 3286*da0073e9SAndroid Build Coastguard Worker When computing softmax of a sparse tensor, the value of 3287*da0073e9SAndroid Build Coastguard Worker unspecified items is negative infinity rather than zero so 3288*da0073e9SAndroid Build Coastguard Worker that 3289*da0073e9SAndroid Build Coastguard Worker 3290*da0073e9SAndroid Build Coastguard Worker softmax(sparse.to_dense(fill_value=-inf), dim) == softmax(sparse, dim).to_dense() 3291*da0073e9SAndroid Build Coastguard Worker 3292*da0073e9SAndroid Build Coastguard Worker holds for non-empty lines. One empty lines, the softmax 3293*da0073e9SAndroid Build Coastguard Worker values are defined as 0 in order to preserve the sparsity 3294*da0073e9SAndroid Build Coastguard Worker of result. 3295*da0073e9SAndroid Build Coastguard Worker 3296*da0073e9SAndroid Build Coastguard Worker Note that in PyTorch, ``to_dense`` method does not 3297*da0073e9SAndroid Build Coastguard Worker implement the ``fill_value`` keyword argument. 3298*da0073e9SAndroid Build Coastguard Worker """ 3299*da0073e9SAndroid Build Coastguard Worker dtype = sparse.dtype 3300*da0073e9SAndroid Build Coastguard Worker device = sparse.device 3301*da0073e9SAndroid Build Coastguard Worker dense = to_dense(sparse, fill_value=-float('inf')) 3302*da0073e9SAndroid Build Coastguard Worker r = F.softmax(dense, dim) 3303*da0073e9SAndroid Build Coastguard Worker # softmax on empty lines results nan, replace with zeros to match the definition 3304*da0073e9SAndroid Build Coastguard Worker r[r != r] = 0 3305*da0073e9SAndroid Build Coastguard Worker return r 3306*da0073e9SAndroid Build Coastguard Worker 3307*da0073e9SAndroid Build Coastguard Worker def sparse_softmax(sparse, dim): 3308*da0073e9SAndroid Build Coastguard Worker """Pure Python softmax of a sparse tensor. Assuming -inf for 3309*da0073e9SAndroid Build Coastguard Worker unspecified sparse tensor data. This is a prototype of 3310*da0073e9SAndroid Build Coastguard Worker sparse softmax algorithm in Python. 3311*da0073e9SAndroid Build Coastguard Worker """ 3312*da0073e9SAndroid Build Coastguard Worker dtype = sparse.dtype 3313*da0073e9SAndroid Build Coastguard Worker device = sparse.device 3314*da0073e9SAndroid Build Coastguard Worker 3315*da0073e9SAndroid Build Coastguard Worker # softmax is non-linear operation, so sparse tensors must 3316*da0073e9SAndroid Build Coastguard Worker # be coalesced. 3317*da0073e9SAndroid Build Coastguard Worker sparse = sparse.coalesce() 3318*da0073e9SAndroid Build Coastguard Worker inf = float('inf') 3319*da0073e9SAndroid Build Coastguard Worker indices = sparse._indices() 3320*da0073e9SAndroid Build Coastguard Worker values = sparse._values() 3321*da0073e9SAndroid Build Coastguard Worker 3322*da0073e9SAndroid Build Coastguard Worker if dim < sparse.sparse_dim(): 3323*da0073e9SAndroid Build Coastguard Worker nnz = sparse._nnz() 3324*da0073e9SAndroid Build Coastguard Worker 3325*da0073e9SAndroid Build Coastguard Worker # compute pool indices 3326*da0073e9SAndroid Build Coastguard Worker size = sparse.size() 3327*da0073e9SAndroid Build Coastguard Worker strides = torch.ones((sparse.sparse_dim(), 1), dtype=indices.dtype, device=indices.device) 3328*da0073e9SAndroid Build Coastguard Worker for i in reversed(range(sparse.sparse_dim() - 1)): 3329*da0073e9SAndroid Build Coastguard Worker strides[i, 0] = strides[i + 1, 0] * size[i + 1] 3330*da0073e9SAndroid Build Coastguard Worker strides[dim, 0] = 0 3331*da0073e9SAndroid Build Coastguard Worker 3332*da0073e9SAndroid Build Coastguard Worker pool = (indices * strides).sum(dim=0) 3333*da0073e9SAndroid Build Coastguard Worker i2p = {} 3334*da0073e9SAndroid Build Coastguard Worker for i in range(nnz): 3335*da0073e9SAndroid Build Coastguard Worker c = int(pool[i]) 3336*da0073e9SAndroid Build Coastguard Worker if c not in i2p: 3337*da0073e9SAndroid Build Coastguard Worker i2p[c] = len(i2p) 3338*da0073e9SAndroid Build Coastguard Worker pool[i] = i2p[c] 3339*da0073e9SAndroid Build Coastguard Worker 3340*da0073e9SAndroid Build Coastguard Worker # compute max 3341*da0073e9SAndroid Build Coastguard Worker dense_size = tuple(size[sparse.sparse_dim():]) 3342*da0073e9SAndroid Build Coastguard Worker mx = torch.empty((pool.max() + 1,) + dense_size, dtype=dtype, device=device) 3343*da0073e9SAndroid Build Coastguard Worker mx[:] = -inf 3344*da0073e9SAndroid Build Coastguard Worker for n in range(nnz): 3345*da0073e9SAndroid Build Coastguard Worker p = pool[n] 3346*da0073e9SAndroid Build Coastguard Worker mx[p] = torch.max(mx[p], values[n]) 3347*da0073e9SAndroid Build Coastguard Worker 3348*da0073e9SAndroid Build Coastguard Worker # apply exp to (v - mx) and sum the results 3349*da0073e9SAndroid Build Coastguard Worker exp_values = torch.empty_like(values) 3350*da0073e9SAndroid Build Coastguard Worker exp_sums = torch.zeros_like(mx) 3351*da0073e9SAndroid Build Coastguard Worker for n in range(nnz): 3352*da0073e9SAndroid Build Coastguard Worker p = pool[n] 3353*da0073e9SAndroid Build Coastguard Worker v = exp_values[n] = (values[n] - mx[p]).exp() 3354*da0073e9SAndroid Build Coastguard Worker exp_sums[p] = exp_sums[p] + v 3355*da0073e9SAndroid Build Coastguard Worker 3356*da0073e9SAndroid Build Coastguard Worker # normalize with the sum of exponents 3357*da0073e9SAndroid Build Coastguard Worker for n in range(nnz): 3358*da0073e9SAndroid Build Coastguard Worker p = pool[n] 3359*da0073e9SAndroid Build Coastguard Worker exp_values[n] = exp_values[n] / exp_sums[p] 3360*da0073e9SAndroid Build Coastguard Worker 3361*da0073e9SAndroid Build Coastguard Worker return torch.sparse_coo_tensor(indices, 3362*da0073e9SAndroid Build Coastguard Worker exp_values, 3363*da0073e9SAndroid Build Coastguard Worker sparse.size(), 3364*da0073e9SAndroid Build Coastguard Worker dtype=dtype, device=device) 3365*da0073e9SAndroid Build Coastguard Worker 3366*da0073e9SAndroid Build Coastguard Worker elif dim < sparse.sparse_dim() + sparse.dense_dim(): 3367*da0073e9SAndroid Build Coastguard Worker return torch.sparse_coo_tensor(indices, 3368*da0073e9SAndroid Build Coastguard Worker F.softmax(values, dim - sparse.sparse_dim() + 1), 3369*da0073e9SAndroid Build Coastguard Worker sparse.size(), 3370*da0073e9SAndroid Build Coastguard Worker dtype=dtype, device=device) 3371*da0073e9SAndroid Build Coastguard Worker else: 3372*da0073e9SAndroid Build Coastguard Worker raise ValueError( 3373*da0073e9SAndroid Build Coastguard Worker f'`dim(={dim})` must be smaller than `sparse_dim(={sparse.sparse_dim()}) + dense_dim(={sparse.dense_dim()})`') 3374*da0073e9SAndroid Build Coastguard Worker 3375*da0073e9SAndroid Build Coastguard Worker def softmax_jacobian_analytic(x, dim): 3376*da0073e9SAndroid Build Coastguard Worker """Return Jacobian of softmax using analytic formula 3377*da0073e9SAndroid Build Coastguard Worker 3378*da0073e9SAndroid Build Coastguard Worker D_jS_i = S_i * (1[i==j] - S_j). 3379*da0073e9SAndroid Build Coastguard Worker 3380*da0073e9SAndroid Build Coastguard Worker where S = softmax(x, dim), x is dense tensor, i,j in 3381*da0073e9SAndroid Build Coastguard Worker range(x.shape[dim]). 3382*da0073e9SAndroid Build Coastguard Worker """ 3383*da0073e9SAndroid Build Coastguard Worker y = F.softmax(x, dim) 3384*da0073e9SAndroid Build Coastguard Worker y[y != y] = 0 # replace nan-s with zeros 3385*da0073e9SAndroid Build Coastguard Worker J = torch.zeros((x.shape[dim],) + tuple(x.shape), dtype=x.dtype, device=x.device) 3386*da0073e9SAndroid Build Coastguard Worker si = [slice(None)] * len(y.shape) 3387*da0073e9SAndroid Build Coastguard Worker sj = [slice(None)] * len(y.shape) 3388*da0073e9SAndroid Build Coastguard Worker s = [slice(None)] * len(J.shape) 3389*da0073e9SAndroid Build Coastguard Worker for i in range(y.shape[dim]): 3390*da0073e9SAndroid Build Coastguard Worker si[dim] = i 3391*da0073e9SAndroid Build Coastguard Worker s[dim + 1] = i 3392*da0073e9SAndroid Build Coastguard Worker yi = y[tuple(si)] 3393*da0073e9SAndroid Build Coastguard Worker for j in range(y.shape[dim]): 3394*da0073e9SAndroid Build Coastguard Worker sj[dim] = j 3395*da0073e9SAndroid Build Coastguard Worker s[0] = j 3396*da0073e9SAndroid Build Coastguard Worker if i == j: 3397*da0073e9SAndroid Build Coastguard Worker J[tuple(s)] = yi * (1 - yi) 3398*da0073e9SAndroid Build Coastguard Worker else: 3399*da0073e9SAndroid Build Coastguard Worker yj = y[tuple(sj)] 3400*da0073e9SAndroid Build Coastguard Worker J[tuple(s)] = - yi * yj 3401*da0073e9SAndroid Build Coastguard Worker sj[dim] = slice(None) 3402*da0073e9SAndroid Build Coastguard Worker si[dim] = slice(None) 3403*da0073e9SAndroid Build Coastguard Worker s[dim + 1] = slice(None) 3404*da0073e9SAndroid Build Coastguard Worker return J 3405*da0073e9SAndroid Build Coastguard Worker 3406*da0073e9SAndroid Build Coastguard Worker def softmax_jacobian_autograd(x, dim, log=False): 3407*da0073e9SAndroid Build Coastguard Worker """Return Jacobian of softmax using PyTorch autograd feature. 3408*da0073e9SAndroid Build Coastguard Worker 3409*da0073e9SAndroid Build Coastguard Worker x can be dense or sparse tensor. 3410*da0073e9SAndroid Build Coastguard Worker """ 3411*da0073e9SAndroid Build Coastguard Worker import itertools 3412*da0073e9SAndroid Build Coastguard Worker 3413*da0073e9SAndroid Build Coastguard Worker if x.is_sparse: 3414*da0073e9SAndroid Build Coastguard Worker x = x.coalesce() 3415*da0073e9SAndroid Build Coastguard Worker 3416*da0073e9SAndroid Build Coastguard Worker dtype = x.dtype 3417*da0073e9SAndroid Build Coastguard Worker device = x.device 3418*da0073e9SAndroid Build Coastguard Worker shape = tuple(x.shape) 3419*da0073e9SAndroid Build Coastguard Worker J = torch.zeros((shape[dim],) + shape, dtype=dtype, device=device) 3420*da0073e9SAndroid Build Coastguard Worker for i in range(shape[dim]): 3421*da0073e9SAndroid Build Coastguard Worker if x.is_sparse: 3422*da0073e9SAndroid Build Coastguard Worker sparse_dim = x.sparse_dim() 3423*da0073e9SAndroid Build Coastguard Worker dense_dim = x.dense_dim() 3424*da0073e9SAndroid Build Coastguard Worker if dim < sparse_dim: 3425*da0073e9SAndroid Build Coastguard Worker ranges = [] 3426*da0073e9SAndroid Build Coastguard Worker for j, sz in enumerate(shape[:sparse_dim]): 3427*da0073e9SAndroid Build Coastguard Worker if dim == j: 3428*da0073e9SAndroid Build Coastguard Worker ranges.append([i]) 3429*da0073e9SAndroid Build Coastguard Worker else: 3430*da0073e9SAndroid Build Coastguard Worker ranges.append(list(range(sz))) 3431*da0073e9SAndroid Build Coastguard Worker indices = torch.tensor(list(itertools.product(*ranges)), dtype=torch.long, device=device).t() 3432*da0073e9SAndroid Build Coastguard Worker values = torch.ones((indices.shape[1],) + shape[sparse_dim:], dtype=dtype, device=device) 3433*da0073e9SAndroid Build Coastguard Worker else: 3434*da0073e9SAndroid Build Coastguard Worker ranges = [] 3435*da0073e9SAndroid Build Coastguard Worker for j, sz in enumerate(shape[:sparse_dim]): 3436*da0073e9SAndroid Build Coastguard Worker ranges.append(list(range(sz))) 3437*da0073e9SAndroid Build Coastguard Worker indices = torch.tensor(list(itertools.product(*ranges)), dtype=torch.long, device=device).t() 3438*da0073e9SAndroid Build Coastguard Worker values = torch.zeros((indices.shape[1],) + shape[sparse_dim:], dtype=dtype, device=device) 3439*da0073e9SAndroid Build Coastguard Worker sv = [slice(None)] * (dense_dim + 1) 3440*da0073e9SAndroid Build Coastguard Worker sv[dim - sparse_dim + 1] = i 3441*da0073e9SAndroid Build Coastguard Worker values[tuple(sv)] = 1 3442*da0073e9SAndroid Build Coastguard Worker v = torch.sparse_coo_tensor(indices, values, shape, dtype=dtype, device=device) 3443*da0073e9SAndroid Build Coastguard Worker else: 3444*da0073e9SAndroid Build Coastguard Worker v = torch.zeros_like(x) 3445*da0073e9SAndroid Build Coastguard Worker sv = [slice(None)] * len(v.shape) 3446*da0073e9SAndroid Build Coastguard Worker sv[dim] = i 3447*da0073e9SAndroid Build Coastguard Worker v[tuple(sv)] = 1 3448*da0073e9SAndroid Build Coastguard Worker x_ = x.clone() 3449*da0073e9SAndroid Build Coastguard Worker x_.requires_grad_(True) 3450*da0073e9SAndroid Build Coastguard Worker 3451*da0073e9SAndroid Build Coastguard Worker if log: 3452*da0073e9SAndroid Build Coastguard Worker if x_.is_sparse: 3453*da0073e9SAndroid Build Coastguard Worker y = torch.sparse.log_softmax(x_, dim) 3454*da0073e9SAndroid Build Coastguard Worker else: 3455*da0073e9SAndroid Build Coastguard Worker y = F.log_softmax(x_, dim) 3456*da0073e9SAndroid Build Coastguard Worker else: 3457*da0073e9SAndroid Build Coastguard Worker if x_.is_sparse: 3458*da0073e9SAndroid Build Coastguard Worker y = torch.sparse.softmax(x_, dim) 3459*da0073e9SAndroid Build Coastguard Worker else: 3460*da0073e9SAndroid Build Coastguard Worker y = F.softmax(x_, dim) 3461*da0073e9SAndroid Build Coastguard Worker # replace nan-s with zeros 3462*da0073e9SAndroid Build Coastguard Worker y.data[y != y] = 0 3463*da0073e9SAndroid Build Coastguard Worker y.backward(v) 3464*da0073e9SAndroid Build Coastguard Worker g = x_.grad 3465*da0073e9SAndroid Build Coastguard Worker if not g.is_sparse: 3466*da0073e9SAndroid Build Coastguard Worker # replace nan-s with zeros 3467*da0073e9SAndroid Build Coastguard Worker g.data[g != g] = 0 3468*da0073e9SAndroid Build Coastguard Worker J[i] = g.to_dense() if g.is_sparse else g 3469*da0073e9SAndroid Build Coastguard Worker return J 3470*da0073e9SAndroid Build Coastguard Worker 3471*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1166") 3472*da0073e9SAndroid Build Coastguard Worker def test_op(sparse_dims, nnz, with_size, coalesced): 3473*da0073e9SAndroid Build Coastguard Worker if isinstance(with_size, Number): 3474*da0073e9SAndroid Build Coastguard Worker with_size = [with_size] * sparse_dims 3475*da0073e9SAndroid Build Coastguard Worker 3476*da0073e9SAndroid Build Coastguard Worker x, i, v = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced) 3477*da0073e9SAndroid Build Coastguard Worker 3478*da0073e9SAndroid Build Coastguard Worker def sparse_log(x): 3479*da0073e9SAndroid Build Coastguard Worker return torch.sparse_coo_tensor(x._indices(), x._values().log(), 3480*da0073e9SAndroid Build Coastguard Worker x.size(), dtype=x.dtype, device=x.device) 3481*da0073e9SAndroid Build Coastguard Worker 3482*da0073e9SAndroid Build Coastguard Worker # Check dim out of bounds 3483*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(IndexError, r"Dimension out of range"): 3484*da0073e9SAndroid Build Coastguard Worker torch.sparse.softmax(x, x.dim()) 3485*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(IndexError, r"Dimension out of range"): 3486*da0073e9SAndroid Build Coastguard Worker torch.sparse.softmax(x, -x.dim() - 1) 3487*da0073e9SAndroid Build Coastguard Worker 3488*da0073e9SAndroid Build Coastguard Worker for dim in range(x.dim()): 3489*da0073e9SAndroid Build Coastguard Worker # Check sparse softmax definition 3490*da0073e9SAndroid Build Coastguard Worker 3491*da0073e9SAndroid Build Coastguard Worker # check Python sparse softmax 3492*da0073e9SAndroid Build Coastguard Worker y = sparse_softmax(x, dim) 3493*da0073e9SAndroid Build Coastguard Worker r1 = softmax_to_dense(x, dim) 3494*da0073e9SAndroid Build Coastguard Worker r2 = y.to_dense() 3495*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r1, r2) 3496*da0073e9SAndroid Build Coastguard Worker 3497*da0073e9SAndroid Build Coastguard Worker # check C++ sparse softmax 3498*da0073e9SAndroid Build Coastguard Worker for d in (dim, dim - x.dim()): 3499*da0073e9SAndroid Build Coastguard Worker y1 = torch.sparse.softmax(x, d) 3500*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y1) 3501*da0073e9SAndroid Build Coastguard Worker 3502*da0073e9SAndroid Build Coastguard Worker # check C++ sparse log_softmax 3503*da0073e9SAndroid Build Coastguard Worker ly1 = torch.sparse.log_softmax(x, d) 3504*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ly1, sparse_log(y1)) 3505*da0073e9SAndroid Build Coastguard Worker 3506*da0073e9SAndroid Build Coastguard Worker # Check autograd support on sparse softmax 3507*da0073e9SAndroid Build Coastguard Worker 3508*da0073e9SAndroid Build Coastguard Worker # check softmax Jacobian definition for dense input 3509*da0073e9SAndroid Build Coastguard Worker x1 = to_dense(x, fill_value=float('-inf')) 3510*da0073e9SAndroid Build Coastguard Worker J = softmax_jacobian_analytic(x1, dim) 3511*da0073e9SAndroid Build Coastguard Worker assert J.shape[0] == x.shape[dim] 3512*da0073e9SAndroid Build Coastguard Worker assert J.shape[dim + 1] == x.shape[dim] 3513*da0073e9SAndroid Build Coastguard Worker 3514*da0073e9SAndroid Build Coastguard Worker # check softmax Jacobian from autograd, dense input 3515*da0073e9SAndroid Build Coastguard Worker J2 = softmax_jacobian_autograd(x1, dim) 3516*da0073e9SAndroid Build Coastguard Worker self.assertEqual(J, J2) 3517*da0073e9SAndroid Build Coastguard Worker 3518*da0073e9SAndroid Build Coastguard Worker # check softmax Jacobian from autograd, sparse input 3519*da0073e9SAndroid Build Coastguard Worker J3 = softmax_jacobian_autograd(x, dim) 3520*da0073e9SAndroid Build Coastguard Worker self.assertEqual(J, J3) 3521*da0073e9SAndroid Build Coastguard Worker 3522*da0073e9SAndroid Build Coastguard Worker ''' 3523*da0073e9SAndroid Build Coastguard Worker y = softmax(x, dim) 3524*da0073e9SAndroid Build Coastguard Worker z = log(y) = log_softmax(x, dim) 3525*da0073e9SAndroid Build Coastguard Worker Dy/Dx = J 3526*da0073e9SAndroid Build Coastguard Worker Dz/Dx = Dz/Dy Dy/Dx = 1/y * J 3527*da0073e9SAndroid Build Coastguard Worker => J = J_log * y 3528*da0073e9SAndroid Build Coastguard Worker ''' 3529*da0073e9SAndroid Build Coastguard Worker # log_softmax Jacobian from autograd, dense input 3530*da0073e9SAndroid Build Coastguard Worker J2_log = softmax_jacobian_autograd(x1, dim, log=True) 3531*da0073e9SAndroid Build Coastguard Worker 3532*da0073e9SAndroid Build Coastguard Worker # log_softmax Jacobian from autograd, sparse input 3533*da0073e9SAndroid Build Coastguard Worker J3_log = softmax_jacobian_autograd(x, dim, log=True) 3534*da0073e9SAndroid Build Coastguard Worker 3535*da0073e9SAndroid Build Coastguard Worker J = J.transpose(0, dim + 1) 3536*da0073e9SAndroid Build Coastguard Worker J2_log = J2_log.transpose(0, dim + 1) 3537*da0073e9SAndroid Build Coastguard Worker J3_log = J3_log.transpose(0, dim + 1) 3538*da0073e9SAndroid Build Coastguard Worker self.assertEqual(J, J2_log * r1) 3539*da0073e9SAndroid Build Coastguard Worker self.assertEqual(J, J3_log * r1) 3540*da0073e9SAndroid Build Coastguard Worker 3541*da0073e9SAndroid Build Coastguard Worker if dim == 0: 3542*da0073e9SAndroid Build Coastguard Worker # check dtype argument 3543*da0073e9SAndroid Build Coastguard Worker other_dtype = torch.float32 3544*da0073e9SAndroid Build Coastguard Worker y2 = torch.sparse.softmax(x, dim, dtype=other_dtype) 3545*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y2.dtype, other_dtype) 3546*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y2, y1.type(other_dtype)) 3547*da0073e9SAndroid Build Coastguard Worker 3548*da0073e9SAndroid Build Coastguard Worker ly2 = torch.sparse.log_softmax(x, dim, dtype=other_dtype) 3549*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ly2.dtype, other_dtype) 3550*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ly2, ly1.type(other_dtype)) 3551*da0073e9SAndroid Build Coastguard Worker 3552*da0073e9SAndroid Build Coastguard Worker test_op(1, 10, [3], coalesced) 3553*da0073e9SAndroid Build Coastguard Worker test_op(1, 10, [2, 3], coalesced) 3554*da0073e9SAndroid Build Coastguard Worker test_op(1, 10, [3, 2], coalesced) 3555*da0073e9SAndroid Build Coastguard Worker test_op(2, 10, [2, 3, 4], coalesced) 3556*da0073e9SAndroid Build Coastguard Worker test_op(2, 10, [3, 4], coalesced) 3557*da0073e9SAndroid Build Coastguard Worker test_op(2, 5, [5, 4], coalesced) 3558*da0073e9SAndroid Build Coastguard Worker test_op(2, 10, [3, 4, 2], coalesced) 3559*da0073e9SAndroid Build Coastguard Worker test_op(3, 10, [3, 4, 2], coalesced) 3560*da0073e9SAndroid Build Coastguard Worker test_op(3, 100, [3, 4, 2], coalesced) 3561*da0073e9SAndroid Build Coastguard Worker test_op(3, 100, [3, 4, 2, 3], coalesced) 3562*da0073e9SAndroid Build Coastguard Worker test_op(3, 100, [3, 4, 2, 3, 5, 2], coalesced) 3563*da0073e9SAndroid Build Coastguard Worker test_op(4, 100, [3, 4, 2, 3, 5, 2], coalesced) 3564*da0073e9SAndroid Build Coastguard Worker 3565*da0073e9SAndroid Build Coastguard Worker 3566*da0073e9SAndroid Build Coastguard Worker def _check_zero_nnz_softmax_op(self, func, ndim, device, dtype): 3567*da0073e9SAndroid Build Coastguard Worker # create a sparse tensor with shape (0,..., 3) it has no materialize values 3568*da0073e9SAndroid Build Coastguard Worker t = torch.sparse_coo_tensor([[] for _ in range(ndim)], [], (0,) * (ndim - 1) + (3,), device=device, dtype=dtype) 3569*da0073e9SAndroid Build Coastguard Worker out = func(t, 0) 3570*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, torch.zeros_like(t)) 3571*da0073e9SAndroid Build Coastguard Worker 3572*da0073e9SAndroid Build Coastguard Worker # gradient 3573*da0073e9SAndroid Build Coastguard Worker t = t.requires_grad_() 3574*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda x: func(x, 0).to_dense(), (t,), masked=True) 3575*da0073e9SAndroid Build Coastguard Worker 3576*da0073e9SAndroid Build Coastguard Worker 3577*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.float) 3578*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error") 3579*da0073e9SAndroid Build Coastguard Worker def test_softmax_zero_nnz(self, device, dtype): 3580*da0073e9SAndroid Build Coastguard Worker self._check_zero_nnz_softmax_op(torch.sparse.softmax, 1, device, dtype) 3581*da0073e9SAndroid Build Coastguard Worker self._check_zero_nnz_softmax_op(torch.sparse.softmax, 10, device, dtype) 3582*da0073e9SAndroid Build Coastguard Worker 3583*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.float) 3584*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error") 3585*da0073e9SAndroid Build Coastguard Worker def test_log_softmax_zero_nnz(self, device, dtype): 3586*da0073e9SAndroid Build Coastguard Worker self._check_zero_nnz_softmax_op(torch.sparse.log_softmax, 1, device, dtype) 3587*da0073e9SAndroid Build Coastguard Worker self._check_zero_nnz_softmax_op(torch.sparse.log_softmax, 10, device, dtype) 3588*da0073e9SAndroid Build Coastguard Worker 3589*da0073e9SAndroid Build Coastguard Worker # TODO: Check after why ROCm's cusparseXcsrgemm2Nnz function doesn't return the same nnz value as CUDA 3590*da0073e9SAndroid Build Coastguard Worker @skipIfRocm 3591*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 3592*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 3593*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*floating_types_and(*[torch.half] if SM53OrLater else [], 3594*da0073e9SAndroid Build Coastguard Worker *[torch.bfloat16] if SM80OrLater else [], 3595*da0073e9SAndroid Build Coastguard Worker torch.complex64, 3596*da0073e9SAndroid Build Coastguard Worker *[torch.complex128] if CUSPARSE_SPMM_COMPLEX128_SUPPORTED else [])) 3597*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_CROSSREF, "not working with fake tensor") 3598*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2, torch.complex64: 1e-2, torch.float32: 1e-2}) 3599*da0073e9SAndroid Build Coastguard Worker def test_sparse_matmul(self, device, dtype, coalesced): 3600*da0073e9SAndroid Build Coastguard Worker """ 3601*da0073e9SAndroid Build Coastguard Worker This function test `torch.sparse.mm` when both the mat1 and mat2 are sparse tensors. 3602*da0073e9SAndroid Build Coastguard Worker """ 3603*da0073e9SAndroid Build Coastguard Worker 3604*da0073e9SAndroid Build Coastguard Worker def ref_sparse_mm(a, b): 3605*da0073e9SAndroid Build Coastguard Worker return a.to_dense() @ b.to_dense() 3606*da0073e9SAndroid Build Coastguard Worker 3607*da0073e9SAndroid Build Coastguard Worker def grad_with_custom_sparsity_pattern_test_helper(sparse_dims, nnz, shape_a, shape_b): 3608*da0073e9SAndroid Build Coastguard Worker def test_grad_dense(a_s, b_s, g_s): 3609*da0073e9SAndroid Build Coastguard Worker a = a_s.to_dense().detach() 3610*da0073e9SAndroid Build Coastguard Worker b = b_s.to_dense().detach() 3611*da0073e9SAndroid Build Coastguard Worker g = g_s.to_dense().detach() 3612*da0073e9SAndroid Build Coastguard Worker 3613*da0073e9SAndroid Build Coastguard Worker a.requires_grad_(True) 3614*da0073e9SAndroid Build Coastguard Worker b.requires_grad_(True) 3615*da0073e9SAndroid Build Coastguard Worker c = a @ b 3616*da0073e9SAndroid Build Coastguard Worker c.backward(g) 3617*da0073e9SAndroid Build Coastguard Worker return a.grad.sparse_mask(a_s.coalesce()), b.grad.sparse_mask(b_s.coalesce()) 3618*da0073e9SAndroid Build Coastguard Worker 3619*da0073e9SAndroid Build Coastguard Worker a, _, _ = self._gen_sparse(sparse_dims, nnz, shape_a, dtype, device, coalesced) 3620*da0073e9SAndroid Build Coastguard Worker b, _, _ = self._gen_sparse(sparse_dims, nnz, shape_b, dtype, device, coalesced) 3621*da0073e9SAndroid Build Coastguard Worker a.requires_grad_(True) 3622*da0073e9SAndroid Build Coastguard Worker b.requires_grad_(True) 3623*da0073e9SAndroid Build Coastguard Worker 3624*da0073e9SAndroid Build Coastguard Worker c = torch.sparse.mm(a, b) 3625*da0073e9SAndroid Build Coastguard Worker c2 = c.to_dense().detach() 3626*da0073e9SAndroid Build Coastguard Worker c2 = torch.rand_like(c2) 3627*da0073e9SAndroid Build Coastguard Worker g = c2.sparse_mask(c.coalesce()) 3628*da0073e9SAndroid Build Coastguard Worker 3629*da0073e9SAndroid Build Coastguard Worker c.backward(g) 3630*da0073e9SAndroid Build Coastguard Worker 3631*da0073e9SAndroid Build Coastguard Worker a_grad, b_grad = test_grad_dense(a, b, g) 3632*da0073e9SAndroid Build Coastguard Worker 3633*da0073e9SAndroid Build Coastguard Worker # We convert grad to dense since dense and sparse mm 3634*da0073e9SAndroid Build Coastguard Worker # implementations handle materialized zeroes differently. 3635*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad.to_dense(), a_grad.to_dense()) 3636*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.grad.to_dense(), b_grad.to_dense()) 3637*da0073e9SAndroid Build Coastguard Worker 3638*da0073e9SAndroid Build Coastguard Worker def test_sparse_matmul(sparse_dims, nnz, shape_a, shape_b): 3639*da0073e9SAndroid Build Coastguard Worker a, i_a, v_a = self._gen_sparse(sparse_dims, nnz, shape_a, dtype, device, coalesced) 3640*da0073e9SAndroid Build Coastguard Worker b, i_b, v_b = self._gen_sparse(sparse_dims, nnz, shape_b, dtype, device, coalesced) 3641*da0073e9SAndroid Build Coastguard Worker 3642*da0073e9SAndroid Build Coastguard Worker # dense implementation 3643*da0073e9SAndroid Build Coastguard Worker r1 = ref_sparse_mm(a, b) 3644*da0073e9SAndroid Build Coastguard Worker 3645*da0073e9SAndroid Build Coastguard Worker # cpp implementation 3646*da0073e9SAndroid Build Coastguard Worker r2 = torch.sparse.mm(a, b) 3647*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r1, r2.to_dense()) 3648*da0073e9SAndroid Build Coastguard Worker 3649*da0073e9SAndroid Build Coastguard Worker # Check result is truly coalesced 3650*da0073e9SAndroid Build Coastguard Worker self.assertTrue(r2.is_coalesced() and is_coalesced_indices(r2)) 3651*da0073e9SAndroid Build Coastguard Worker 3652*da0073e9SAndroid Build Coastguard Worker if dtype in [torch.double, torch.cdouble]: 3653*da0073e9SAndroid Build Coastguard Worker a.requires_grad_(True) 3654*da0073e9SAndroid Build Coastguard Worker b.requires_grad_(True) 3655*da0073e9SAndroid Build Coastguard Worker 3656*da0073e9SAndroid Build Coastguard Worker # check autograd support on sparse matmul 3657*da0073e9SAndroid Build Coastguard Worker def fn(D1, D2): 3658*da0073e9SAndroid Build Coastguard Worker return torch.sparse.mm(D1, D2).to_dense() 3659*da0073e9SAndroid Build Coastguard Worker 3660*da0073e9SAndroid Build Coastguard Worker if a.is_cuda: 3661*da0073e9SAndroid Build Coastguard Worker # For cuda, `nondet_tol` is set with `1e-5` 3662*da0073e9SAndroid Build Coastguard Worker # This is because cuSparse sometimes returns approximate zero values like `~e-323` 3663*da0073e9SAndroid Build Coastguard Worker # TODO: Check this cuSparse issue. 3664*da0073e9SAndroid Build Coastguard Worker # This happens when you do chain multiplication `torch.sparse.mm` operations 3665*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, (a, b), nondet_tol=1e-5, masked=True) 3666*da0073e9SAndroid Build Coastguard Worker else: 3667*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, (a, b), masked=True) 3668*da0073e9SAndroid Build Coastguard Worker grad_with_custom_sparsity_pattern_test_helper(sparse_dims, nnz, shape_a, shape_b) 3669*da0073e9SAndroid Build Coastguard Worker 3670*da0073e9SAndroid Build Coastguard Worker def test_error_cases(): 3671*da0073e9SAndroid Build Coastguard Worker def fn(sparse_dims, nnz, shape_a, shape_b): 3672*da0073e9SAndroid Build Coastguard Worker a, i_a, v_a = self._gen_sparse(sparse_dims, nnz, shape_a, dtype, device, coalesced) 3673*da0073e9SAndroid Build Coastguard Worker b, i_b, v_b = self._gen_sparse(sparse_dims, nnz, shape_b, dtype, device, coalesced) 3674*da0073e9SAndroid Build Coastguard Worker r2 = torch.sparse.mm(a, b) 3675*da0073e9SAndroid Build Coastguard Worker 3676*da0073e9SAndroid Build Coastguard Worker # This is not a matrix 3677*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: fn(3, 4, [2, 2, 2], [2, 2, 2])) 3678*da0073e9SAndroid Build Coastguard Worker 3679*da0073e9SAndroid Build Coastguard Worker # Shapes does not 3680*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 3681*da0073e9SAndroid Build Coastguard Worker r"mat1 and mat2 shapes cannot be multiplied \(2x3 and 4x2\)", 3682*da0073e9SAndroid Build Coastguard Worker lambda: fn(2, 10, [2, 3], [4, 2])) 3683*da0073e9SAndroid Build Coastguard Worker 3684*da0073e9SAndroid Build Coastguard Worker def different_dtypes(): 3685*da0073e9SAndroid Build Coastguard Worker a, i_a, v_a = self._gen_sparse(2, 10, [2, 2], dtype, device, coalesced) 3686*da0073e9SAndroid Build Coastguard Worker b, i_b, v_b = self._gen_sparse(2, 10, [2, 2], dtype, device, coalesced) 3687*da0073e9SAndroid Build Coastguard Worker r2 = torch.sparse.mm(a.to(torch.float64), a.to(torch.float32)) 3688*da0073e9SAndroid Build Coastguard Worker 3689*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 'mat1 dtype Double does not match mat2 dtype Float', different_dtypes) 3690*da0073e9SAndroid Build Coastguard Worker 3691*da0073e9SAndroid Build Coastguard Worker def test_backward_noncontiguous(): 3692*da0073e9SAndroid Build Coastguard Worker # Sparse.mm backward used to wrong with non-contiguous grads, 3693*da0073e9SAndroid Build Coastguard Worker # see https://github.com/pytorch/pytorch/issues/102493. 3694*da0073e9SAndroid Build Coastguard Worker n_reps = 7 3695*da0073e9SAndroid Build Coastguard Worker for _ in range(n_reps): 3696*da0073e9SAndroid Build Coastguard Worker A = torch.eye(5).to_sparse().requires_grad_(True) 3697*da0073e9SAndroid Build Coastguard Worker B = torch.eye(5).to_sparse() 3698*da0073e9SAndroid Build Coastguard Worker out = torch.sparse.mm(A, B) 3699*da0073e9SAndroid Build Coastguard Worker out.coalesce().values().sum().backward() 3700*da0073e9SAndroid Build Coastguard Worker self.assertEqual(A.grad, A) 3701*da0073e9SAndroid Build Coastguard Worker 3702*da0073e9SAndroid Build Coastguard Worker for n in range(2, 5): 3703*da0073e9SAndroid Build Coastguard Worker for m in range(2, 8): 3704*da0073e9SAndroid Build Coastguard Worker for p in range(2, 8): 3705*da0073e9SAndroid Build Coastguard Worker test_sparse_matmul(2, 10, [n, m], [m, p]) 3706*da0073e9SAndroid Build Coastguard Worker 3707*da0073e9SAndroid Build Coastguard Worker test_sparse_matmul(2, 0, [0, 0], [0, 0]) 3708*da0073e9SAndroid Build Coastguard Worker test_sparse_matmul(2, 0, [0, 10], [10, 0]) 3709*da0073e9SAndroid Build Coastguard Worker test_error_cases() 3710*da0073e9SAndroid Build Coastguard Worker test_backward_noncontiguous() 3711*da0073e9SAndroid Build Coastguard Worker 3712*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 3713*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 3714*da0073e9SAndroid Build Coastguard Worker def test_assign(self, device, dtype, coalesced): 3715*da0073e9SAndroid Build Coastguard Worker def assign_to(): 3716*da0073e9SAndroid Build Coastguard Worker a, i_a, v_a = self._gen_sparse(2, 5, [2, 3], dtype, device, coalesced) 3717*da0073e9SAndroid Build Coastguard Worker a[0] = 100 3718*da0073e9SAndroid Build Coastguard Worker 3719*da0073e9SAndroid Build Coastguard Worker self.assertRaises(TypeError, assign_to) 3720*da0073e9SAndroid Build Coastguard Worker 3721*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 3722*da0073e9SAndroid Build Coastguard Worker def test_full_broadcast_to(self, device, dtype): 3723*da0073e9SAndroid Build Coastguard Worker def can_broadcast(s0, s1): 3724*da0073e9SAndroid Build Coastguard Worker s0 = tuple(reversed(s0)) 3725*da0073e9SAndroid Build Coastguard Worker s1 = tuple(reversed(s1)) 3726*da0073e9SAndroid Build Coastguard Worker for i in range(len(s0)): 3727*da0073e9SAndroid Build Coastguard Worker if s0[i] != 1 and s0[i] != s1[i]: 3728*da0073e9SAndroid Build Coastguard Worker return False 3729*da0073e9SAndroid Build Coastguard Worker return True 3730*da0073e9SAndroid Build Coastguard Worker sizes = ( 3731*da0073e9SAndroid Build Coastguard Worker (), (1,), (2,), (1, 1), (3, 1), (3, 2), (4, 1, 1), (4, 3, 2) 3732*da0073e9SAndroid Build Coastguard Worker ) 3733*da0073e9SAndroid Build Coastguard Worker for s0, s1 in itertools.combinations(sizes, r=2): 3734*da0073e9SAndroid Build Coastguard Worker t = make_tensor(s0, dtype=dtype, device=device, low=-9, high=9) 3735*da0073e9SAndroid Build Coastguard Worker for sparse_dims in range(1, len(s0) + 1): 3736*da0073e9SAndroid Build Coastguard Worker s = t.to_sparse(sparse_dims) 3737*da0073e9SAndroid Build Coastguard Worker if can_broadcast(s0, s1): 3738*da0073e9SAndroid Build Coastguard Worker t_res = torch.broadcast_to(t, s1) 3739*da0073e9SAndroid Build Coastguard Worker s_res = torch._sparse_broadcast_to(s, s1) 3740*da0073e9SAndroid Build Coastguard Worker torch._validate_sparse_coo_tensor_args(s_res._indices(), s_res._values(), s_res.shape) 3741*da0073e9SAndroid Build Coastguard Worker if s_res.is_coalesced(): 3742*da0073e9SAndroid Build Coastguard Worker # ensure that is_coalesced is estimated correctly 3743*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s_res, torch.sparse_coo_tensor(s_res._indices(), s_res._values(), s_res.shape).coalesce()) 3744*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s_res.to_dense(), t_res) 3745*da0073e9SAndroid Build Coastguard Worker else: 3746*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 3747*da0073e9SAndroid Build Coastguard Worker r"The expanded size of the tensor \(\d\) " 3748*da0073e9SAndroid Build Coastguard Worker r"must match the existing size \(\d\)"): 3749*da0073e9SAndroid Build Coastguard Worker torch._sparse_broadcast_to(s, s1) 3750*da0073e9SAndroid Build Coastguard Worker 3751*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 3752*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 3753*da0073e9SAndroid Build Coastguard Worker def test_sparse_broadcast_to(self, device, dtype, coalesced): 3754*da0073e9SAndroid Build Coastguard Worker def test(sparse_dims, nnz, with_size, new_size): 3755*da0073e9SAndroid Build Coastguard Worker x = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0] 3756*da0073e9SAndroid Build Coastguard Worker y = self.safeToDense(x) 3757*da0073e9SAndroid Build Coastguard Worker x1 = torch._sparse_broadcast_to(x, new_size) 3758*da0073e9SAndroid Build Coastguard Worker y1 = y.broadcast_to(new_size) 3759*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self.safeToDense(x1), y1) 3760*da0073e9SAndroid Build Coastguard Worker 3761*da0073e9SAndroid Build Coastguard Worker test(4, 6, [7, 3, 1, 3, 0], [7, 3, 4, 3, 0]) 3762*da0073e9SAndroid Build Coastguard Worker test(4, 6, [7, 3, 1, 3, 0], [2, 7, 3, 1, 3, 0]) 3763*da0073e9SAndroid Build Coastguard Worker test(4, 6, [7, 3, 1, 3, 1, 3], [7, 3, 1, 3, 2, 3]) 3764*da0073e9SAndroid Build Coastguard Worker test(4, 6, [7, 3, 1, 3, 2, 1], [7, 3, 1, 3, 2, 3]) 3765*da0073e9SAndroid Build Coastguard Worker 3766*da0073e9SAndroid Build Coastguard Worker def _test_mul_skips(self, device, dtype, coalesced): 3767*da0073e9SAndroid Build Coastguard Worker skipTestIfUncoalesced = False 3768*da0073e9SAndroid Build Coastguard Worker # This case always coalesce inputs and that could lead to loss of precision, 3769*da0073e9SAndroid Build Coastguard Worker # hence it is inhibited for float16/bfloat16 by providing already coalesced tensors. 3770*da0073e9SAndroid Build Coastguard Worker if not coalesced and dtype in {torch.float16, torch.bfloat16}: 3771*da0073e9SAndroid Build Coastguard Worker skipTestIfUncoalesced = True 3772*da0073e9SAndroid Build Coastguard Worker # to_dense is problematic for boolean non-coalesced CUDA tensors 3773*da0073e9SAndroid Build Coastguard Worker # see https://github.com/pytorch/pytorch/issues/81648 3774*da0073e9SAndroid Build Coastguard Worker if not coalesced and dtype == torch.bool and torch.device(device).type == "cuda": 3775*da0073e9SAndroid Build Coastguard Worker skipTestIfUncoalesced = True 3776*da0073e9SAndroid Build Coastguard Worker 3777*da0073e9SAndroid Build Coastguard Worker if skipTestIfUncoalesced: 3778*da0073e9SAndroid Build Coastguard Worker self.skipTest(f"Test with dtype={dtype}, device={device} runs only with coalesced inputs") 3779*da0073e9SAndroid Build Coastguard Worker 3780*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 3781*da0073e9SAndroid Build Coastguard Worker # NOTE: addcmul_out is not implemented for bool. 3782*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.bfloat16, torch.float16)) 3783*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2}) 3784*da0073e9SAndroid Build Coastguard Worker def test_sparse_sparse_mul(self, device, dtype, coalesced): 3785*da0073e9SAndroid Build Coastguard Worker self._test_mul_skips(device, dtype, coalesced) 3786*da0073e9SAndroid Build Coastguard Worker 3787*da0073e9SAndroid Build Coastguard Worker shape = (2, 3, 4, 10) 3788*da0073e9SAndroid Build Coastguard Worker nnz = 10 3789*da0073e9SAndroid Build Coastguard Worker 3790*da0073e9SAndroid Build Coastguard Worker def check(self, x, y): 3791*da0073e9SAndroid Build Coastguard Worker res_sparse = x * y 3792*da0073e9SAndroid Build Coastguard Worker res_dense = x.to_dense() * y.to_dense() 3793*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_sparse.to_dense(), res_dense) 3794*da0073e9SAndroid Build Coastguard Worker 3795*da0073e9SAndroid Build Coastguard Worker def check_empty(sparse_shape, nnz, dense_shape, coalesce): 3796*da0073e9SAndroid Build Coastguard Worker from itertools import product 3797*da0073e9SAndroid Build Coastguard Worker for nnz_val, shape_suffix in product((nnz, 0), ((), (0,))): 3798*da0073e9SAndroid Build Coastguard Worker empty_sparse_shape = sparse_shape + shape_suffix 3799*da0073e9SAndroid Build Coastguard Worker empty_dense_shape = dense_shape + shape_suffix 3800*da0073e9SAndroid Build Coastguard Worker x = self._gen_sparse(sparse_dim, nnz_val, empty_sparse_shape, dtype, device, coalesce)[0] 3801*da0073e9SAndroid Build Coastguard Worker check(self, x, x) 3802*da0073e9SAndroid Build Coastguard Worker 3803*da0073e9SAndroid Build Coastguard Worker # TODO: uncomment once backward is implemented for sparse tensors that broadcast in dense dims. 3804*da0073e9SAndroid Build Coastguard Worker # def check_autograd(x, y): 3805*da0073e9SAndroid Build Coastguard Worker # if dtype in {torch.double, torch.cdouble}: 3806*da0073e9SAndroid Build Coastguard Worker # xa = x.detach().clone().requires_grad_(True) 3807*da0073e9SAndroid Build Coastguard Worker # ya = y.detach().clone().requires_grad_(True) 3808*da0073e9SAndroid Build Coastguard Worker # gradcheck(lambda a, b: (a * b).to_dense(), (xa, ya), masked=True) 3809*da0073e9SAndroid Build Coastguard Worker # gradcheck(lambda a, b: (a * b).to_dense(), (ya, xa), masked=True) 3810*da0073e9SAndroid Build Coastguard Worker 3811*da0073e9SAndroid Build Coastguard Worker for dim in range(len(shape) + 1): 3812*da0073e9SAndroid Build Coastguard Worker sub_shape = shape[dim:] 3813*da0073e9SAndroid Build Coastguard Worker sparse_dim = len(sub_shape) // 2 3814*da0073e9SAndroid Build Coastguard Worker 3815*da0073e9SAndroid Build Coastguard Worker check_empty(sub_shape, nnz, shape, coalesced) 3816*da0073e9SAndroid Build Coastguard Worker 3817*da0073e9SAndroid Build Coastguard Worker x = self._gen_sparse(sparse_dim, nnz, sub_shape, dtype, device, coalesced)[0] 3818*da0073e9SAndroid Build Coastguard Worker y = self._gen_sparse(sparse_dim, nnz, sub_shape, dtype, device, coalesced)[0] 3819*da0073e9SAndroid Build Coastguard Worker check(self, x, y) 3820*da0073e9SAndroid Build Coastguard Worker # TODO: uncomment once supported 3821*da0073e9SAndroid Build Coastguard Worker # check_autograd(x, y) 3822*da0073e9SAndroid Build Coastguard Worker 3823*da0073e9SAndroid Build Coastguard Worker # check broadcasting in dense dims 3824*da0073e9SAndroid Build Coastguard Worker for d in range(sparse_dim, len(sub_shape)): 3825*da0073e9SAndroid Build Coastguard Worker new_shape = sub_shape[:d] + (1,) + sub_shape[d + 1:] 3826*da0073e9SAndroid Build Coastguard Worker y = self._gen_sparse(sparse_dim, nnz, new_shape, dtype, device, coalesced)[0] 3827*da0073e9SAndroid Build Coastguard Worker check(self, x, y) 3828*da0073e9SAndroid Build Coastguard Worker # TODO: uncomment once supported 3829*da0073e9SAndroid Build Coastguard Worker # check_autograd(x, y) 3830*da0073e9SAndroid Build Coastguard Worker 3831*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 3832*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16)) 3833*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2}) 3834*da0073e9SAndroid Build Coastguard Worker def test_sparse_dense_mul(self, device, dtype, coalesced): 3835*da0073e9SAndroid Build Coastguard Worker self._test_mul_skips(device, dtype, coalesced) 3836*da0073e9SAndroid Build Coastguard Worker 3837*da0073e9SAndroid Build Coastguard Worker shape = (2, 3, 4, 10) 3838*da0073e9SAndroid Build Coastguard Worker nnz = 10 3839*da0073e9SAndroid Build Coastguard Worker 3840*da0073e9SAndroid Build Coastguard Worker def check(self, s, d): 3841*da0073e9SAndroid Build Coastguard Worker res = d * s 3842*da0073e9SAndroid Build Coastguard Worker 3843*da0073e9SAndroid Build Coastguard Worker # check commutativity 3844*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, s * d) 3845*da0073e9SAndroid Build Coastguard Worker 3846*da0073e9SAndroid Build Coastguard Worker # check correctness 3847*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.to_dense(), s.to_dense() * d) 3848*da0073e9SAndroid Build Coastguard Worker 3849*da0073e9SAndroid Build Coastguard Worker # check in-placeness for dense 3850*da0073e9SAndroid Build Coastguard Worker if d.dim() >= s.dim(): 3851*da0073e9SAndroid Build Coastguard Worker dc = d.clone() 3852*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d.mul_(s), dc.mul_(s.to_dense())) 3853*da0073e9SAndroid Build Coastguard Worker 3854*da0073e9SAndroid Build Coastguard Worker # check in-placeness for sparse 3855*da0073e9SAndroid Build Coastguard Worker if s.dim() >= d.dim(): 3856*da0073e9SAndroid Build Coastguard Worker # for sparse 3857*da0073e9SAndroid Build Coastguard Worker sc = s.clone() 3858*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s.mul_(d).to_dense(), sc.to_dense().mul_(d)) 3859*da0073e9SAndroid Build Coastguard Worker 3860*da0073e9SAndroid Build Coastguard Worker for dim in range(len(shape) + 1): 3861*da0073e9SAndroid Build Coastguard Worker sub_shape = shape[dim:] 3862*da0073e9SAndroid Build Coastguard Worker sparse_dim = len(sub_shape) // 2 3863*da0073e9SAndroid Build Coastguard Worker 3864*da0073e9SAndroid Build Coastguard Worker def check_empty(sparse_shape, nnz, dense_shape, coalesce): 3865*da0073e9SAndroid Build Coastguard Worker from itertools import product 3866*da0073e9SAndroid Build Coastguard Worker for nnz_val, shape_suffix in product((nnz, 0), ((), (0,))): 3867*da0073e9SAndroid Build Coastguard Worker empty_sparse_shape = sparse_shape + shape_suffix 3868*da0073e9SAndroid Build Coastguard Worker empty_dense_shape = dense_shape + shape_suffix 3869*da0073e9SAndroid Build Coastguard Worker s = self._gen_sparse(sparse_dim, nnz_val, empty_sparse_shape, dtype, device, coalesce)[0] 3870*da0073e9SAndroid Build Coastguard Worker d = make_tensor(empty_dense_shape, dtype=dtype, device=device) 3871*da0073e9SAndroid Build Coastguard Worker check(self, s, d) 3872*da0073e9SAndroid Build Coastguard Worker 3873*da0073e9SAndroid Build Coastguard Worker # check scalar multiplication 3874*da0073e9SAndroid Build Coastguard Worker s = self._gen_sparse(sparse_dim, nnz, sub_shape, dtype, device, coalesced)[0] 3875*da0073e9SAndroid Build Coastguard Worker for scalar in (True, 1, 1.0): 3876*da0073e9SAndroid Build Coastguard Worker res_sparse_right = s * scalar 3877*da0073e9SAndroid Build Coastguard Worker res_sparse_left = scalar * s 3878*da0073e9SAndroid Build Coastguard Worker res_dense = s.to_dense() * scalar 3879*da0073e9SAndroid Build Coastguard Worker # check correctness and dtype 3880*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s.to(res_sparse_right.dtype), res_sparse_right) 3881*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_sparse_right, res_sparse_left) 3882*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_sparse_right.dtype, res_dense.dtype) 3883*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_sparse_left.dtype, res_dense.dtype) 3884*da0073e9SAndroid Build Coastguard Worker # check scalar as 0-dim sparse tensor 3885*da0073e9SAndroid Build Coastguard Worker tscalar = torch.tensor(scalar, device=device) 3886*da0073e9SAndroid Build Coastguard Worker sscalar = tscalar.to_sparse() 3887*da0073e9SAndroid Build Coastguard Worker res_sparse_right = s * sscalar 3888*da0073e9SAndroid Build Coastguard Worker res_sparse_left = sscalar * s 3889*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_sparse_right, res_sparse_left) 3890*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s.to(res_sparse_right.dtype), res_sparse_right) 3891*da0073e9SAndroid Build Coastguard Worker 3892*da0073e9SAndroid Build Coastguard Worker # check non-coalesced 0-dim scalar 3893*da0073e9SAndroid Build Coastguard Worker # we skip torch.bool because for such tensors 3894*da0073e9SAndroid Build Coastguard Worker # coalesce.to_dense != to_dense 3895*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bool: 3896*da0073e9SAndroid Build Coastguard Worker return 3897*da0073e9SAndroid Build Coastguard Worker 3898*da0073e9SAndroid Build Coastguard Worker for scalar_dtype in (int, float): 3899*da0073e9SAndroid Build Coastguard Worker scalar = scalar_dtype(1) 3900*da0073e9SAndroid Build Coastguard Worker idx = torch.tensor([], device=device).reshape(0, 2) 3901*da0073e9SAndroid Build Coastguard Worker val = torch.tensor([scalar, scalar], device=device) 3902*da0073e9SAndroid Build Coastguard Worker sscalar = torch.sparse_coo_tensor(idx, val, ()) 3903*da0073e9SAndroid Build Coastguard Worker res_dense = s.to_dense() * sscalar.to_dense() 3904*da0073e9SAndroid Build Coastguard Worker self.assertEqual((s * sscalar).to_dense(), res_dense) 3905*da0073e9SAndroid Build Coastguard Worker self.assertEqual((sscalar * s).to_dense(), res_dense) 3906*da0073e9SAndroid Build Coastguard Worker 3907*da0073e9SAndroid Build Coastguard Worker # Case 1: sparse broadcasts over dense 3908*da0073e9SAndroid Build Coastguard Worker s = self._gen_sparse(sparse_dim, nnz, sub_shape, dtype, device, coalesced)[0] 3909*da0073e9SAndroid Build Coastguard Worker d = make_tensor(shape, dtype=dtype, device=device) 3910*da0073e9SAndroid Build Coastguard Worker check(self, s, d) 3911*da0073e9SAndroid Build Coastguard Worker check_empty(sub_shape, nnz, shape, coalesced) 3912*da0073e9SAndroid Build Coastguard Worker 3913*da0073e9SAndroid Build Coastguard Worker # Case 2: dense broadcasts over sparse 3914*da0073e9SAndroid Build Coastguard Worker s = self._gen_sparse(3, nnz, shape, dtype, device, coalesced)[0] 3915*da0073e9SAndroid Build Coastguard Worker d = make_tensor(sub_shape, dtype=dtype, device=device) 3916*da0073e9SAndroid Build Coastguard Worker check(self, s, d) 3917*da0073e9SAndroid Build Coastguard Worker check_empty(shape, nnz, sub_shape, coalesced) 3918*da0073e9SAndroid Build Coastguard Worker 3919*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "NumPy is not available") 3920*da0073e9SAndroid Build Coastguard Worker @onlyCPU 3921*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.bool)) 3922*da0073e9SAndroid Build Coastguard Worker def test_sparse_spdiags(self, device, dtype): 3923*da0073e9SAndroid Build Coastguard Worker 3924*da0073e9SAndroid Build Coastguard Worker make_diags = functools.partial(make_tensor, dtype=dtype, device=device) 3925*da0073e9SAndroid Build Coastguard Worker make_offsets = functools.partial(torch.tensor, dtype=torch.long, device=device) 3926*da0073e9SAndroid Build Coastguard Worker 3927*da0073e9SAndroid Build Coastguard Worker if TEST_SCIPY: 3928*da0073e9SAndroid Build Coastguard Worker def reference(diags, offsets, shape): 3929*da0073e9SAndroid Build Coastguard Worker return scipy.sparse.spdiags(diags, offsets, *shape).toarray() 3930*da0073e9SAndroid Build Coastguard Worker 3931*da0073e9SAndroid Build Coastguard Worker else: 3932*da0073e9SAndroid Build Coastguard Worker def reference(diags, offsets, shape): 3933*da0073e9SAndroid Build Coastguard Worker result = torch.zeros(shape, dtype=dtype, device=device) 3934*da0073e9SAndroid Build Coastguard Worker for i, off in enumerate(offsets): 3935*da0073e9SAndroid Build Coastguard Worker res_view = result.diagonal(off) 3936*da0073e9SAndroid Build Coastguard Worker data = diags[i] 3937*da0073e9SAndroid Build Coastguard Worker if off > 0: 3938*da0073e9SAndroid Build Coastguard Worker data = data[off:] 3939*da0073e9SAndroid Build Coastguard Worker 3940*da0073e9SAndroid Build Coastguard Worker m = min(res_view.shape[0], data.shape[0]) 3941*da0073e9SAndroid Build Coastguard Worker res_view[:m] = data[:m] 3942*da0073e9SAndroid Build Coastguard Worker return result 3943*da0073e9SAndroid Build Coastguard Worker 3944*da0073e9SAndroid Build Coastguard Worker def check_valid(diags, offsets, shape, layout=None): 3945*da0073e9SAndroid Build Coastguard Worker ref_out = reference(diags, offsets, shape) 3946*da0073e9SAndroid Build Coastguard Worker out = torch.sparse.spdiags(diags, offsets, shape, layout=layout) 3947*da0073e9SAndroid Build Coastguard Worker if layout is None: 3948*da0073e9SAndroid Build Coastguard Worker ex_layout = torch.sparse_coo 3949*da0073e9SAndroid Build Coastguard Worker else: 3950*da0073e9SAndroid Build Coastguard Worker ex_layout = layout 3951*da0073e9SAndroid Build Coastguard Worker out_dense = out.to_dense() 3952*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.layout == ex_layout, f"Output layout {out.layout} expected {ex_layout}") 3953*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_dense, ref_out, f"Result:\n{out_dense} does not match reference:\n{ref_out}") 3954*da0073e9SAndroid Build Coastguard Worker 3955*da0073e9SAndroid Build Coastguard Worker def check_invalid(args, error): 3956*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, error): 3957*da0073e9SAndroid Build Coastguard Worker torch.sparse.spdiags(*args) 3958*da0073e9SAndroid Build Coastguard Worker 3959*da0073e9SAndroid Build Coastguard Worker def valid_cases(): 3960*da0073e9SAndroid Build Coastguard Worker # some normal cases 3961*da0073e9SAndroid Build Coastguard Worker yield (make_diags((1, 5)), make_offsets([0]), (5, 5)) 3962*da0073e9SAndroid Build Coastguard Worker yield (make_diags((3, 3)), make_offsets([-1, 0, 1]), (4, 4)) 3963*da0073e9SAndroid Build Coastguard Worker # noncontigous diags 3964*da0073e9SAndroid Build Coastguard Worker yield (make_diags((5, 4), noncontiguous=True), make_offsets([-1, 1, 0, 2, -2]), (5, 5)) 3965*da0073e9SAndroid Build Coastguard Worker # noncontigous offsets 3966*da0073e9SAndroid Build Coastguard Worker yield (make_diags((3, 4)), make_offsets([1, -1, 0, -2, 2])[::2], (5, 5)) 3967*da0073e9SAndroid Build Coastguard Worker # noncontigous diags + offsets 3968*da0073e9SAndroid Build Coastguard Worker yield (make_diags((3, 4), noncontiguous=True), make_offsets([1, -1, 0, -2, 2])[::2], (5, 5)) 3969*da0073e9SAndroid Build Coastguard Worker # correct dimensionality, 2d, 2d , and shapes match, but the number of diagonals is zero 3970*da0073e9SAndroid Build Coastguard Worker yield (make_diags((0, 3)), make_offsets([]), (3, 3)) 3971*da0073e9SAndroid Build Coastguard Worker # forward rotation of upper diagonals 3972*da0073e9SAndroid Build Coastguard Worker yield (make_diags((3, 8)), make_offsets([1, 2, 3]), (4, 4)) 3973*da0073e9SAndroid Build Coastguard Worker # rotation exausts input space to read from 3974*da0073e9SAndroid Build Coastguard Worker yield (make_diags((2, 3)), make_offsets([2, 1]), (3, 3)) 3975*da0073e9SAndroid Build Coastguard Worker # Simple cases repeated with special output format 3976*da0073e9SAndroid Build Coastguard Worker yield (make_diags((1, 5)), make_offsets([0]), (5, 5), torch.sparse_csc) 3977*da0073e9SAndroid Build Coastguard Worker yield (make_diags((3, 3)), make_offsets([-1, 0, 1]), (4, 4), torch.sparse_csr) 3978*da0073e9SAndroid Build Coastguard Worker # vector diags 3979*da0073e9SAndroid Build Coastguard Worker yield (make_diags((3, )), make_offsets([1]), (4, 4)) 3980*da0073e9SAndroid Build Coastguard Worker # Scalar offset 3981*da0073e9SAndroid Build Coastguard Worker yield (make_diags((1, 3)), make_offsets(2), (4, 4)) 3982*da0073e9SAndroid Build Coastguard Worker # offsets out of range 3983*da0073e9SAndroid Build Coastguard Worker yield (make_diags((1, 3)), make_offsets([3]), (3, 3)) 3984*da0073e9SAndroid Build Coastguard Worker yield (make_diags((1, 3)), make_offsets([-3]), (3, 3)) 3985*da0073e9SAndroid Build Coastguard Worker 3986*da0073e9SAndroid Build Coastguard Worker for case in valid_cases(): 3987*da0073e9SAndroid Build Coastguard Worker check_valid(*case) 3988*da0073e9SAndroid Build Coastguard Worker 3989*da0073e9SAndroid Build Coastguard Worker def invalid_cases(): 3990*da0073e9SAndroid Build Coastguard Worker yield (make_diags((1, 3)), make_offsets([0]), (3, 2, 3)), "Output shape must be 2d" 3991*da0073e9SAndroid Build Coastguard Worker yield (make_diags((2, 3)), make_offsets([[1, 2], [0, 3]]), (3, 3)), "Offsets must be scalar or vector" 3992*da0073e9SAndroid Build Coastguard Worker yield (make_diags((3, 2, 3)), make_offsets([0, 1, 2]), (4, 4)), "Diagonals must be vector or matrix" 3993*da0073e9SAndroid Build Coastguard Worker yield (make_diags((3, 3)), make_offsets([-1, 0]), (3, 3)), \ 3994*da0073e9SAndroid Build Coastguard Worker r"Number of diagonals \(\d\) does not match the number of offsets \(\d\)" 3995*da0073e9SAndroid Build Coastguard Worker yield (make_diags((5,)), make_offsets([0, 1, 2, 3, 4]), (3, 3)), \ 3996*da0073e9SAndroid Build Coastguard Worker r"Number of diagonals \(\d\) does not match the number of offsets \(\d\)" 3997*da0073e9SAndroid Build Coastguard Worker yield (make_diags((2, 2)), make_offsets([-1, 0]), (2, 3), torch.strided), \ 3998*da0073e9SAndroid Build Coastguard Worker r"Only output layouts \(\w+, \w+, \w+\) are supported, got \w+" 3999*da0073e9SAndroid Build Coastguard Worker yield (make_diags((2, 5)), make_offsets([0, 0]), (5, 5)), "Offset tensor contains duplicate values" 4000*da0073e9SAndroid Build Coastguard Worker yield (make_diags((1, 5)), make_offsets([0]).to(torch.int32), (5, 5)), r"Offset Tensor must have dtype Long but got \w+" 4001*da0073e9SAndroid Build Coastguard Worker 4002*da0073e9SAndroid Build Coastguard Worker 4003*da0073e9SAndroid Build Coastguard Worker for case, error_regex in invalid_cases(): 4004*da0073e9SAndroid Build Coastguard Worker check_invalid(case, error_regex) 4005*da0073e9SAndroid Build Coastguard Worker 4006*da0073e9SAndroid Build Coastguard Worker def test_small_nnz_coalesced(self): 4007*da0073e9SAndroid Build Coastguard Worker # creating a coo tensor with nnz == 0 is always coalesced 4008*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.sparse_coo_tensor([[], []], [], (2, 2)).is_coalesced()) 4009*da0073e9SAndroid Build Coastguard Worker # same for a coo tensor with only 1 nnz 4010*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.sparse_coo_tensor([[0], [0]], [1], (2, 2)).is_coalesced()) 4011*da0073e9SAndroid Build Coastguard Worker # two or more nnz coalesced is false as it can't be verified without an expensive check 4012*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.sparse_coo_tensor([[0, 0], [0, 0]], [1, 2], (2, 2)).is_coalesced()) 4013*da0073e9SAndroid Build Coastguard Worker # even if there are no duplicates 4014*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.sparse_coo_tensor([[0, 1], [0, 1]], [1, 2], (2, 2)).is_coalesced()) 4015*da0073e9SAndroid Build Coastguard Worker 4016*da0073e9SAndroid Build Coastguard Worker @coalescedonoff 4017*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.bool)) 4018*da0073e9SAndroid Build Coastguard Worker def test_sum(self, device, dtype, coalesced): 4019*da0073e9SAndroid Build Coastguard Worker def run_test(shape, nnz): 4020*da0073e9SAndroid Build Coastguard Worker a = self._gen_sparse(2, nnz, shape, dtype, device, coalesced)[0] 4021*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.sum(), a._values().sum()) 4022*da0073e9SAndroid Build Coastguard Worker if dtype.is_floating_point or dtype.is_complex: 4023*da0073e9SAndroid Build Coastguard Worker a.requires_grad_(True) 4024*da0073e9SAndroid Build Coastguard Worker a_inter = a.sum() 4025*da0073e9SAndroid Build Coastguard Worker a_inter.abs().backward() 4026*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 4027*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad, torch.ones(shape, dtype=dtype, device=device) * torch.sgn(a_inter)) 4028*da0073e9SAndroid Build Coastguard Worker for shape in [(10, 5), (10, 10)]: 4029*da0073e9SAndroid Build Coastguard Worker run_test(shape, 0) 4030*da0073e9SAndroid Build Coastguard Worker run_test(shape, max(shape)) 4031*da0073e9SAndroid Build Coastguard Worker run_test(shape, shape[0] * shape[1]) 4032*da0073e9SAndroid Build Coastguard Worker 4033*da0073e9SAndroid Build Coastguard Worker 4034*da0073e9SAndroid Build Coastguard Workerclass TestSparseOneOff(TestCase): 4035*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, 'CUDA not available') 4036*da0073e9SAndroid Build Coastguard Worker def test_cuda_from_cpu(self): 4037*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4038*da0073e9SAndroid Build Coastguard Worker RuntimeError, 4039*da0073e9SAndroid Build Coastguard Worker "Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!"): 4040*da0073e9SAndroid Build Coastguard Worker torch.sparse_coo_tensor(torch.zeros(1, 4).long().cuda(), 4041*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 4, 4), 4042*da0073e9SAndroid Build Coastguard Worker [3, 4, 4]) 4043*da0073e9SAndroid Build Coastguard Worker 4044*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4045*da0073e9SAndroid Build Coastguard Worker RuntimeError, 4046*da0073e9SAndroid Build Coastguard Worker "Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!"): 4047*da0073e9SAndroid Build Coastguard Worker torch.sparse_coo_tensor(torch.zeros(1, 4).long().cuda(), 4048*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 4, 4, 0), 4049*da0073e9SAndroid Build Coastguard Worker [3, 4, 4, 0]) 4050*da0073e9SAndroid Build Coastguard Worker 4051*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4052*da0073e9SAndroid Build Coastguard Worker RuntimeError, 4053*da0073e9SAndroid Build Coastguard Worker "Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!"): 4054*da0073e9SAndroid Build Coastguard Worker torch.sparse_coo_tensor(torch.empty(1, 0).long().cuda(), 4055*da0073e9SAndroid Build Coastguard Worker torch.randn(0, 4, 4, 0), 4056*da0073e9SAndroid Build Coastguard Worker [0, 4, 4, 0]) 4057*da0073e9SAndroid Build Coastguard Worker 4058*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, 'CUDA not available') 4059*da0073e9SAndroid Build Coastguard Worker def test_cuda_sparse_cpu_dense_add(self): 4060*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(3, 4, 4) 4061*da0073e9SAndroid Build Coastguard Worker sparse_y = torch.sparse_coo_tensor(torch.zeros(1, 4).long().cuda(), 4062*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 4, 4).cuda(), 4063*da0073e9SAndroid Build Coastguard Worker [3, 4, 4]) 4064*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "add: expected 'self' to be a CUDA tensor, but got a CPU tensor"): 4065*da0073e9SAndroid Build Coastguard Worker x + sparse_y 4066*da0073e9SAndroid Build Coastguard Worker 4067*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(3, 4, 4, 0) 4068*da0073e9SAndroid Build Coastguard Worker sparse_y = torch.sparse_coo_tensor(torch.zeros(1, 4).long().cuda(), 4069*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 4, 4, 0).cuda(), 4070*da0073e9SAndroid Build Coastguard Worker [3, 4, 4, 0]) 4071*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "add: expected 'self' to be a CUDA tensor, but got a CPU tensor"): 4072*da0073e9SAndroid Build Coastguard Worker x + sparse_y 4073*da0073e9SAndroid Build Coastguard Worker 4074*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(0, 4, 4, 0) 4075*da0073e9SAndroid Build Coastguard Worker sparse_y = torch.sparse_coo_tensor(torch.empty(1, 0).long().cuda(), 4076*da0073e9SAndroid Build Coastguard Worker torch.randn(0, 4, 4, 0).cuda(), 4077*da0073e9SAndroid Build Coastguard Worker [0, 4, 4, 0]) 4078*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "add: expected 'self' to be a CUDA tensor, but got a CPU tensor"): 4079*da0073e9SAndroid Build Coastguard Worker x + sparse_y 4080*da0073e9SAndroid Build Coastguard Worker 4081*da0073e9SAndroid Build Coastguard Worker 4082*da0073e9SAndroid Build Coastguard Workerdef _sparse_to_dense(tensor): 4083*da0073e9SAndroid Build Coastguard Worker if tensor.dtype != torch.bool: 4084*da0073e9SAndroid Build Coastguard Worker return tensor.to_dense(masked_grad=True) 4085*da0073e9SAndroid Build Coastguard Worker 4086*da0073e9SAndroid Build Coastguard Worker # to_dense uses coalesce which isn't implemented for bool 4087*da0073e9SAndroid Build Coastguard Worker return tensor.to(torch.int8).to_dense().to(torch.bool) 4088*da0073e9SAndroid Build Coastguard Worker 4089*da0073e9SAndroid Build Coastguard Worker 4090*da0073e9SAndroid Build Coastguard Worker_sparse_unary_ops = ops(sparse_unary_ufuncs, dtypes=OpDTypes.supported, 4091*da0073e9SAndroid Build Coastguard Worker allowed_dtypes=all_types_and_complex()) 4092*da0073e9SAndroid Build Coastguard Workerclass TestSparseUnaryUfuncs(TestCase): 4093*da0073e9SAndroid Build Coastguard Worker exact_dtype = True 4094*da0073e9SAndroid Build Coastguard Worker 4095*da0073e9SAndroid Build Coastguard Worker 4096*da0073e9SAndroid Build Coastguard Worker @_sparse_unary_ops 4097*da0073e9SAndroid Build Coastguard Worker def test_sparse_consistency(self, device, dtype, op): 4098*da0073e9SAndroid Build Coastguard Worker sample = first_sample(self, op.sample_inputs(device, dtype)) 4099*da0073e9SAndroid Build Coastguard Worker assert isinstance(sample.input, torch.Tensor) 4100*da0073e9SAndroid Build Coastguard Worker 4101*da0073e9SAndroid Build Coastguard Worker expected = op(sample.input, *sample.args, **sample.kwargs) 4102*da0073e9SAndroid Build Coastguard Worker assert torch.is_tensor(expected) 4103*da0073e9SAndroid Build Coastguard Worker output = op(sample.input.to_sparse(), *sample.args, **sample.kwargs) 4104*da0073e9SAndroid Build Coastguard Worker assert torch.is_tensor(output) 4105*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_sparse_to_dense(output), expected) 4106*da0073e9SAndroid Build Coastguard Worker 4107*da0073e9SAndroid Build Coastguard Worker @_sparse_unary_ops 4108*da0073e9SAndroid Build Coastguard Worker def test_out(self, device, dtype, op): 4109*da0073e9SAndroid Build Coastguard Worker if not op.supports_out: 4110*da0073e9SAndroid Build Coastguard Worker self.skipTest("Skipped! Out not supported") 4111*da0073e9SAndroid Build Coastguard Worker 4112*da0073e9SAndroid Build Coastguard Worker sample = first_sample(self, op.sample_inputs(device, dtype)) 4113*da0073e9SAndroid Build Coastguard Worker sample.input = sample.input.to_sparse() 4114*da0073e9SAndroid Build Coastguard Worker expect = op(sample.input, *sample.args, **sample.kwargs) 4115*da0073e9SAndroid Build Coastguard Worker 4116*da0073e9SAndroid Build Coastguard Worker out = torch.sparse_coo_tensor(sample.input.shape, device=device, 4117*da0073e9SAndroid Build Coastguard Worker dtype=expect.dtype) 4118*da0073e9SAndroid Build Coastguard Worker op(sample.input, *sample.args, **sample.kwargs, out=out) 4119*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, expect) 4120*da0073e9SAndroid Build Coastguard Worker 4121*da0073e9SAndroid Build Coastguard Worker @_sparse_unary_ops 4122*da0073e9SAndroid Build Coastguard Worker def test_inplace(self, device, dtype, op): 4123*da0073e9SAndroid Build Coastguard Worker if op.inplace_variant is None: 4124*da0073e9SAndroid Build Coastguard Worker self.skipTest("Skipped! Out not supported") 4125*da0073e9SAndroid Build Coastguard Worker 4126*da0073e9SAndroid Build Coastguard Worker sample = first_sample(self, op.sample_inputs(device, dtype)) 4127*da0073e9SAndroid Build Coastguard Worker sample.input = sample.input.to_sparse().coalesce() 4128*da0073e9SAndroid Build Coastguard Worker expect = op(sample.input, *sample.args, **sample.kwargs) 4129*da0073e9SAndroid Build Coastguard Worker 4130*da0073e9SAndroid Build Coastguard Worker if not torch.can_cast(expect.dtype, dtype): 4131*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "result type .* can't be cast to"): 4132*da0073e9SAndroid Build Coastguard Worker op.inplace_variant(sample.input, *sample.args, **sample.kwargs) 4133*da0073e9SAndroid Build Coastguard Worker return 4134*da0073e9SAndroid Build Coastguard Worker 4135*da0073e9SAndroid Build Coastguard Worker actual = op.inplace_variant(sample.input, *sample.args, **sample.kwargs) 4136*da0073e9SAndroid Build Coastguard Worker self.assertIs(actual, sample.input) 4137*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expect) 4138*da0073e9SAndroid Build Coastguard Worker 4139*da0073e9SAndroid Build Coastguard Worker @_sparse_unary_ops 4140*da0073e9SAndroid Build Coastguard Worker def test_sparse_zero_dims(self, device, dtype, op): 4141*da0073e9SAndroid Build Coastguard Worker # test 0x0 sparse_coo_tensor 4142*da0073e9SAndroid Build Coastguard Worker indices = torch.empty(2, 0, dtype=torch.int64) 4143*da0073e9SAndroid Build Coastguard Worker values = torch.empty(0, dtype=dtype) 4144*da0073e9SAndroid Build Coastguard Worker sparse_0x0 = torch.sparse_coo_tensor(indices, values, (0, 0)) 4145*da0073e9SAndroid Build Coastguard Worker expected = torch.sparse_coo_tensor(indices, op(values), (0, 0)) 4146*da0073e9SAndroid Build Coastguard Worker actual = op(sparse_0x0) 4147*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 4148*da0073e9SAndroid Build Coastguard Worker 4149*da0073e9SAndroid Build Coastguard Worker @_sparse_unary_ops 4150*da0073e9SAndroid Build Coastguard Worker def test_sparse_zeros(self, device, dtype, op): 4151*da0073e9SAndroid Build Coastguard Worker samples = op.sample_inputs(device, dtype) 4152*da0073e9SAndroid Build Coastguard Worker 4153*da0073e9SAndroid Build Coastguard Worker zero_input = torch.zeros((), device=device, dtype=dtype) 4154*da0073e9SAndroid Build Coastguard Worker sparse_input = torch.sparse_coo_tensor((), dtype=dtype, device=device) 4155*da0073e9SAndroid Build Coastguard Worker 4156*da0073e9SAndroid Build Coastguard Worker expect = op(zero_input) 4157*da0073e9SAndroid Build Coastguard Worker actual = op(sparse_input) 4158*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, _sparse_to_dense(actual)) 4159*da0073e9SAndroid Build Coastguard Worker 4160*da0073e9SAndroid Build Coastguard Worker @ops(sparse_unary_ufuncs, dtypes=OpDTypes.supported, 4161*da0073e9SAndroid Build Coastguard Worker allowed_dtypes=[torch.double, torch.cdouble]) 4162*da0073e9SAndroid Build Coastguard Worker def test_sparse_fn_grad(self, device, dtype, op): 4163*da0073e9SAndroid Build Coastguard Worker if not op.supports_autograd: 4164*da0073e9SAndroid Build Coastguard Worker self.skipTest("Skipped! Op doesn't support autograd") 4165*da0073e9SAndroid Build Coastguard Worker 4166*da0073e9SAndroid Build Coastguard Worker for sample in op.sample_inputs(device, dtype): 4167*da0073e9SAndroid Build Coastguard Worker sparse_input = sample.input.to_sparse().detach().requires_grad_(True) 4168*da0073e9SAndroid Build Coastguard Worker 4169*da0073e9SAndroid Build Coastguard Worker def fn(x): 4170*da0073e9SAndroid Build Coastguard Worker return _sparse_to_dense( 4171*da0073e9SAndroid Build Coastguard Worker op(x, *sample.args, **sample.kwargs)) 4172*da0073e9SAndroid Build Coastguard Worker 4173*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck( 4174*da0073e9SAndroid Build Coastguard Worker fn, 4175*da0073e9SAndroid Build Coastguard Worker (sparse_input,), 4176*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 4177*da0073e9SAndroid Build Coastguard Worker check_grad_dtypes=True, 4178*da0073e9SAndroid Build Coastguard Worker nondet_tol=op.gradcheck_nondet_tol, 4179*da0073e9SAndroid Build Coastguard Worker fast_mode=op.gradcheck_fast_mode, 4180*da0073e9SAndroid Build Coastguard Worker masked=True)) 4181*da0073e9SAndroid Build Coastguard Worker 4182*da0073e9SAndroid Build Coastguard Worker 4183*da0073e9SAndroid Build Coastguard Workerclass TestSparseMaskedReductions(TestCase): 4184*da0073e9SAndroid Build Coastguard Worker exact_dtype = True 4185*da0073e9SAndroid Build Coastguard Worker 4186*da0073e9SAndroid Build Coastguard Worker fp16_low_precision_list = { 4187*da0073e9SAndroid Build Coastguard Worker 'masked.prod', 4188*da0073e9SAndroid Build Coastguard Worker } 4189*da0073e9SAndroid Build Coastguard Worker 4190*da0073e9SAndroid Build Coastguard Worker @ops(sparse_masked_reduction_ops) 4191*da0073e9SAndroid Build Coastguard Worker def test_future_empty_dim(self, device, dtype, op): 4192*da0073e9SAndroid Build Coastguard Worker """Currently, `dim=()` in reductions operations means "reduce over 4193*da0073e9SAndroid Build Coastguard Worker all dimensions" while in future, it will read "no reduce". See 4194*da0073e9SAndroid Build Coastguard Worker https://github.com/pytorch/pytorch/issues/29137 4195*da0073e9SAndroid Build Coastguard Worker 4196*da0073e9SAndroid Build Coastguard Worker For sparse masked reductions, we'll implement the current behavior. 4197*da0073e9SAndroid Build Coastguard Worker 4198*da0073e9SAndroid Build Coastguard Worker For testing, we'll use samples with `dim=0` and map it to 4199*da0073e9SAndroid Build Coastguard Worker `dim=()` until 4200*da0073e9SAndroid Build Coastguard Worker torch.testing._internal.common_methods_invocations._generate_reduction_kwargs 4201*da0073e9SAndroid Build Coastguard Worker is made to generate samples with `dim=()` for non-scalar 4202*da0073e9SAndroid Build Coastguard Worker inputs. With this and after gh-29137 is resolved, this test 4203*da0073e9SAndroid Build Coastguard Worker can be deleted. See also `torch.masked._canonical_dim` 4204*da0073e9SAndroid Build Coastguard Worker implementation about changing the `dim=()` behavior. 4205*da0073e9SAndroid Build Coastguard Worker """ 4206*da0073e9SAndroid Build Coastguard Worker 4207*da0073e9SAndroid Build Coastguard Worker samples = op.sample_inputs_func(op, device, dtype, requires_grad=False) 4208*da0073e9SAndroid Build Coastguard Worker op_name = op.name.replace('masked.', '') 4209*da0073e9SAndroid Build Coastguard Worker for sample_input in samples: 4210*da0073e9SAndroid Build Coastguard Worker if sample_input.kwargs.get('dim') != 0: 4211*da0073e9SAndroid Build Coastguard Worker continue 4212*da0073e9SAndroid Build Coastguard Worker sample_input_kwargs = dict(sample_input.kwargs) 4213*da0073e9SAndroid Build Coastguard Worker sample_input_kwargs['dim'] = () # reduce over all dimensions 4214*da0073e9SAndroid Build Coastguard Worker 4215*da0073e9SAndroid Build Coastguard Worker t = sample_input.input 4216*da0073e9SAndroid Build Coastguard Worker mask = sample_input_kwargs.get('mask') 4217*da0073e9SAndroid Build Coastguard Worker if mask is None and op_name in {'prod', 'amax', 'amin'}: 4218*da0073e9SAndroid Build Coastguard Worker # FIXME: for now reductions with non-zero reduction identity and 4219*da0073e9SAndroid Build Coastguard Worker # unspecified mask are not supported for sparse COO 4220*da0073e9SAndroid Build Coastguard Worker # tensors, see torch.masked.prod implementation 4221*da0073e9SAndroid Build Coastguard Worker # for details. 4222*da0073e9SAndroid Build Coastguard Worker continue 4223*da0073e9SAndroid Build Coastguard Worker sparse_op_kwargs = dict(sample_input_kwargs) 4224*da0073e9SAndroid Build Coastguard Worker actual = op(t.to_sparse(), *sample_input.args, **sample_input_kwargs) 4225*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual.layout, torch.sparse_coo) 4226*da0073e9SAndroid Build Coastguard Worker 4227*da0073e9SAndroid Build Coastguard Worker expected = op(t, *sample_input.args, **sample_input_kwargs).to_sparse() 4228*da0073e9SAndroid Build Coastguard Worker atol = None 4229*da0073e9SAndroid Build Coastguard Worker rtol = None 4230*da0073e9SAndroid Build Coastguard Worker if op.name in self.fp16_low_precision_list and dtype == torch.half: 4231*da0073e9SAndroid Build Coastguard Worker atol = 1e-5 4232*da0073e9SAndroid Build Coastguard Worker rtol = 2e-3 4233*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, atol=atol, rtol=rtol) 4234*da0073e9SAndroid Build Coastguard Worker 4235*da0073e9SAndroid Build Coastguard Worker 4236*da0073e9SAndroid Build Coastguard Workerclass TestSparseMeta(TestCase): 4237*da0073e9SAndroid Build Coastguard Worker exact_dtype = True 4238*da0073e9SAndroid Build Coastguard Worker 4239*da0073e9SAndroid Build Coastguard Worker def _test_meta_sparse_coo(self, dtype): 4240*da0073e9SAndroid Build Coastguard Worker r = torch.empty(4, 4, layout=torch.sparse_coo, device='meta', dtype=dtype) 4241*da0073e9SAndroid Build Coastguard Worker self.assertTrue(r.is_meta) 4242*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.device.type, "meta") 4243*da0073e9SAndroid Build Coastguard Worker r2 = torch.empty_like(r) 4244*da0073e9SAndroid Build Coastguard Worker self.assertTrue(r2.is_meta) 4245*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, r2) 4246*da0073e9SAndroid Build Coastguard Worker r3 = torch.sparse_coo_tensor(size=(4, 4), device='meta', dtype=dtype) 4247*da0073e9SAndroid Build Coastguard Worker self.assertTrue(r3.is_meta) 4248*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, r3) 4249*da0073e9SAndroid Build Coastguard Worker r.sparse_resize_((4, 4), 1, 1) 4250*da0073e9SAndroid Build Coastguard Worker r.sparse_resize_and_clear_((4, 4, 4), 2, 1) 4251*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.sparse_dim(), 2) 4252*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.dense_dim(), 1) 4253*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r._dimV(), 1) 4254*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r._nnz(), 0) 4255*da0073e9SAndroid Build Coastguard Worker # nnz zero sparse tensors should always be coalesced at creation 4256*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.is_coalesced(), True) 4257*da0073e9SAndroid Build Coastguard Worker # but we can force them into the uncoalesed state 4258*da0073e9SAndroid Build Coastguard Worker r._coalesced_(False) 4259*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.is_coalesced(), False) 4260*da0073e9SAndroid Build Coastguard Worker # return the coalesced state for indices/values access 4261*da0073e9SAndroid Build Coastguard Worker r._coalesced_(True) 4262*da0073e9SAndroid Build Coastguard Worker # TODO: this sort of aliasing will need to be handled by 4263*da0073e9SAndroid Build Coastguard Worker # functionalization 4264*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r._indices(), torch.empty(2, 0, device='meta', dtype=torch.int64)) 4265*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r._values(), torch.empty(0, 4, device='meta', dtype=dtype)) 4266*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.indices(), torch.empty(2, 0, device='meta', dtype=torch.int64)) 4267*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.values(), torch.empty(0, 4, device='meta', dtype=dtype)) 4268*da0073e9SAndroid Build Coastguard Worker 4269*da0073e9SAndroid Build Coastguard Worker def _test_meta_sparse_compressed(self, dtype, layout, batchsize, densesize): 4270*da0073e9SAndroid Build Coastguard Worker index_dtype = torch.int64 4271*da0073e9SAndroid Build Coastguard Worker blocksize = (2, 3) if layout in {torch.sparse_bsr, torch.sparse_bsc} else () 4272*da0073e9SAndroid Build Coastguard Worker sparsesize = (4, 6) 4273*da0073e9SAndroid Build Coastguard Worker nnz = 0 4274*da0073e9SAndroid Build Coastguard Worker 4275*da0073e9SAndroid Build Coastguard Worker shape = (*batchsize, *sparsesize, *densesize) 4276*da0073e9SAndroid Build Coastguard Worker compressed_dim = 0 if layout in {torch.sparse_csr, torch.sparse_bsr} else 1 4277*da0073e9SAndroid Build Coastguard Worker nof_compressed_indices = (sparsesize[compressed_dim] // blocksize[compressed_dim] + 1 if blocksize 4278*da0073e9SAndroid Build Coastguard Worker else sparsesize[compressed_dim] + 1) 4279*da0073e9SAndroid Build Coastguard Worker compressed_indices = torch.empty((*batchsize, nof_compressed_indices), device='meta', dtype=index_dtype) 4280*da0073e9SAndroid Build Coastguard Worker plain_indices = torch.empty((*batchsize, nnz), device='meta', dtype=index_dtype) 4281*da0073e9SAndroid Build Coastguard Worker 4282*da0073e9SAndroid Build Coastguard Worker values = torch.empty((*batchsize, nnz, *blocksize, *densesize), device='meta', dtype=dtype) 4283*da0073e9SAndroid Build Coastguard Worker r = torch.sparse_compressed_tensor( 4284*da0073e9SAndroid Build Coastguard Worker compressed_indices, 4285*da0073e9SAndroid Build Coastguard Worker plain_indices, 4286*da0073e9SAndroid Build Coastguard Worker values, 4287*da0073e9SAndroid Build Coastguard Worker shape, 4288*da0073e9SAndroid Build Coastguard Worker layout=layout 4289*da0073e9SAndroid Build Coastguard Worker ) 4290*da0073e9SAndroid Build Coastguard Worker self.assertTrue(r.is_meta) 4291*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.device.type, "meta") 4292*da0073e9SAndroid Build Coastguard Worker 4293*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.sparse_dim(), 2) 4294*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.dense_dim(), len(densesize)) 4295*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r._nnz(), nnz) 4296*da0073e9SAndroid Build Coastguard Worker batch_dims = r.ndim - r.sparse_dim() - r.dense_dim() 4297*da0073e9SAndroid Build Coastguard Worker r_blocksize = r.values().shape[batch_dims + 1: batch_dims + 1 + len(blocksize)] 4298*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r_blocksize, blocksize) 4299*da0073e9SAndroid Build Coastguard Worker 4300*da0073e9SAndroid Build Coastguard Worker r_compressed_indices = r.crow_indices() if layout in {torch.sparse_csr, torch.sparse_bsr} else r.ccol_indices() 4301*da0073e9SAndroid Build Coastguard Worker r_plain_indices = r.col_indices() if layout in {torch.sparse_csr, torch.sparse_bsr} else r.row_indices() 4302*da0073e9SAndroid Build Coastguard Worker 4303*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r_compressed_indices, 4304*da0073e9SAndroid Build Coastguard Worker torch.empty((*batchsize, nof_compressed_indices), device='meta', dtype=index_dtype)) 4305*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r_plain_indices, torch.empty((*batchsize, nnz), device='meta', dtype=index_dtype)) 4306*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.values(), torch.empty((*batchsize, nnz, *blocksize, *densesize), device='meta', dtype=dtype)) 4307*da0073e9SAndroid Build Coastguard Worker 4308*da0073e9SAndroid Build Coastguard Worker r2 = torch.empty_like(r) 4309*da0073e9SAndroid Build Coastguard Worker self.assertTrue(r2.is_meta) 4310*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r2, r) 4311*da0073e9SAndroid Build Coastguard Worker 4312*da0073e9SAndroid Build Coastguard Worker if layout in {torch.sparse_csr, torch.sparse_csc}: 4313*da0073e9SAndroid Build Coastguard Worker r3 = torch.empty((*batchsize, *sparsesize), dtype=dtype, layout=layout, device="meta") 4314*da0073e9SAndroid Build Coastguard Worker self.assertTrue(r3.is_meta) 4315*da0073e9SAndroid Build Coastguard Worker if not densesize: 4316*da0073e9SAndroid Build Coastguard Worker # dense dimensions cannot be specified for torch.empty 4317*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r3, r) 4318*da0073e9SAndroid Build Coastguard Worker 4319*da0073e9SAndroid Build Coastguard Worker @all_sparse_layouts('layout', include_strided=False) 4320*da0073e9SAndroid Build Coastguard Worker @parametrize("dtype", [torch.float64]) 4321*da0073e9SAndroid Build Coastguard Worker def test_meta(self, dtype, layout): 4322*da0073e9SAndroid Build Coastguard Worker if layout is torch.sparse_coo: 4323*da0073e9SAndroid Build Coastguard Worker self._test_meta_sparse_coo(dtype) 4324*da0073e9SAndroid Build Coastguard Worker else: 4325*da0073e9SAndroid Build Coastguard Worker for batchsize, densesize in itertools.product([(), (2,)], [(), (3,)]): 4326*da0073e9SAndroid Build Coastguard Worker self._test_meta_sparse_compressed(dtype, layout, batchsize, densesize) 4327*da0073e9SAndroid Build Coastguard Worker 4328*da0073e9SAndroid Build Coastguard Worker def _test_print_meta_data(self, dtype, layout, batchsize, sparsesize, densesize): 4329*da0073e9SAndroid Build Coastguard Worker index_dtype = torch.int64 4330*da0073e9SAndroid Build Coastguard Worker nnz = 0 4331*da0073e9SAndroid Build Coastguard Worker blocksize = (2, 3) if layout in {torch.sparse_bsr, torch.sparse_bsc} else () 4332*da0073e9SAndroid Build Coastguard Worker shape = (*batchsize, *sparsesize, *densesize) 4333*da0073e9SAndroid Build Coastguard Worker values = torch.empty((*batchsize, nnz, *blocksize, *densesize), device='meta', dtype=dtype) 4334*da0073e9SAndroid Build Coastguard Worker if layout is torch.sparse_coo: 4335*da0073e9SAndroid Build Coastguard Worker indices = torch.empty((len(sparsesize), nnz), device='meta', dtype=index_dtype) 4336*da0073e9SAndroid Build Coastguard Worker x = torch.sparse_coo_tensor(indices, values, shape) 4337*da0073e9SAndroid Build Coastguard Worker else: 4338*da0073e9SAndroid Build Coastguard Worker compressed_dim = 0 if layout in {torch.sparse_csr, torch.sparse_bsr} else 1 4339*da0073e9SAndroid Build Coastguard Worker nof_compressed_indices = (sparsesize[compressed_dim] // blocksize[compressed_dim] + 1 if blocksize 4340*da0073e9SAndroid Build Coastguard Worker else sparsesize[compressed_dim] + 1) 4341*da0073e9SAndroid Build Coastguard Worker compressed_indices = torch.empty((*batchsize, nof_compressed_indices), device='meta', dtype=index_dtype) 4342*da0073e9SAndroid Build Coastguard Worker plain_indices = torch.empty((*batchsize, nnz), device='meta', dtype=index_dtype) 4343*da0073e9SAndroid Build Coastguard Worker x = torch.sparse_compressed_tensor( 4344*da0073e9SAndroid Build Coastguard Worker compressed_indices, 4345*da0073e9SAndroid Build Coastguard Worker plain_indices, 4346*da0073e9SAndroid Build Coastguard Worker values, 4347*da0073e9SAndroid Build Coastguard Worker shape, 4348*da0073e9SAndroid Build Coastguard Worker layout=layout 4349*da0073e9SAndroid Build Coastguard Worker ) 4350*da0073e9SAndroid Build Coastguard Worker 4351*da0073e9SAndroid Build Coastguard Worker printed = [] 4352*da0073e9SAndroid Build Coastguard Worker printed.append(f"########## {dtype}/{index_dtype}/size={batchsize}+{sparsesize}+{blocksize}+{densesize} ##########") 4353*da0073e9SAndroid Build Coastguard Worker printed.append("# sparse meta tensor") 4354*da0073e9SAndroid Build Coastguard Worker printed.append(str(x)) 4355*da0073e9SAndroid Build Coastguard Worker 4356*da0073e9SAndroid Build Coastguard Worker return printed 4357*da0073e9SAndroid Build Coastguard Worker 4358*da0073e9SAndroid Build Coastguard Worker @all_sparse_layouts('layout', include_strided=False) 4359*da0073e9SAndroid Build Coastguard Worker @parametrize("dtype", [torch.float64]) 4360*da0073e9SAndroid Build Coastguard Worker def test_print_meta(self, dtype, layout): 4361*da0073e9SAndroid Build Coastguard Worker printed = [] 4362*da0073e9SAndroid Build Coastguard Worker for batchsize, sparsesize, densesize in itertools.product( 4363*da0073e9SAndroid Build Coastguard Worker [(), (2,)], [(4, 6), (3, 5, 7)], [(), (3,)] 4364*da0073e9SAndroid Build Coastguard Worker ): 4365*da0073e9SAndroid Build Coastguard Worker if layout is torch.sparse_coo and batchsize: 4366*da0073e9SAndroid Build Coastguard Worker # COO tensors don't have batch dimensions 4367*da0073e9SAndroid Build Coastguard Worker continue 4368*da0073e9SAndroid Build Coastguard Worker if layout is not torch.sparse_coo and len(sparsesize) != 2: 4369*da0073e9SAndroid Build Coastguard Worker # CSR/CSC/BSR/BSC tensors must have 2 sparse dimensions 4370*da0073e9SAndroid Build Coastguard Worker continue 4371*da0073e9SAndroid Build Coastguard Worker printed += self._test_print_meta_data(dtype, layout, batchsize, sparsesize, densesize) 4372*da0073e9SAndroid Build Coastguard Worker 4373*da0073e9SAndroid Build Coastguard Worker orig_maxDiff = self.maxDiff 4374*da0073e9SAndroid Build Coastguard Worker self.maxDiff = None 4375*da0073e9SAndroid Build Coastguard Worker try: 4376*da0073e9SAndroid Build Coastguard Worker self.assertExpected('\n'.join(printed)) 4377*da0073e9SAndroid Build Coastguard Worker self.maxDiff = orig_maxDiff 4378*da0073e9SAndroid Build Coastguard Worker except Exception: 4379*da0073e9SAndroid Build Coastguard Worker self.maxDiff = orig_maxDiff 4380*da0073e9SAndroid Build Coastguard Worker raise 4381*da0073e9SAndroid Build Coastguard Worker 4382*da0073e9SAndroid Build Coastguard Worker def assertEqualMeta(self, x, y, expected_nnz): 4383*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.layout, y.layout) 4384*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.shape, y.shape) 4385*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.dtype, y.dtype) 4386*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.sparse_dim(), y.sparse_dim()) 4387*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.dense_dim(), y.dense_dim()) 4388*da0073e9SAndroid Build Coastguard Worker 4389*da0073e9SAndroid Build Coastguard Worker def assertEqualAttrs(x, y, expected_shape): 4390*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.shape, expected_shape) 4391*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.dtype, y.dtype) 4392*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.layout, y.layout) 4393*da0073e9SAndroid Build Coastguard Worker if not x.is_meta: 4394*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.device, y.device) 4395*da0073e9SAndroid Build Coastguard Worker 4396*da0073e9SAndroid Build Coastguard Worker if x.layout is torch.sparse_coo: 4397*da0073e9SAndroid Build Coastguard Worker assertEqualAttrs(x._indices(), y._indices(), (*y._indices().shape[:-1], expected_nnz)) 4398*da0073e9SAndroid Build Coastguard Worker assertEqualAttrs(x._values(), y._values(), (expected_nnz, *y._values().shape[1:])) 4399*da0073e9SAndroid Build Coastguard Worker elif x.layout in {torch.sparse_csr, torch.sparse_bsr}: 4400*da0073e9SAndroid Build Coastguard Worker assertEqualAttrs(x.crow_indices(), y.crow_indices(), y.crow_indices().shape) 4401*da0073e9SAndroid Build Coastguard Worker assertEqualAttrs(x.col_indices(), y.col_indices(), (*y.col_indices().shape[:-1], expected_nnz)) 4402*da0073e9SAndroid Build Coastguard Worker batch_dim = x.col_indices().ndim - 1 4403*da0073e9SAndroid Build Coastguard Worker values_shape = (*y.values().shape[:batch_dim], expected_nnz, *y.values().shape[batch_dim + 1:]) 4404*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.values().layout, y.values().layout) 4405*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.values().dtype, y.values().dtype) 4406*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.values().shape, values_shape) 4407*da0073e9SAndroid Build Coastguard Worker elif x.layout in {torch.sparse_csc, torch.sparse_bsc}: 4408*da0073e9SAndroid Build Coastguard Worker assertEqualAttrs(x.ccol_indices(), y.ccol_indices(), y.ccol_indices().shape) 4409*da0073e9SAndroid Build Coastguard Worker assertEqualAttrs(x.row_indices(), y.row_indices(), (*y.row_indices().shape[:-1], expected_nnz)) 4410*da0073e9SAndroid Build Coastguard Worker batch_dim = x.row_indices().ndim - 1 4411*da0073e9SAndroid Build Coastguard Worker values_shape = (*y.values().shape[:batch_dim], expected_nnz, *y.values().shape[batch_dim + 1:]) 4412*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.values().layout, y.values().layout) 4413*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.values().dtype, y.values().dtype) 4414*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.values().shape, values_shape) 4415*da0073e9SAndroid Build Coastguard Worker 4416*da0073e9SAndroid Build Coastguard Worker @all_sparse_layouts('layout', include_strided=False) 4417*da0073e9SAndroid Build Coastguard Worker @parametrize("dtype", [torch.float64]) 4418*da0073e9SAndroid Build Coastguard Worker def test_to_meta(self, dtype, layout): 4419*da0073e9SAndroid Build Coastguard Worker index_dtype = torch.int64 4420*da0073e9SAndroid Build Coastguard Worker device = 'cpu' 4421*da0073e9SAndroid Build Coastguard Worker for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype): 4422*da0073e9SAndroid Build Coastguard Worker m = t.to(device="meta") 4423*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.device.type, "meta") 4424*da0073e9SAndroid Build Coastguard Worker self.assertEqualMeta(m, t, 0) 4425*da0073e9SAndroid Build Coastguard Worker 4426*da0073e9SAndroid Build Coastguard Worker @all_sparse_layouts('layout', include_strided=False) 4427*da0073e9SAndroid Build Coastguard Worker @parametrize("dtype", [torch.float64]) 4428*da0073e9SAndroid Build Coastguard Worker def test_zeros_like_meta(self, dtype, layout): 4429*da0073e9SAndroid Build Coastguard Worker index_dtype = torch.int64 4430*da0073e9SAndroid Build Coastguard Worker device = 'cpu' 4431*da0073e9SAndroid Build Coastguard Worker for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype): 4432*da0073e9SAndroid Build Coastguard Worker m = torch.zeros_like(t, device="meta") 4433*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.device.type, "meta") 4434*da0073e9SAndroid Build Coastguard Worker self.assertEqualMeta(m, t, 0) 4435*da0073e9SAndroid Build Coastguard Worker 4436*da0073e9SAndroid Build Coastguard Worker @all_sparse_layouts('layout', include_strided=False) 4437*da0073e9SAndroid Build Coastguard Worker @parametrize("dtype", [torch.float64]) 4438*da0073e9SAndroid Build Coastguard Worker def test_fake(self, dtype, layout): 4439*da0073e9SAndroid Build Coastguard Worker from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor 4440*da0073e9SAndroid Build Coastguard Worker fake_mode = FakeTensorMode() 4441*da0073e9SAndroid Build Coastguard Worker index_dtype = torch.int64 4442*da0073e9SAndroid Build Coastguard Worker device = 'cpu' 4443*da0073e9SAndroid Build Coastguard Worker for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype): 4444*da0073e9SAndroid Build Coastguard Worker f = FakeTensor.from_tensor(t, fake_mode) 4445*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(f, FakeTensor) 4446*da0073e9SAndroid Build Coastguard Worker self.assertEqualMeta(f, t, 0) 4447*da0073e9SAndroid Build Coastguard Worker 4448*da0073e9SAndroid Build Coastguard Worker d = f.detach() 4449*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(d, FakeTensor) 4450*da0073e9SAndroid Build Coastguard Worker self.assertEqualMeta(d, t, 0) 4451*da0073e9SAndroid Build Coastguard Worker 4452*da0073e9SAndroid Build Coastguard Worker @all_sparse_layouts('layout', include_strided=False) 4453*da0073e9SAndroid Build Coastguard Worker @parametrize("dtype", [torch.float64]) 4454*da0073e9SAndroid Build Coastguard Worker def test_zeros_like_fake(self, dtype, layout): 4455*da0073e9SAndroid Build Coastguard Worker from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor 4456*da0073e9SAndroid Build Coastguard Worker from torch.utils._mode_utils import no_dispatch 4457*da0073e9SAndroid Build Coastguard Worker fake_mode = FakeTensorMode() 4458*da0073e9SAndroid Build Coastguard Worker index_dtype = torch.int64 4459*da0073e9SAndroid Build Coastguard Worker device = 'cpu' 4460*da0073e9SAndroid Build Coastguard Worker for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype): 4461*da0073e9SAndroid Build Coastguard Worker f = FakeTensor.from_tensor(t, fake_mode) 4462*da0073e9SAndroid Build Coastguard Worker expected = torch.zeros_like(t) 4463*da0073e9SAndroid Build Coastguard Worker with no_dispatch(): 4464*da0073e9SAndroid Build Coastguard Worker result = torch.zeros_like(f, device=f.fake_device) 4465*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 4466*da0073e9SAndroid Build Coastguard Worker self.assertEqualMeta(result, expected, 0) 4467*da0073e9SAndroid Build Coastguard Worker 4468*da0073e9SAndroid Build Coastguard Worker @all_sparse_layouts('layout', include_strided=False) 4469*da0073e9SAndroid Build Coastguard Worker @parametrize("dtype", [torch.float64]) 4470*da0073e9SAndroid Build Coastguard Worker def test_sum_meta(self, dtype, layout): 4471*da0073e9SAndroid Build Coastguard Worker device = 'cpu' 4472*da0073e9SAndroid Build Coastguard Worker index_dtype = torch.int64 4473*da0073e9SAndroid Build Coastguard Worker for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype): 4474*da0073e9SAndroid Build Coastguard Worker m = t.to(device='meta') 4475*da0073e9SAndroid Build Coastguard Worker r = torch.sum(m) 4476*da0073e9SAndroid Build Coastguard Worker expected = torch.sum(t).to(device="meta") 4477*da0073e9SAndroid Build Coastguard Worker self.assertTrue(r.is_meta) 4478*da0073e9SAndroid Build Coastguard Worker self.assertEqualMeta(r, expected, 0) 4479*da0073e9SAndroid Build Coastguard Worker 4480*da0073e9SAndroid Build Coastguard Worker @all_sparse_layouts('layout', include_strided=False) 4481*da0073e9SAndroid Build Coastguard Worker @parametrize("dtype", [torch.float64]) 4482*da0073e9SAndroid Build Coastguard Worker def test_add_meta(self, dtype, layout): 4483*da0073e9SAndroid Build Coastguard Worker device = 'cpu' 4484*da0073e9SAndroid Build Coastguard Worker index_dtype = torch.int64 4485*da0073e9SAndroid Build Coastguard Worker for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype): 4486*da0073e9SAndroid Build Coastguard Worker expected = torch.add(t, t).to(device='meta') 4487*da0073e9SAndroid Build Coastguard Worker m = t.to(device='meta') 4488*da0073e9SAndroid Build Coastguard Worker r = torch.add(m, m) 4489*da0073e9SAndroid Build Coastguard Worker self.assertEqualMeta(r, expected, 0) 4490*da0073e9SAndroid Build Coastguard Worker 4491*da0073e9SAndroid Build Coastguard Worker 4492*da0073e9SAndroid Build Coastguard Workerclass _SparseDataset(torch.utils.data.Dataset): 4493*da0073e9SAndroid Build Coastguard Worker # An utility class used in TestSparseAny.test_dataloader method. 4494*da0073e9SAndroid Build Coastguard Worker 4495*da0073e9SAndroid Build Coastguard Worker def __init__(self, sparse_tensors): 4496*da0073e9SAndroid Build Coastguard Worker self.sparse_tensors = sparse_tensors 4497*da0073e9SAndroid Build Coastguard Worker 4498*da0073e9SAndroid Build Coastguard Worker def __len__(self): 4499*da0073e9SAndroid Build Coastguard Worker return len(self.sparse_tensors) 4500*da0073e9SAndroid Build Coastguard Worker 4501*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, index): 4502*da0073e9SAndroid Build Coastguard Worker return self.sparse_tensors[index] 4503*da0073e9SAndroid Build Coastguard Worker 4504*da0073e9SAndroid Build Coastguard Worker 4505*da0073e9SAndroid Build Coastguard Workerclass TestSparseAny(TestCase): 4506*da0073e9SAndroid Build Coastguard Worker 4507*da0073e9SAndroid Build Coastguard Worker @onlyCPU 4508*da0073e9SAndroid Build Coastguard Worker @all_sparse_layouts('layout', include_strided=False) 4509*da0073e9SAndroid Build Coastguard Worker @torch.sparse.check_sparse_tensor_invariants(enable=False) 4510*da0073e9SAndroid Build Coastguard Worker def test_check_sparse_tensor_invariants(self, layout): 4511*da0073e9SAndroid Build Coastguard Worker 4512*da0073e9SAndroid Build Coastguard Worker if layout is torch.sparse_coo: 4513*da0073e9SAndroid Build Coastguard Worker 4514*da0073e9SAndroid Build Coastguard Worker def create_invalid_tensor(check_invariants=None): 4515*da0073e9SAndroid Build Coastguard Worker shape = (2, 2) 4516*da0073e9SAndroid Build Coastguard Worker invalid_indices = torch.tensor([[0], [3]]) # column index is out of range 4517*da0073e9SAndroid Build Coastguard Worker values = torch.tensor([1]) 4518*da0073e9SAndroid Build Coastguard Worker if check_invariants is None: 4519*da0073e9SAndroid Build Coastguard Worker return torch.sparse_coo_tensor(invalid_indices, values, shape) 4520*da0073e9SAndroid Build Coastguard Worker else: 4521*da0073e9SAndroid Build Coastguard Worker return torch.sparse_coo_tensor(invalid_indices, values, shape, check_invariants=check_invariants) 4522*da0073e9SAndroid Build Coastguard Worker 4523*da0073e9SAndroid Build Coastguard Worker expected_exception_message = 'size is inconsistent with indices: for dim 1, size is 2 but found index 3' 4524*da0073e9SAndroid Build Coastguard Worker 4525*da0073e9SAndroid Build Coastguard Worker elif layout in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}: 4526*da0073e9SAndroid Build Coastguard Worker 4527*da0073e9SAndroid Build Coastguard Worker def create_invalid_tensor(check_invariants=None): 4528*da0073e9SAndroid Build Coastguard Worker shape = (2, 2) 4529*da0073e9SAndroid Build Coastguard Worker compressed_indices = torch.tensor([0, 0, 1]) 4530*da0073e9SAndroid Build Coastguard Worker invalid_plain_indices = torch.tensor([3]) # index is out of range 4531*da0073e9SAndroid Build Coastguard Worker if layout in {torch.sparse_bsr, torch.sparse_bsc}: 4532*da0073e9SAndroid Build Coastguard Worker values = torch.tensor([[[1]]]) 4533*da0073e9SAndroid Build Coastguard Worker else: 4534*da0073e9SAndroid Build Coastguard Worker values = torch.tensor([1]) 4535*da0073e9SAndroid Build Coastguard Worker if check_invariants is None: 4536*da0073e9SAndroid Build Coastguard Worker return torch.sparse_compressed_tensor(compressed_indices, invalid_plain_indices, values, shape, layout=layout) 4537*da0073e9SAndroid Build Coastguard Worker else: 4538*da0073e9SAndroid Build Coastguard Worker return torch.sparse_compressed_tensor(compressed_indices, invalid_plain_indices, values, shape, layout=layout, 4539*da0073e9SAndroid Build Coastguard Worker check_invariants=check_invariants) 4540*da0073e9SAndroid Build Coastguard Worker 4541*da0073e9SAndroid Build Coastguard Worker if layout in {torch.sparse_csr, torch.sparse_bsr}: 4542*da0073e9SAndroid Build Coastguard Worker expected_exception_message = r'`0 <= col_indices < ncols` is not satisfied.' 4543*da0073e9SAndroid Build Coastguard Worker else: 4544*da0073e9SAndroid Build Coastguard Worker expected_exception_message = r'`0 <= row_indices < nrows` is not satisfied.' 4545*da0073e9SAndroid Build Coastguard Worker 4546*da0073e9SAndroid Build Coastguard Worker else: 4547*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError(layout) 4548*da0073e9SAndroid Build Coastguard Worker 4549*da0073e9SAndroid Build Coastguard Worker # First, consider the case where invariant checks are disabled 4550*da0073e9SAndroid Build Coastguard Worker # "globally" (read: within the context of this test method 4551*da0073e9SAndroid Build Coastguard Worker # caller) as defined by check_sparse_tensor_invariants(False) 4552*da0073e9SAndroid Build Coastguard Worker # decorator: 4553*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled()) 4554*da0073e9SAndroid Build Coastguard Worker 4555*da0073e9SAndroid Build Coastguard Worker # Enable the invariant checks in a local context: 4556*da0073e9SAndroid Build Coastguard Worker with torch.sparse.check_sparse_tensor_invariants(): 4557*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled()) 4558*da0073e9SAndroid Build Coastguard Worker 4559*da0073e9SAndroid Build Coastguard Worker # Leaving the local context must restore the "global" state of 4560*da0073e9SAndroid Build Coastguard Worker # the invariant check feature: 4561*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled()) 4562*da0073e9SAndroid Build Coastguard Worker 4563*da0073e9SAndroid Build Coastguard Worker # Since invariant checks are disabled by default, we can 4564*da0073e9SAndroid Build Coastguard Worker # create an invalid sparse tensor without raising an 4565*da0073e9SAndroid Build Coastguard Worker # exception: 4566*da0073e9SAndroid Build Coastguard Worker r = create_invalid_tensor() 4567*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.layout, layout) 4568*da0073e9SAndroid Build Coastguard Worker 4569*da0073e9SAndroid Build Coastguard Worker # Or, when disabling the invariants check explicitly: 4570*da0073e9SAndroid Build Coastguard Worker r = create_invalid_tensor(check_invariants=False) 4571*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.layout, layout) 4572*da0073e9SAndroid Build Coastguard Worker 4573*da0073e9SAndroid Build Coastguard Worker # Enabling invariant check via constructor's optional argument 4574*da0073e9SAndroid Build Coastguard Worker # will raise an exception when sparse tensor invariants are 4575*da0073e9SAndroid Build Coastguard Worker # violated: 4576*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, expected_exception_message): 4577*da0073e9SAndroid Build Coastguard Worker create_invalid_tensor(check_invariants=True) 4578*da0073e9SAndroid Build Coastguard Worker 4579*da0073e9SAndroid Build Coastguard Worker # Check that the global invariant check flag has been restored 4580*da0073e9SAndroid Build Coastguard Worker # after raising the exception above: 4581*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled()) 4582*da0073e9SAndroid Build Coastguard Worker 4583*da0073e9SAndroid Build Coastguard Worker # Next, consider the case where invariant checks are enabled 4584*da0073e9SAndroid Build Coastguard Worker # within a local context: 4585*da0073e9SAndroid Build Coastguard Worker with torch.sparse.check_sparse_tensor_invariants(): 4586*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled()) 4587*da0073e9SAndroid Build Coastguard Worker 4588*da0073e9SAndroid Build Coastguard Worker # Since invariant checks are now enabled by default, an 4589*da0073e9SAndroid Build Coastguard Worker # attempt to create an invalid sparse tensor will lead to 4590*da0073e9SAndroid Build Coastguard Worker # an exception: 4591*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, expected_exception_message): 4592*da0073e9SAndroid Build Coastguard Worker create_invalid_tensor() 4593*da0073e9SAndroid Build Coastguard Worker 4594*da0073e9SAndroid Build Coastguard Worker # Similarly, when enabling the invariant checks 4595*da0073e9SAndroid Build Coastguard Worker # explicitly, invalid sparse tensor construction will lead 4596*da0073e9SAndroid Build Coastguard Worker # to an exception: 4597*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, expected_exception_message): 4598*da0073e9SAndroid Build Coastguard Worker create_invalid_tensor(check_invariants=True) 4599*da0073e9SAndroid Build Coastguard Worker 4600*da0073e9SAndroid Build Coastguard Worker # However, invariants check can be disabled via 4601*da0073e9SAndroid Build Coastguard Worker # constructor's optional argument so that the invalid 4602*da0073e9SAndroid Build Coastguard Worker # tensor is succesfully constructed: 4603*da0073e9SAndroid Build Coastguard Worker r = create_invalid_tensor(check_invariants=False) 4604*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.layout, layout) 4605*da0073e9SAndroid Build Coastguard Worker 4606*da0073e9SAndroid Build Coastguard Worker # Check that the invariant check flag has been restored 4607*da0073e9SAndroid Build Coastguard Worker # when leaving the constructor: 4608*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled()) 4609*da0073e9SAndroid Build Coastguard Worker 4610*da0073e9SAndroid Build Coastguard Worker # Double-check restoring the global state when leaving the 4611*da0073e9SAndroid Build Coastguard Worker # local context: 4612*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled()) 4613*da0073e9SAndroid Build Coastguard Worker 4614*da0073e9SAndroid Build Coastguard Worker # Test nesting of pre-defined context managers 4615*da0073e9SAndroid Build Coastguard Worker check_ctx = torch.sparse.check_sparse_tensor_invariants(True) 4616*da0073e9SAndroid Build Coastguard Worker no_check_ctx = torch.sparse.check_sparse_tensor_invariants(False) 4617*da0073e9SAndroid Build Coastguard Worker with check_ctx: 4618*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled()) 4619*da0073e9SAndroid Build Coastguard Worker with no_check_ctx: 4620*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled()) 4621*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled()) 4622*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled()) 4623*da0073e9SAndroid Build Coastguard Worker 4624*da0073e9SAndroid Build Coastguard Worker # Test an attempt to re-use an activate context manager instance 4625*da0073e9SAndroid Build Coastguard Worker check_ctx2 = torch.sparse.check_sparse_tensor_invariants(True) 4626*da0073e9SAndroid Build Coastguard Worker with check_ctx: 4627*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled()) 4628*da0073e9SAndroid Build Coastguard Worker with no_check_ctx: 4629*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled()) 4630*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "This context manager instance is already activated." 4631*da0073e9SAndroid Build Coastguard Worker " Use a different context manager instance for context nesting"): 4632*da0073e9SAndroid Build Coastguard Worker with check_ctx: 4633*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled()) 4634*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled()) 4635*da0073e9SAndroid Build Coastguard Worker with check_ctx2: 4636*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled()) 4637*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled()) 4638*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled()) 4639*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled()) 4640*da0073e9SAndroid Build Coastguard Worker 4641*da0073e9SAndroid Build Coastguard Worker def test_generate_simple_inputs(self): 4642*da0073e9SAndroid Build Coastguard Worker layouts = [torch.strided, torch.sparse_coo, torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc] 4643*da0073e9SAndroid Build Coastguard Worker 4644*da0073e9SAndroid Build Coastguard Worker tested_combinations = set() 4645*da0073e9SAndroid Build Coastguard Worker for tensors in zip(*map(self.generate_simple_inputs, layouts)): 4646*da0073e9SAndroid Build Coastguard Worker for i, t in enumerate(tensors): 4647*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.layout, layouts[i]) 4648*da0073e9SAndroid Build Coastguard Worker 4649*da0073e9SAndroid Build Coastguard Worker # all layouts must produce semantically the same tensors 4650*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, tensors[0]) 4651*da0073e9SAndroid Build Coastguard Worker 4652*da0073e9SAndroid Build Coastguard Worker if t.layout is torch.strided: 4653*da0073e9SAndroid Build Coastguard Worker is_hybrid = None 4654*da0073e9SAndroid Build Coastguard Worker else: 4655*da0073e9SAndroid Build Coastguard Worker is_hybrid = t.dense_dim() > 0 4656*da0073e9SAndroid Build Coastguard Worker if t.layout in {torch.sparse_csr, torch.sparse_bsr}: 4657*da0073e9SAndroid Build Coastguard Worker is_batch = t.crow_indices().ndim > 1 4658*da0073e9SAndroid Build Coastguard Worker elif t.layout in {torch.sparse_csc, torch.sparse_bsc}: 4659*da0073e9SAndroid Build Coastguard Worker is_batch = t.ccol_indices().ndim > 1 4660*da0073e9SAndroid Build Coastguard Worker else: 4661*da0073e9SAndroid Build Coastguard Worker is_batch = None 4662*da0073e9SAndroid Build Coastguard Worker if t.layout in {torch.sparse_bsr, torch.sparse_bsc}: 4663*da0073e9SAndroid Build Coastguard Worker blocksize = t.values().shape[1:3] 4664*da0073e9SAndroid Build Coastguard Worker nontrivial_blocksize = 1 not in blocksize 4665*da0073e9SAndroid Build Coastguard Worker else: 4666*da0073e9SAndroid Build Coastguard Worker nontrivial_blocksize = None 4667*da0073e9SAndroid Build Coastguard Worker if t.layout in {torch.sparse_csr, torch.sparse_bsr}: 4668*da0073e9SAndroid Build Coastguard Worker contiguous_indices = t.crow_indices().is_contiguous() and t.col_indices().is_contiguous() 4669*da0073e9SAndroid Build Coastguard Worker contiguous_values = t.values().is_contiguous() 4670*da0073e9SAndroid Build Coastguard Worker elif t.layout in {torch.sparse_csc, torch.sparse_bsc}: 4671*da0073e9SAndroid Build Coastguard Worker contiguous_indices = t.ccol_indices().is_contiguous() and t.row_indices().is_contiguous() 4672*da0073e9SAndroid Build Coastguard Worker contiguous_values = t.values().is_contiguous() 4673*da0073e9SAndroid Build Coastguard Worker elif t.layout is torch.sparse_coo: 4674*da0073e9SAndroid Build Coastguard Worker contiguous_indices = t._indices().is_contiguous() 4675*da0073e9SAndroid Build Coastguard Worker contiguous_values = t._values().is_contiguous() 4676*da0073e9SAndroid Build Coastguard Worker else: 4677*da0073e9SAndroid Build Coastguard Worker contiguous_indices = None 4678*da0073e9SAndroid Build Coastguard Worker contiguous_values = t.is_contiguous() 4679*da0073e9SAndroid Build Coastguard Worker 4680*da0073e9SAndroid Build Coastguard Worker tested_combinations.add((t.layout, is_hybrid, is_batch, nontrivial_blocksize, 4681*da0073e9SAndroid Build Coastguard Worker contiguous_indices, contiguous_values)) 4682*da0073e9SAndroid Build Coastguard Worker 4683*da0073e9SAndroid Build Coastguard Worker # Ensure that the inputs generation covers all layout, 4684*da0073e9SAndroid Build Coastguard Worker # non-hybrid/hybrid, non-batch/batch, and contiguity 4685*da0073e9SAndroid Build Coastguard Worker # combinations: 4686*da0073e9SAndroid Build Coastguard Worker untested_combinations = set() 4687*da0073e9SAndroid Build Coastguard Worker for layout in layouts: 4688*da0073e9SAndroid Build Coastguard Worker for is_hybrid in [False, True]: 4689*da0073e9SAndroid Build Coastguard Worker if layout is torch.strided: 4690*da0073e9SAndroid Build Coastguard Worker is_hybrid = None 4691*da0073e9SAndroid Build Coastguard Worker for is_batch in [False, True]: 4692*da0073e9SAndroid Build Coastguard Worker if layout in {torch.sparse_coo, torch.strided}: 4693*da0073e9SAndroid Build Coastguard Worker is_batch = None 4694*da0073e9SAndroid Build Coastguard Worker for nontrivial_blocksize in [False, True]: 4695*da0073e9SAndroid Build Coastguard Worker if layout not in {torch.sparse_bsr, torch.sparse_bsc}: 4696*da0073e9SAndroid Build Coastguard Worker nontrivial_blocksize = None 4697*da0073e9SAndroid Build Coastguard Worker for contiguous_indices in [False, True]: 4698*da0073e9SAndroid Build Coastguard Worker if layout is torch.strided: 4699*da0073e9SAndroid Build Coastguard Worker contiguous_indices = None 4700*da0073e9SAndroid Build Coastguard Worker elif not is_batch: 4701*da0073e9SAndroid Build Coastguard Worker # indices are contiguous per-patch 4702*da0073e9SAndroid Build Coastguard Worker contiguous_indices = True 4703*da0073e9SAndroid Build Coastguard Worker for contiguous_values in [False, True]: 4704*da0073e9SAndroid Build Coastguard Worker key = (layout, is_hybrid, is_batch, nontrivial_blocksize, 4705*da0073e9SAndroid Build Coastguard Worker contiguous_indices, contiguous_values) 4706*da0073e9SAndroid Build Coastguard Worker if key not in tested_combinations: 4707*da0073e9SAndroid Build Coastguard Worker untested_combinations.add( 4708*da0073e9SAndroid Build Coastguard Worker f'layout={layout}, is_hybrid={is_hybrid}, is_batch={is_batch},' 4709*da0073e9SAndroid Build Coastguard Worker f' nontrivial_blocksize={nontrivial_blocksize},' 4710*da0073e9SAndroid Build Coastguard Worker f' contiguous_indices{contiguous_indices}, contiguous_values={contiguous_values}') 4711*da0073e9SAndroid Build Coastguard Worker assert not untested_combinations, untested_combinations 4712*da0073e9SAndroid Build Coastguard Worker 4713*da0073e9SAndroid Build Coastguard Worker @all_sparse_layouts('layout', include_strided=False) 4714*da0073e9SAndroid Build Coastguard Worker def test_constructor_autograd(self, device, layout): 4715*da0073e9SAndroid Build Coastguard Worker 4716*da0073e9SAndroid Build Coastguard Worker def specific_constructor(*args, **kwargs): 4717*da0073e9SAndroid Build Coastguard Worker if layout is torch.sparse_csr: 4718*da0073e9SAndroid Build Coastguard Worker return torch.sparse_csr_tensor(*args, **kwargs) 4719*da0073e9SAndroid Build Coastguard Worker elif layout is torch.sparse_csc: 4720*da0073e9SAndroid Build Coastguard Worker return torch.sparse_csc_tensor(*args, **kwargs) 4721*da0073e9SAndroid Build Coastguard Worker elif layout is torch.sparse_bsc: 4722*da0073e9SAndroid Build Coastguard Worker return torch.sparse_bsc_tensor(*args, **kwargs) 4723*da0073e9SAndroid Build Coastguard Worker elif layout is torch.sparse_bsr: 4724*da0073e9SAndroid Build Coastguard Worker return torch.sparse_bsr_tensor(*args, **kwargs) 4725*da0073e9SAndroid Build Coastguard Worker elif layout is torch.sparse_coo: 4726*da0073e9SAndroid Build Coastguard Worker return torch.sparse_coo_tensor(*args, **kwargs) 4727*da0073e9SAndroid Build Coastguard Worker else: 4728*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError(layout) 4729*da0073e9SAndroid Build Coastguard Worker 4730*da0073e9SAndroid Build Coastguard Worker def generic_constructor(*args, **kwargs): 4731*da0073e9SAndroid Build Coastguard Worker if layout in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}: 4732*da0073e9SAndroid Build Coastguard Worker kwargs.update(layout=layout) 4733*da0073e9SAndroid Build Coastguard Worker return torch.sparse_compressed_tensor(*args, **kwargs) 4734*da0073e9SAndroid Build Coastguard Worker elif layout is torch.sparse_coo: 4735*da0073e9SAndroid Build Coastguard Worker return torch.sparse_coo_tensor(*args, **kwargs) 4736*da0073e9SAndroid Build Coastguard Worker else: 4737*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError(layout) 4738*da0073e9SAndroid Build Coastguard Worker 4739*da0073e9SAndroid Build Coastguard Worker if layout is torch.sparse_coo: 4740*da0073e9SAndroid Build Coastguard Worker constructors = (specific_constructor,) 4741*da0073e9SAndroid Build Coastguard Worker else: 4742*da0073e9SAndroid Build Coastguard Worker constructors = (specific_constructor, generic_constructor) 4743*da0073e9SAndroid Build Coastguard Worker 4744*da0073e9SAndroid Build Coastguard Worker for args, kwargs in self.generate_simple_inputs( 4745*da0073e9SAndroid Build Coastguard Worker layout, device=device, dtype=torch.float64, 4746*da0073e9SAndroid Build Coastguard Worker enable_batch=False, # TODO: remove after gh-104868 is resolved 4747*da0073e9SAndroid Build Coastguard Worker output_tensor=False): 4748*da0073e9SAndroid Build Coastguard Worker values_offset = 1 if layout is torch.sparse_coo else 2 4749*da0073e9SAndroid Build Coastguard Worker 4750*da0073e9SAndroid Build Coastguard Worker for cnstr in constructors: 4751*da0073e9SAndroid Build Coastguard Worker for requires_grad in (False, True): 4752*da0073e9SAndroid Build Coastguard Worker values = args[values_offset].detach().requires_grad_(requires_grad) 4753*da0073e9SAndroid Build Coastguard Worker args = (*args[:values_offset], values, *args[values_offset + 1:]) 4754*da0073e9SAndroid Build Coastguard Worker kwargs_ = dict(kwargs) 4755*da0073e9SAndroid Build Coastguard Worker args_ = args + (kwargs_.pop('size'),) 4756*da0073e9SAndroid Build Coastguard Worker 4757*da0073e9SAndroid Build Coastguard Worker sparse = cnstr(*args, **kwargs) 4758*da0073e9SAndroid Build Coastguard Worker 4759*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sparse.requires_grad, requires_grad) 4760*da0073e9SAndroid Build Coastguard Worker 4761*da0073e9SAndroid Build Coastguard Worker if requires_grad: 4762*da0073e9SAndroid Build Coastguard Worker for masked in (False, True): 4763*da0073e9SAndroid Build Coastguard Worker if layout is torch.sparse_coo: 4764*da0073e9SAndroid Build Coastguard Worker torch.autograd.gradcheck( 4765*da0073e9SAndroid Build Coastguard Worker lambda i, v: cnstr(i, v, **kwargs).to_dense(masked_grad=masked), 4766*da0073e9SAndroid Build Coastguard Worker args, masked=masked) 4767*da0073e9SAndroid Build Coastguard Worker torch.autograd.gradcheck( 4768*da0073e9SAndroid Build Coastguard Worker lambda i, v, sz: cnstr(i, v, sz, **kwargs_).to_dense(masked_grad=masked), 4769*da0073e9SAndroid Build Coastguard Worker args_, masked=masked) 4770*da0073e9SAndroid Build Coastguard Worker else: 4771*da0073e9SAndroid Build Coastguard Worker if layout in {torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc} and 0: 4772*da0073e9SAndroid Build Coastguard Worker # TODO: remove this if-block after gh-107370 is resolved 4773*da0073e9SAndroid Build Coastguard Worker continue 4774*da0073e9SAndroid Build Coastguard Worker torch.autograd.gradcheck( 4775*da0073e9SAndroid Build Coastguard Worker lambda ci, pi, v: cnstr(ci, pi, v, **kwargs).to_dense(masked_grad=masked), 4776*da0073e9SAndroid Build Coastguard Worker args, masked=masked) 4777*da0073e9SAndroid Build Coastguard Worker torch.autograd.gradcheck( 4778*da0073e9SAndroid Build Coastguard Worker lambda ci, pi, v, sz: cnstr(ci, pi, v, sz, **kwargs_).to_dense(masked_grad=masked), 4779*da0073e9SAndroid Build Coastguard Worker args_, masked=masked) 4780*da0073e9SAndroid Build Coastguard Worker 4781*da0073e9SAndroid Build Coastguard Worker @all_sparse_layouts('from_layout', include_strided=False) 4782*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 4783*da0073e9SAndroid Build Coastguard Worker @parametrize("index_dtype", [torch.int32, torch.int64]) 4784*da0073e9SAndroid Build Coastguard Worker def test_to_dense(self, from_layout, device, dtype, index_dtype): 4785*da0073e9SAndroid Build Coastguard Worker """ 4786*da0073e9SAndroid Build Coastguard Worker This test tests conversion from any layout to strided layout. 4787*da0073e9SAndroid Build Coastguard Worker """ 4788*da0073e9SAndroid Build Coastguard Worker for t in self.generate_simple_inputs( 4789*da0073e9SAndroid Build Coastguard Worker from_layout, device=device, dtype=dtype, index_dtype=index_dtype): 4790*da0073e9SAndroid Build Coastguard Worker r = t.to_dense() 4791*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.layout, torch.strided) 4792*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, t) 4793*da0073e9SAndroid Build Coastguard Worker 4794*da0073e9SAndroid Build Coastguard Worker @all_sparse_layouts('from_layout', include_strided=False) 4795*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float64, torch.complex128) 4796*da0073e9SAndroid Build Coastguard Worker @parametrize("index_dtype", [torch.int64]) 4797*da0073e9SAndroid Build Coastguard Worker @gradcheck_semantics() 4798*da0073e9SAndroid Build Coastguard Worker def test_gradcheck_to_dense(self, from_layout, device, dtype, index_dtype, gradcheck): 4799*da0073e9SAndroid Build Coastguard Worker for t in self.generate_simple_inputs( 4800*da0073e9SAndroid Build Coastguard Worker from_layout, device=device, dtype=dtype, index_dtype=index_dtype): 4801*da0073e9SAndroid Build Coastguard Worker batch_dim = t.dim() - t.dense_dim() - t.sparse_dim() 4802*da0073e9SAndroid Build Coastguard Worker if batch_dim > 0: 4803*da0073e9SAndroid Build Coastguard Worker # TODO: implement batch support in _convert_indices_from_csr_to_coo 4804*da0073e9SAndroid Build Coastguard Worker continue 4805*da0073e9SAndroid Build Coastguard Worker t = t.clone().detach().requires_grad_(True) 4806*da0073e9SAndroid Build Coastguard Worker r = gradcheck(lambda x: torch.Tensor.to_dense(x, masked_grad=gradcheck.masked), t) 4807*da0073e9SAndroid Build Coastguard Worker self.assertTrue(r) 4808*da0073e9SAndroid Build Coastguard Worker 4809*da0073e9SAndroid Build Coastguard Worker @all_sparse_layouts('from_layout', include_strided=True) 4810*da0073e9SAndroid Build Coastguard Worker @all_sparse_layouts('to_layout', include_strided=False) 4811*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 4812*da0073e9SAndroid Build Coastguard Worker @parametrize("index_dtype", [torch.int32, torch.int64]) 4813*da0073e9SAndroid Build Coastguard Worker def test_to_sparse(self, from_layout, to_layout, device, dtype, index_dtype): 4814*da0073e9SAndroid Build Coastguard Worker """ 4815*da0073e9SAndroid Build Coastguard Worker This test tests conversion from any layout to any sparse layout. 4816*da0073e9SAndroid Build Coastguard Worker """ 4817*da0073e9SAndroid Build Coastguard Worker for t in self.generate_simple_inputs( 4818*da0073e9SAndroid Build Coastguard Worker from_layout, device=device, dtype=dtype, index_dtype=index_dtype, 4819*da0073e9SAndroid Build Coastguard Worker enable_hybrid=( 4820*da0073e9SAndroid Build Coastguard Worker # TODO: to support conversion strided->hybrid 4821*da0073e9SAndroid Build Coastguard Worker # CSR/CSC/BSR/BSC, to_sparse() requires extra keyword 4822*da0073e9SAndroid Build Coastguard Worker # argument, either nof_batch_dims or 4823*da0073e9SAndroid Build Coastguard Worker # nof_dense_dims 4824*da0073e9SAndroid Build Coastguard Worker not (from_layout is torch.strided and to_layout in 4825*da0073e9SAndroid Build Coastguard Worker {torch.sparse_bsr, torch.sparse_bsc, torch.sparse_csr, torch.sparse_csc}))): 4826*da0073e9SAndroid Build Coastguard Worker 4827*da0073e9SAndroid Build Coastguard Worker if to_layout in {torch.sparse_bsr, torch.sparse_bsc}: 4828*da0073e9SAndroid Build Coastguard Worker if from_layout == torch.sparse_bsr: 4829*da0073e9SAndroid Build Coastguard Worker batch_ndim = t.crow_indices().dim() - 1 4830*da0073e9SAndroid Build Coastguard Worker blocksize = t.values().shape[batch_ndim + 1:batch_ndim + 3] 4831*da0073e9SAndroid Build Coastguard Worker elif from_layout == torch.sparse_bsc: 4832*da0073e9SAndroid Build Coastguard Worker batch_ndim = t.ccol_indices().dim() - 1 4833*da0073e9SAndroid Build Coastguard Worker blocksize = t.values().shape[batch_ndim + 1:batch_ndim + 3] 4834*da0073e9SAndroid Build Coastguard Worker else: 4835*da0073e9SAndroid Build Coastguard Worker blocksize = (1, 1) 4836*da0073e9SAndroid Build Coastguard Worker else: 4837*da0073e9SAndroid Build Coastguard Worker blocksize = None 4838*da0073e9SAndroid Build Coastguard Worker 4839*da0073e9SAndroid Build Coastguard Worker if from_layout is torch.strided: 4840*da0073e9SAndroid Build Coastguard Worker is_batch = None 4841*da0073e9SAndroid Build Coastguard Worker is_hybrid = None 4842*da0073e9SAndroid Build Coastguard Worker else: 4843*da0073e9SAndroid Build Coastguard Worker is_batch = t.dim() > (t.sparse_dim() + t.dense_dim()) 4844*da0073e9SAndroid Build Coastguard Worker is_hybrid = t.dense_dim() > 0 4845*da0073e9SAndroid Build Coastguard Worker 4846*da0073e9SAndroid Build Coastguard Worker def explicit_to_sparse(x): 4847*da0073e9SAndroid Build Coastguard Worker # Used to check that the explicit conversion methods 4848*da0073e9SAndroid Build Coastguard Worker # are consistent with the `to_sparse(*, layout, 4849*da0073e9SAndroid Build Coastguard Worker # blocksize)` method. 4850*da0073e9SAndroid Build Coastguard Worker if to_layout is torch.sparse_coo: 4851*da0073e9SAndroid Build Coastguard Worker return x.to_sparse_coo() 4852*da0073e9SAndroid Build Coastguard Worker elif to_layout is torch.sparse_csr: 4853*da0073e9SAndroid Build Coastguard Worker return x.to_sparse_csr() 4854*da0073e9SAndroid Build Coastguard Worker elif to_layout is torch.sparse_csc: 4855*da0073e9SAndroid Build Coastguard Worker return x.to_sparse_csc() 4856*da0073e9SAndroid Build Coastguard Worker elif to_layout is torch.sparse_bsr: 4857*da0073e9SAndroid Build Coastguard Worker return x.to_sparse_bsr(blocksize) 4858*da0073e9SAndroid Build Coastguard Worker elif to_layout is torch.sparse_bsc: 4859*da0073e9SAndroid Build Coastguard Worker return x.to_sparse_bsc(blocksize) 4860*da0073e9SAndroid Build Coastguard Worker else: 4861*da0073e9SAndroid Build Coastguard Worker assert 0 # unreachable 4862*da0073e9SAndroid Build Coastguard Worker 4863*da0073e9SAndroid Build Coastguard Worker # TODO: The following exception cases all correspond to 4864*da0073e9SAndroid Build Coastguard Worker # not implemented conversions 4865*da0073e9SAndroid Build Coastguard Worker if from_layout in { 4866*da0073e9SAndroid Build Coastguard Worker torch.sparse_csr, torch.sparse_csc} and to_layout in {torch.sparse_bsr, torch.sparse_bsc} and is_batch: 4867*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4868*da0073e9SAndroid Build Coastguard Worker RuntimeError, 4869*da0073e9SAndroid Build Coastguard Worker r"conversion from Sparse(Csr|Csc) to Sparse(Bsr|Bsc) for batched inputs is not supported"): 4870*da0073e9SAndroid Build Coastguard Worker t.to_sparse(layout=to_layout, blocksize=blocksize) 4871*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4872*da0073e9SAndroid Build Coastguard Worker RuntimeError, 4873*da0073e9SAndroid Build Coastguard Worker r"conversion from Sparse(Csr|Csc) to Sparse(Bsr|Bsc) for batched inputs is not supported"): 4874*da0073e9SAndroid Build Coastguard Worker explicit_to_sparse(t) 4875*da0073e9SAndroid Build Coastguard Worker continue 4876*da0073e9SAndroid Build Coastguard Worker elif from_layout is torch.sparse_coo and to_layout in { 4877*da0073e9SAndroid Build Coastguard Worker torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc} and t.sparse_dim() != 2: 4878*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4879*da0073e9SAndroid Build Coastguard Worker RuntimeError, 4880*da0073e9SAndroid Build Coastguard Worker r"conversion from Sparse to .* for input tensors with sparse_dim\(\)!=2 is not supported"): 4881*da0073e9SAndroid Build Coastguard Worker t.to_sparse(layout=to_layout, blocksize=blocksize) 4882*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4883*da0073e9SAndroid Build Coastguard Worker RuntimeError, 4884*da0073e9SAndroid Build Coastguard Worker r"conversion from Sparse to .* for input tensors with sparse_dim\(\)!=2 is not supported"): 4885*da0073e9SAndroid Build Coastguard Worker explicit_to_sparse(t) 4886*da0073e9SAndroid Build Coastguard Worker continue 4887*da0073e9SAndroid Build Coastguard Worker elif (from_layout, to_layout) in {(torch.sparse_bsc, torch.sparse_csr), (torch.sparse_bsc, torch.sparse_csc), 4888*da0073e9SAndroid Build Coastguard Worker (torch.sparse_bsr, torch.sparse_csr), (torch.sparse_bsr, torch.sparse_csc)}: 4889*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4890*da0073e9SAndroid Build Coastguard Worker RuntimeError, 4891*da0073e9SAndroid Build Coastguard Worker r"sparse_compressed_to_sparse_(csr|csc|bsr|bsc): expected\s*(Sparse(Csc|Csr)[,]|)\s*Sparse(Csr|Bsr)" 4892*da0073e9SAndroid Build Coastguard Worker " or Sparse(Csc|Bsc) layout but got Sparse(Csr|Csc|Bsr|Bsc)"): 4893*da0073e9SAndroid Build Coastguard Worker t.to_sparse(layout=to_layout, blocksize=blocksize) 4894*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4895*da0073e9SAndroid Build Coastguard Worker RuntimeError, 4896*da0073e9SAndroid Build Coastguard Worker r"sparse_compressed_to_sparse_(csr|csc|bsr|bsc): expected\s*(Sparse(Csc|Csr)[,]|)\s*Sparse(Csr|Bsr)" 4897*da0073e9SAndroid Build Coastguard Worker " or Sparse(Csc|Bsc) layout but got Sparse(Csr|Csc|Bsr|Bsc)"): 4898*da0073e9SAndroid Build Coastguard Worker explicit_to_sparse(t) 4899*da0073e9SAndroid Build Coastguard Worker self.skipTest('NOT IMPL') 4900*da0073e9SAndroid Build Coastguard Worker else: 4901*da0073e9SAndroid Build Coastguard Worker r = t.to_sparse(layout=to_layout, blocksize=blocksize) 4902*da0073e9SAndroid Build Coastguard Worker 4903*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.layout, to_layout) 4904*da0073e9SAndroid Build Coastguard Worker 4905*da0073e9SAndroid Build Coastguard Worker # to_sparse method uses unsafe construction of sparse 4906*da0073e9SAndroid Build Coastguard Worker # tensors. Here we explicitly validate the results to 4907*da0073e9SAndroid Build Coastguard Worker # make sure that the sparse tensors are consistent 4908*da0073e9SAndroid Build Coastguard Worker # with the corresponding sparse tensor invariants. 4909*da0073e9SAndroid Build Coastguard Worker if r.layout in {torch.sparse_csr, torch.sparse_bsr, torch.sparse_csc, torch.sparse_bsc}: 4910*da0073e9SAndroid Build Coastguard Worker if r.layout in {torch.sparse_csr, torch.sparse_bsr}: 4911*da0073e9SAndroid Build Coastguard Worker compressed_indices, plain_indices = r.crow_indices(), r.col_indices() 4912*da0073e9SAndroid Build Coastguard Worker else: 4913*da0073e9SAndroid Build Coastguard Worker compressed_indices, plain_indices = r.ccol_indices(), r.row_indices() 4914*da0073e9SAndroid Build Coastguard Worker torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, r.values(), 4915*da0073e9SAndroid Build Coastguard Worker r.shape, r.layout) 4916*da0073e9SAndroid Build Coastguard Worker if from_layout in {torch.strided, torch.sparse_coo}: 4917*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compressed_indices.dtype, torch.int64) 4918*da0073e9SAndroid Build Coastguard Worker self.assertEqual(plain_indices.dtype, torch.int64) 4919*da0073e9SAndroid Build Coastguard Worker else: 4920*da0073e9SAndroid Build Coastguard Worker self.assertEqual(compressed_indices.dtype, index_dtype) 4921*da0073e9SAndroid Build Coastguard Worker self.assertEqual(plain_indices.dtype, index_dtype) 4922*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.values().dtype, dtype) 4923*da0073e9SAndroid Build Coastguard Worker elif r.layout is torch.sparse_coo: 4924*da0073e9SAndroid Build Coastguard Worker if t.layout is torch.sparse_coo: 4925*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.is_coalesced(), r.is_coalesced()) 4926*da0073e9SAndroid Build Coastguard Worker 4927*da0073e9SAndroid Build Coastguard Worker # Check r is truly coalesced when r.is_coalesced == True 4928*da0073e9SAndroid Build Coastguard Worker if r.is_coalesced(): 4929*da0073e9SAndroid Build Coastguard Worker self.assertTrue(is_coalesced_indices(r)) 4930*da0073e9SAndroid Build Coastguard Worker 4931*da0073e9SAndroid Build Coastguard Worker torch._validate_sparse_coo_tensor_args(r._indices(), r._values(), r.shape) 4932*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r._indices().dtype, torch.int64) 4933*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r._values().dtype, dtype) 4934*da0073e9SAndroid Build Coastguard Worker else: 4935*da0073e9SAndroid Build Coastguard Worker assert 0 # unreachable 4936*da0073e9SAndroid Build Coastguard Worker 4937*da0073e9SAndroid Build Coastguard Worker # Finally, we'll test tensor equality: 4938*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, t) 4939*da0073e9SAndroid Build Coastguard Worker 4940*da0073e9SAndroid Build Coastguard Worker # Also, check consistency with explicit conversion methods: 4941*da0073e9SAndroid Build Coastguard Worker r2 = explicit_to_sparse(t) 4942*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r2, r) 4943*da0073e9SAndroid Build Coastguard Worker 4944*da0073e9SAndroid Build Coastguard Worker # Check inverse conversion from sparse compressed block tensors 4945*da0073e9SAndroid Build Coastguard Worker if from_layout == torch.sparse_bsr: 4946*da0073e9SAndroid Build Coastguard Worker batch_ndim = t.crow_indices().dim() - 1 4947*da0073e9SAndroid Build Coastguard Worker from_blocksize = t.values().shape[batch_ndim + 1:batch_ndim + 3] 4948*da0073e9SAndroid Build Coastguard Worker elif from_layout == torch.sparse_bsc: 4949*da0073e9SAndroid Build Coastguard Worker batch_ndim = t.ccol_indices().dim() - 1 4950*da0073e9SAndroid Build Coastguard Worker from_blocksize = t.values().shape[batch_ndim + 1:batch_ndim + 3] 4951*da0073e9SAndroid Build Coastguard Worker else: 4952*da0073e9SAndroid Build Coastguard Worker continue 4953*da0073e9SAndroid Build Coastguard Worker if r.ndim != 2: 4954*da0073e9SAndroid Build Coastguard Worker continue 4955*da0073e9SAndroid Build Coastguard Worker 4956*da0073e9SAndroid Build Coastguard Worker t2 = r.to_sparse(layout=from_layout, blocksize=from_blocksize) 4957*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t2, t) 4958*da0073e9SAndroid Build Coastguard Worker 4959*da0073e9SAndroid Build Coastguard Worker # extra tests 4960*da0073e9SAndroid Build Coastguard Worker if (from_layout, to_layout) == (torch.sparse_csr, torch.sparse_bsr): 4961*da0073e9SAndroid Build Coastguard Worker # See gh-90910 4962*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([[0, 0, 1, 0], [0, 1, 0, 0]], dtype=dtype, device=device).to_sparse_csr() 4963*da0073e9SAndroid Build Coastguard Worker r = t.to_sparse_bsr((2, 2)) 4964*da0073e9SAndroid Build Coastguard Worker torch._validate_sparse_compressed_tensor_args(r.crow_indices(), r.col_indices(), r.values(), r.shape, r.layout) 4965*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, t) 4966*da0073e9SAndroid Build Coastguard Worker 4967*da0073e9SAndroid Build Coastguard Worker if (from_layout, to_layout) in {(torch.sparse_csr, torch.sparse_csc), 4968*da0073e9SAndroid Build Coastguard Worker (torch.sparse_csc, torch.sparse_csr)}: 4969*da0073e9SAndroid Build Coastguard Worker # See gh-91007 4970*da0073e9SAndroid Build Coastguard Worker compressed_indices = torch.tensor([0, 4, 8, 8, 12, 16, 20], dtype=index_dtype, device=device) 4971*da0073e9SAndroid Build Coastguard Worker plain_indices = torch.tensor([0, 1, 2, 3] * 5, dtype=index_dtype, device=device) 4972*da0073e9SAndroid Build Coastguard Worker t = torch.sparse_compressed_tensor(compressed_indices, plain_indices, range(20), 4973*da0073e9SAndroid Build Coastguard Worker dtype=dtype, device=device, layout=from_layout) 4974*da0073e9SAndroid Build Coastguard Worker r = t.to_sparse(layout=to_layout) 4975*da0073e9SAndroid Build Coastguard Worker if r.layout in {torch.sparse_csr, torch.sparse_bsr}: 4976*da0073e9SAndroid Build Coastguard Worker compressed_indices, plain_indices = r.crow_indices(), r.col_indices() 4977*da0073e9SAndroid Build Coastguard Worker else: 4978*da0073e9SAndroid Build Coastguard Worker compressed_indices, plain_indices = r.ccol_indices(), r.row_indices() 4979*da0073e9SAndroid Build Coastguard Worker torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, r.values(), r.shape, r.layout) 4980*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, t) 4981*da0073e9SAndroid Build Coastguard Worker 4982*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 4983*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 4984*da0073e9SAndroid Build Coastguard Worker @ops(reduction_ops_with_sparse_support) 4985*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.bfloat16: 5e-4, torch.float16: 5e-3}) 4986*da0073e9SAndroid Build Coastguard Worker @all_sparse_layouts('layout', include_strided=False) 4987*da0073e9SAndroid Build Coastguard Worker def test_reductions(self, layout, device, dtype, op): 4988*da0073e9SAndroid Build Coastguard Worker count = 0 4989*da0073e9SAndroid Build Coastguard Worker for sample in op.sample_inputs_sparse(layout, device, dtype): 4990*da0073e9SAndroid Build Coastguard Worker count += 1 4991*da0073e9SAndroid Build Coastguard Worker 4992*da0073e9SAndroid Build Coastguard Worker t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs 4993*da0073e9SAndroid Build Coastguard Worker result = op.op(t_inp, *t_args, **t_kwargs) 4994*da0073e9SAndroid Build Coastguard Worker 4995*da0073e9SAndroid Build Coastguard Worker # Checking invariant rop(inp, ...).to_dense() == rop(inp.to_dense(), ...) 4996*da0073e9SAndroid Build Coastguard Worker dense = op.op(t_inp.to_dense(), *t_args, **t_kwargs) 4997*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, dense) 4998*da0073e9SAndroid Build Coastguard Worker 4999*da0073e9SAndroid Build Coastguard Worker if count == 0: 5000*da0073e9SAndroid Build Coastguard Worker # we count samples to avoid false-positive test reports 5001*da0073e9SAndroid Build Coastguard Worker self.skipTest('no sample inputs') 5002*da0073e9SAndroid Build Coastguard Worker 5003*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 5004*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 5005*da0073e9SAndroid Build Coastguard Worker @ops(reduction_ops_with_sparse_support, allowed_dtypes=(torch.float32, torch.float64, torch.complex64, torch.complex128)) 5006*da0073e9SAndroid Build Coastguard Worker @all_sparse_layouts('layout', include_strided=False) 5007*da0073e9SAndroid Build Coastguard Worker def test_reductions_backward(self, layout, device, dtype, op): 5008*da0073e9SAndroid Build Coastguard Worker count = 0 5009*da0073e9SAndroid Build Coastguard Worker for sample in op.sample_inputs_sparse(layout, device, dtype, requires_grad=True): 5010*da0073e9SAndroid Build Coastguard Worker t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs 5011*da0073e9SAndroid Build Coastguard Worker r = op.op(t_inp, *t_args, **t_kwargs) 5012*da0073e9SAndroid Build Coastguard Worker if r.numel() != 0: 5013*da0073e9SAndroid Build Coastguard Worker r = r.sum() 5014*da0073e9SAndroid Build Coastguard Worker 5015*da0073e9SAndroid Build Coastguard Worker if op.name == 'sum': 5016*da0073e9SAndroid Build Coastguard Worker count += 1 5017*da0073e9SAndroid Build Coastguard Worker r.abs().backward() 5018*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t_inp.grad, torch.ones(t_inp.shape, dtype=dtype, device=device) * torch.sgn(r)) 5019*da0073e9SAndroid Build Coastguard Worker else: 5020*da0073e9SAndroid Build Coastguard Worker self.skipTest('NOT IMPL') 5021*da0073e9SAndroid Build Coastguard Worker 5022*da0073e9SAndroid Build Coastguard Worker if count == 0: 5023*da0073e9SAndroid Build Coastguard Worker # we count samples to avoid false-positive test reports 5024*da0073e9SAndroid Build Coastguard Worker self.skipTest('no sample inputs') 5025*da0073e9SAndroid Build Coastguard Worker 5026*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 5027*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 5028*da0073e9SAndroid Build Coastguard Worker @parametrize("mth", [subtest(mth, name=mth.__name__) 5029*da0073e9SAndroid Build Coastguard Worker for mth in [torch.Tensor.is_coalesced, 5030*da0073e9SAndroid Build Coastguard Worker torch.Tensor.coalesce, 5031*da0073e9SAndroid Build Coastguard Worker torch.Tensor.indices, 5032*da0073e9SAndroid Build Coastguard Worker torch.Tensor.values, 5033*da0073e9SAndroid Build Coastguard Worker torch.Tensor.crow_indices, 5034*da0073e9SAndroid Build Coastguard Worker torch.Tensor.col_indices, 5035*da0073e9SAndroid Build Coastguard Worker torch.Tensor.ccol_indices, 5036*da0073e9SAndroid Build Coastguard Worker torch.Tensor.row_indices, 5037*da0073e9SAndroid Build Coastguard Worker ]]) 5038*da0073e9SAndroid Build Coastguard Worker @all_sparse_layouts('layout', include_strided=True) 5039*da0073e9SAndroid Build Coastguard Worker def test_unsupported_backend_error_message(self, mth, layout, device): 5040*da0073e9SAndroid Build Coastguard Worker inp = torch.tensor([[1, 2], [3, 4]], device=device).to_sparse( 5041*da0073e9SAndroid Build Coastguard Worker layout=layout, 5042*da0073e9SAndroid Build Coastguard Worker blocksize=(1, 1) if layout in {torch.sparse_bsr, torch.sparse_bsc} else None) 5043*da0073e9SAndroid Build Coastguard Worker assert inp.layout is layout 5044*da0073e9SAndroid Build Coastguard Worker 5045*da0073e9SAndroid Build Coastguard Worker expected_behaviour = dict( 5046*da0073e9SAndroid Build Coastguard Worker # <mth name> = (<supported layouts>, <exception message on other layouts>) 5047*da0073e9SAndroid Build Coastguard Worker is_coalesced=({torch.sparse_coo}, 5048*da0073e9SAndroid Build Coastguard Worker "is_coalesced expected sparse coordinate tensor layout but got (Sparse(Csr|Csc|Bsr|Bsc)|Strided)"), 5049*da0073e9SAndroid Build Coastguard Worker coalesce=({torch.sparse_coo}, 5050*da0073e9SAndroid Build Coastguard Worker "coalesce expected sparse coordinate tensor layout but got (Sparse(Csr|Csc|Bsr|Bsc)|Strided)"), 5051*da0073e9SAndroid Build Coastguard Worker indices=({torch.sparse_coo}, 5052*da0073e9SAndroid Build Coastguard Worker "indices expected sparse coordinate tensor layout but got (Sparse(Csr|Csc|Bsr|Bsc)|Strided)"), 5053*da0073e9SAndroid Build Coastguard Worker values=({torch.sparse_coo, torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}, 5054*da0073e9SAndroid Build Coastguard Worker "values expected sparse tensor layout but got Strided"), 5055*da0073e9SAndroid Build Coastguard Worker crow_indices=({torch.sparse_csr, torch.sparse_bsr}, 5056*da0073e9SAndroid Build Coastguard Worker "crow_indices expected sparse row compressed tensor layout but got (Sparse(Csc|Bsc|)|Strided)"), 5057*da0073e9SAndroid Build Coastguard Worker col_indices=({torch.sparse_csr, torch.sparse_bsr}, 5058*da0073e9SAndroid Build Coastguard Worker "col_indices expected sparse row compressed tensor layout but got (Sparse(Csc|Bsc|)|Strided)"), 5059*da0073e9SAndroid Build Coastguard Worker ccol_indices=({torch.sparse_csc, torch.sparse_bsc}, 5060*da0073e9SAndroid Build Coastguard Worker "ccol_indices expected sparse column compressed tensor layout but got (Sparse(Csr|Bsr|)|Strided)"), 5061*da0073e9SAndroid Build Coastguard Worker row_indices=({torch.sparse_csc, torch.sparse_bsc}, 5062*da0073e9SAndroid Build Coastguard Worker "row_indices expected sparse column compressed tensor layout but got (Sparse(Csr|Bsr|)|Strided)"), 5063*da0073e9SAndroid Build Coastguard Worker )[mth.__name__] 5064*da0073e9SAndroid Build Coastguard Worker 5065*da0073e9SAndroid Build Coastguard Worker if layout in expected_behaviour[0]: 5066*da0073e9SAndroid Build Coastguard Worker mth(inp) 5067*da0073e9SAndroid Build Coastguard Worker else: 5068*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, expected_behaviour[1]): 5069*da0073e9SAndroid Build Coastguard Worker mth(inp) 5070*da0073e9SAndroid Build Coastguard Worker 5071*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 5072*da0073e9SAndroid Build Coastguard Worker @all_sparse_layouts('layout', include_strided=not True) 5073*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float64, torch.cdouble) 5074*da0073e9SAndroid Build Coastguard Worker @parametrize("masked", [subtest(False, name='sparse'), subtest(True, name='masked')]) 5075*da0073e9SAndroid Build Coastguard Worker @parametrize("fast_mode", [subtest(False, name='slow'), subtest(True, name='fast')]) 5076*da0073e9SAndroid Build Coastguard Worker def test_gradcheck_mm(self, layout, dtype, device, masked, fast_mode): 5077*da0073e9SAndroid Build Coastguard Worker # This function does not check the following cases: 5078*da0073e9SAndroid Build Coastguard Worker # - batch or hybrid tensors because addmm does not support 5079*da0073e9SAndroid Build Coastguard Worker # such inputs yet 5080*da0073e9SAndroid Build Coastguard Worker # - check_forward_ad=True because of the lack of sparse tensor 5081*da0073e9SAndroid Build Coastguard Worker # support in aten::view_as_real, torch._VF._make_dual, etc. 5082*da0073e9SAndroid Build Coastguard Worker 5083*da0073e9SAndroid Build Coastguard Worker ref_x = torch.tensor([[1, 2, 0, 0], 5084*da0073e9SAndroid Build Coastguard Worker [0, 6, 0, 0], 5085*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 0], 5086*da0073e9SAndroid Build Coastguard Worker [13, 14, 0, 15]], dtype=dtype, device=device) 5087*da0073e9SAndroid Build Coastguard Worker ref_y = torch.tensor([[11, 12, 13, 14], 5088*da0073e9SAndroid Build Coastguard Worker [21, 22, 23, 24], 5089*da0073e9SAndroid Build Coastguard Worker [31, 32, 33, 34], 5090*da0073e9SAndroid Build Coastguard Worker [41, 42, 43, 44]], 5091*da0073e9SAndroid Build Coastguard Worker dtype=dtype, device=device) 5092*da0073e9SAndroid Build Coastguard Worker 5093*da0073e9SAndroid Build Coastguard Worker mm = torch.sparse.mm if masked else torch.mm 5094*da0073e9SAndroid Build Coastguard Worker 5095*da0073e9SAndroid Build Coastguard Worker blocksize = (2, 2) if layout in {torch.sparse_bsr, torch.sparse_bsc} else None 5096*da0073e9SAndroid Build Coastguard Worker x = ref_x.to_sparse(layout=layout, blocksize=blocksize).requires_grad_(True) 5097*da0073e9SAndroid Build Coastguard Worker y = ref_y.requires_grad_(True) 5098*da0073e9SAndroid Build Coastguard Worker 5099*da0073e9SAndroid Build Coastguard Worker if layout is torch.sparse_bsr and not masked or layout is torch.sparse_bsc: 5100*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5101*da0073e9SAndroid Build Coastguard Worker RuntimeError, 5102*da0073e9SAndroid Build Coastguard Worker r"addmm: computation on (CPU|CUDA) is not implemented for Strided \+ Sparse(Bsr|Bsc) @ Strided"): 5103*da0073e9SAndroid Build Coastguard Worker torch.autograd.gradcheck(mm, (x, y), fast_mode=fast_mode, masked=masked) 5104*da0073e9SAndroid Build Coastguard Worker self.skipTest('NOT IMPL') 5105*da0073e9SAndroid Build Coastguard Worker elif layout in {torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc} and masked: 5106*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5107*da0073e9SAndroid Build Coastguard Worker RuntimeError, 5108*da0073e9SAndroid Build Coastguard Worker r"(sparse_addmm_sparse_backward: unsupported combination of layouts," 5109*da0073e9SAndroid Build Coastguard Worker r" grad: Strided, mat1: Sparse(Csc|Bsr|Bsc), mat2: Strided" 5110*da0073e9SAndroid Build Coastguard Worker r"|addmm: computation on (CPU|CUDA) is not implemented for " 5111*da0073e9SAndroid Build Coastguard Worker r"Strided \+ Sparse(Csc|Bsr|Bsc) @ Strided without MKL)"): 5112*da0073e9SAndroid Build Coastguard Worker torch.autograd.gradcheck(mm, (x, y), fast_mode=fast_mode, masked=masked) 5113*da0073e9SAndroid Build Coastguard Worker self.skipTest('NOT IMPL') 5114*da0073e9SAndroid Build Coastguard Worker else: 5115*da0073e9SAndroid Build Coastguard Worker torch.autograd.gradcheck(mm, (x, y), fast_mode=fast_mode, masked=masked) 5116*da0073e9SAndroid Build Coastguard Worker 5117*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 5118*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 5119*da0073e9SAndroid Build Coastguard Worker @ops(binary_ufuncs_with_sparse_support) 5120*da0073e9SAndroid Build Coastguard Worker @all_sparse_layouts('layout', include_strided=False) 5121*da0073e9SAndroid Build Coastguard Worker def test_binary_operation(self, layout, device, dtype, op): 5122*da0073e9SAndroid Build Coastguard Worker if not op.supports_sparse_layout(layout): 5123*da0073e9SAndroid Build Coastguard Worker self.skipTest(f'{layout} is not supported in `{op.name}` OpInfo definition. Skipping!') 5124*da0073e9SAndroid Build Coastguard Worker 5125*da0073e9SAndroid Build Coastguard Worker for sample in op.sample_inputs_sparse(layout, device, dtype): 5126*da0073e9SAndroid Build Coastguard Worker if validate_sample_input_sparse(op, sample, check_validate=False) is not sample: 5127*da0073e9SAndroid Build Coastguard Worker # that is, the validation returns the sparse sample 5128*da0073e9SAndroid Build Coastguard Worker # wrapped within ErrorInput instance 5129*da0073e9SAndroid Build Coastguard Worker continue 5130*da0073e9SAndroid Build Coastguard Worker t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs 5131*da0073e9SAndroid Build Coastguard Worker batch_dim = t_inp.dim() - t_inp.dense_dim() - t_inp.sparse_dim() 5132*da0073e9SAndroid Build Coastguard Worker result = op.op(t_inp, *t_args, **t_kwargs) 5133*da0073e9SAndroid Build Coastguard Worker 5134*da0073e9SAndroid Build Coastguard Worker # Check rop(inp, ...).shape == inp.shape 5135*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, t_inp.shape) 5136*da0073e9SAndroid Build Coastguard Worker 5137*da0073e9SAndroid Build Coastguard Worker # Check rop(inp, ...).sparse_dim() == inp.sparse_dim() 5138*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.sparse_dim(), t_inp.sparse_dim()) 5139*da0073e9SAndroid Build Coastguard Worker 5140*da0073e9SAndroid Build Coastguard Worker # Check rop(inp, ...).dense_dim() == inp.dense_dim() 5141*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dense_dim(), t_inp.dense_dim()) 5142*da0073e9SAndroid Build Coastguard Worker 5143*da0073e9SAndroid Build Coastguard Worker # Check invariant rop(inp, ...).to_dense() == rop(inp.to_dense(), ...) 5144*da0073e9SAndroid Build Coastguard Worker try: 5145*da0073e9SAndroid Build Coastguard Worker dense = op.op(t_inp.to_dense(), *(t_args[0].to_dense(), *t_args[1:]), **t_kwargs) 5146*da0073e9SAndroid Build Coastguard Worker except Exception as msg: 5147*da0073e9SAndroid Build Coastguard Worker # this is strided op issue, so skipping the sample silently here 5148*da0073e9SAndroid Build Coastguard Worker if "\"cpublas_axpy_impl\" not implemented for 'ComplexHalf'" in str(msg): 5149*da0073e9SAndroid Build Coastguard Worker continue 5150*da0073e9SAndroid Build Coastguard Worker raise 5151*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, dense) 5152*da0073e9SAndroid Build Coastguard Worker 5153*da0073e9SAndroid Build Coastguard Worker @onlyCPU 5154*da0073e9SAndroid Build Coastguard Worker @all_sparse_layouts('layout', include_strided=True) 5155*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 5156*da0073e9SAndroid Build Coastguard Worker def test_to_sparse_identity(self, device, layout, dtype): 5157*da0073e9SAndroid Build Coastguard Worker for dense_dim in range(4): 5158*da0073e9SAndroid Build Coastguard Worker x_dense = torch.eye(dense_dim, dtype=dtype, device=device) 5159*da0073e9SAndroid Build Coastguard Worker for sparse_dim_in in range(1, dense_dim): 5160*da0073e9SAndroid Build Coastguard Worker x_sparse = x_dense.to_sparse(sparse_dim_in) 5161*da0073e9SAndroid Build Coastguard Worker for sparse_dim_out in range(0, dense_dim): 5162*da0073e9SAndroid Build Coastguard Worker if sparse_dim_out == sparse_dim_in: 5163*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x_sparse.to_sparse(sparse_dim_out).sparse_dim() == sparse_dim_out) 5164*da0073e9SAndroid Build Coastguard Worker else: 5165*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5166*da0073e9SAndroid Build Coastguard Worker RuntimeError, 5167*da0073e9SAndroid Build Coastguard Worker r"to_sparse: conversion from Sparse to Sparse with sparse_dim argument !=self.sparse_dim\(\)" 5168*da0073e9SAndroid Build Coastguard Worker " is not supported"): 5169*da0073e9SAndroid Build Coastguard Worker x_sparse.to_sparse(sparse_dim_out) 5170*da0073e9SAndroid Build Coastguard Worker 5171*da0073e9SAndroid Build Coastguard Worker 5172*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 5173*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 5174*da0073e9SAndroid Build Coastguard Worker @ops(like_fns_with_sparse_support) 5175*da0073e9SAndroid Build Coastguard Worker @all_sparse_layouts('layout', include_strided=False) 5176*da0073e9SAndroid Build Coastguard Worker def test_like_fns(self, layout, device, dtype, op): 5177*da0073e9SAndroid Build Coastguard Worker 5178*da0073e9SAndroid Build Coastguard Worker for sample in op.sample_inputs_sparse(layout, device, dtype): 5179*da0073e9SAndroid Build Coastguard Worker t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs 5180*da0073e9SAndroid Build Coastguard Worker batch_dim = t_inp.dim() - t_inp.dense_dim() - t_inp.sparse_dim() 5181*da0073e9SAndroid Build Coastguard Worker if t_inp.layout in {torch.sparse_bsr, torch.sparse_bsc}: 5182*da0073e9SAndroid Build Coastguard Worker expected_blocksize = t_inp.values().shape[batch_dim + 1:batch_dim + 3] 5183*da0073e9SAndroid Build Coastguard Worker else: 5184*da0073e9SAndroid Build Coastguard Worker expected_blocksize = None 5185*da0073e9SAndroid Build Coastguard Worker expected_dtype = t_kwargs.get('dtype', dtype) 5186*da0073e9SAndroid Build Coastguard Worker expected_device = torch.device(t_kwargs.get('device', device)) 5187*da0073e9SAndroid Build Coastguard Worker expected_layout = t_kwargs.get('layout', layout) 5188*da0073e9SAndroid Build Coastguard Worker 5189*da0073e9SAndroid Build Coastguard Worker result = op.op(t_inp, *t_args, **t_kwargs) 5190*da0073e9SAndroid Build Coastguard Worker 5191*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, expected_dtype) 5192*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.device.type, expected_device.type) 5193*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.layout, expected_layout) 5194*da0073e9SAndroid Build Coastguard Worker 5195*da0073e9SAndroid Build Coastguard Worker if result.layout in {torch.sparse_bsr, torch.sparse_bsc}: 5196*da0073e9SAndroid Build Coastguard Worker result_batch_dim = result.dim() - result.dense_dim() - result.sparse_dim() 5197*da0073e9SAndroid Build Coastguard Worker blocksize = result.values().shape[result_batch_dim + 1:result_batch_dim + 3] 5198*da0073e9SAndroid Build Coastguard Worker self.assertEqual(blocksize, expected_blocksize) 5199*da0073e9SAndroid Build Coastguard Worker 5200*da0073e9SAndroid Build Coastguard Worker # Check op(inp).shape == inp.shape 5201*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, t_inp.shape) 5202*da0073e9SAndroid Build Coastguard Worker 5203*da0073e9SAndroid Build Coastguard Worker if expected_layout is torch.strided: 5204*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.sparse_dim(), 0) 5205*da0073e9SAndroid Build Coastguard Worker # Check op(inp, layout=torch.strided).dense_dim() == inp.dim() 5206*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dense_dim(), t_inp.dim()) 5207*da0073e9SAndroid Build Coastguard Worker elif expected_layout is torch.sparse_coo: 5208*da0073e9SAndroid Build Coastguard Worker # Check op(inp, layout=torch.sparse_coo).sparse_dim() == batch_dim + inp.sparse_dim() 5209*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.sparse_dim(), batch_dim + t_inp.sparse_dim()) 5210*da0073e9SAndroid Build Coastguard Worker # Check op(inp, layout=torch.sparse_coo).dense_dim() == inp.dense_dim() 5211*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dense_dim(), t_inp.dense_dim()) 5212*da0073e9SAndroid Build Coastguard Worker 5213*da0073e9SAndroid Build Coastguard Worker torch._validate_sparse_coo_tensor_args(result._indices(), result._values(), result.shape) 5214*da0073e9SAndroid Build Coastguard Worker else: 5215*da0073e9SAndroid Build Coastguard Worker # Check op(inp).sparse_dim() == inp.sparse_dim() 5216*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.sparse_dim(), t_inp.sparse_dim()) 5217*da0073e9SAndroid Build Coastguard Worker # Check op(inp).dense_dim() == inp.dense_dim() 5218*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dense_dim(), t_inp.dense_dim()) 5219*da0073e9SAndroid Build Coastguard Worker 5220*da0073e9SAndroid Build Coastguard Worker if result.layout in {torch.sparse_csr, torch.sparse_bsr}: 5221*da0073e9SAndroid Build Coastguard Worker compressed_indices, plain_indices = result.crow_indices(), result.col_indices() 5222*da0073e9SAndroid Build Coastguard Worker else: 5223*da0073e9SAndroid Build Coastguard Worker compressed_indices, plain_indices = result.ccol_indices(), result.row_indices() 5224*da0073e9SAndroid Build Coastguard Worker 5225*da0073e9SAndroid Build Coastguard Worker torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, result.values(), 5226*da0073e9SAndroid Build Coastguard Worker result.shape, result.layout) 5227*da0073e9SAndroid Build Coastguard Worker 5228*da0073e9SAndroid Build Coastguard Worker @all_sparse_layouts('mask_layout', include_strided=False) 5229*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 5230*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 5231*da0073e9SAndroid Build Coastguard Worker def test_sparse_mask(self, mask_layout, device, dtype): 5232*da0073e9SAndroid Build Coastguard Worker input_layout = torch.strided 5233*da0073e9SAndroid Build Coastguard Worker mask_dtype = torch.bool 5234*da0073e9SAndroid Build Coastguard Worker for mask in self.generate_simple_inputs(mask_layout, dtype=mask_dtype, device=device, 5235*da0073e9SAndroid Build Coastguard Worker enable_hybrid=False, enable_batch=False): 5236*da0073e9SAndroid Build Coastguard Worker 5237*da0073e9SAndroid Build Coastguard Worker x = make_tensor(mask.shape, dtype=dtype, device=device).to_sparse(layout=input_layout) 5238*da0073e9SAndroid Build Coastguard Worker 5239*da0073e9SAndroid Build Coastguard Worker result = x.sparse_mask(mask) 5240*da0073e9SAndroid Build Coastguard Worker 5241*da0073e9SAndroid Build Coastguard Worker # Check invariant `x.sparse_mask(mask).<indices> == mask.<indices>` 5242*da0073e9SAndroid Build Coastguard Worker if mask_layout is torch.sparse_coo: 5243*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result._indices(), mask._indices()) 5244*da0073e9SAndroid Build Coastguard Worker ones = torch.sparse_coo_tensor(mask._indices(), 5245*da0073e9SAndroid Build Coastguard Worker torch.ones_like(mask._values(), dtype=x.dtype), 5246*da0073e9SAndroid Build Coastguard Worker mask.shape, 5247*da0073e9SAndroid Build Coastguard Worker is_coalesced=mask.is_coalesced()) 5248*da0073e9SAndroid Build Coastguard Worker elif mask_layout in {torch.sparse_csr, torch.sparse_bsr}: 5249*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.crow_indices(), mask.crow_indices()) 5250*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.col_indices(), mask.col_indices()) 5251*da0073e9SAndroid Build Coastguard Worker ones = torch.sparse_compressed_tensor(mask.crow_indices(), mask.col_indices(), 5252*da0073e9SAndroid Build Coastguard Worker torch.ones_like(mask.values(), dtype=x.dtype), 5253*da0073e9SAndroid Build Coastguard Worker mask.shape, layout=mask.layout) 5254*da0073e9SAndroid Build Coastguard Worker else: 5255*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.ccol_indices(), mask.ccol_indices()) 5256*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.row_indices(), mask.row_indices()) 5257*da0073e9SAndroid Build Coastguard Worker ones = torch.sparse_compressed_tensor(mask.ccol_indices(), mask.row_indices(), 5258*da0073e9SAndroid Build Coastguard Worker torch.ones_like(mask.values(), dtype=x.dtype), 5259*da0073e9SAndroid Build Coastguard Worker mask.shape, layout=mask.layout) 5260*da0073e9SAndroid Build Coastguard Worker 5261*da0073e9SAndroid Build Coastguard Worker # Check invariant: 5262*da0073e9SAndroid Build Coastguard Worker # x.sparse_mask(mask).to_dense() == x.mul(sparse_xyz_tensor(<mask indices>, 5263*da0073e9SAndroid Build Coastguard Worker # ones_like(<mask values>)).to_dense()) 5264*da0073e9SAndroid Build Coastguard Worker expected = x.mul(ones.to_dense()) 5265*da0073e9SAndroid Build Coastguard Worker 5266*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.to_dense(), expected) 5267*da0073e9SAndroid Build Coastguard Worker 5268*da0073e9SAndroid Build Coastguard Worker # Check invariant `mask.to_dense().sparse_mask(mask) == mask` 5269*da0073e9SAndroid Build Coastguard Worker result = mask.to_dense().sparse_mask(mask) 5270*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, mask) 5271*da0073e9SAndroid Build Coastguard Worker 5272*da0073e9SAndroid Build Coastguard Worker @all_sparse_layouts('layout', include_strided=False) 5273*da0073e9SAndroid Build Coastguard Worker @parametrize("masked", [subtest(False, name='nonmasked'), subtest(True, name='masked')]) 5274*da0073e9SAndroid Build Coastguard Worker @parametrize("fast_mode", [subtest(False, name='slow'), subtest(True, name='fast')]) 5275*da0073e9SAndroid Build Coastguard Worker def test_as_sparse_gradcheck(self, layout, device, masked, fast_mode): 5276*da0073e9SAndroid Build Coastguard Worker gradcheck = torch.sparse.as_sparse_gradcheck(torch.autograd.gradcheck) 5277*da0073e9SAndroid Build Coastguard Worker sparse_compressed_layouts = {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc} 5278*da0073e9SAndroid Build Coastguard Worker 5279*da0073e9SAndroid Build Coastguard Worker def identity(x): 5280*da0073e9SAndroid Build Coastguard Worker return x 5281*da0073e9SAndroid Build Coastguard Worker 5282*da0073e9SAndroid Build Coastguard Worker for func in (torch.Tensor.to_dense, 5283*da0073e9SAndroid Build Coastguard Worker torch.Tensor.sum, 5284*da0073e9SAndroid Build Coastguard Worker identity, 5285*da0073e9SAndroid Build Coastguard Worker torch.Tensor.to_sparse, 5286*da0073e9SAndroid Build Coastguard Worker torch.Tensor.values, 5287*da0073e9SAndroid Build Coastguard Worker ): 5288*da0073e9SAndroid Build Coastguard Worker for x in self.generate_simple_inputs( 5289*da0073e9SAndroid Build Coastguard Worker layout, 5290*da0073e9SAndroid Build Coastguard Worker device=device, 5291*da0073e9SAndroid Build Coastguard Worker dtype=torch.float64, 5292*da0073e9SAndroid Build Coastguard Worker # TODO: fix gh-104868 to enable batched samples: 5293*da0073e9SAndroid Build Coastguard Worker enable_batch=layout not in sparse_compressed_layouts, 5294*da0073e9SAndroid Build Coastguard Worker enable_hybrid=not ( 5295*da0073e9SAndroid Build Coastguard Worker layout in sparse_compressed_layouts and ( 5296*da0073e9SAndroid Build Coastguard Worker # FIXME: RuntimeError: sparse_mask(): the 5297*da0073e9SAndroid Build Coastguard Worker # number of sparse dimensions in `self` 5298*da0073e9SAndroid Build Coastguard Worker # should match that of the `mask`. Got 5299*da0073e9SAndroid Build Coastguard Worker # `self.sparse_dim() == 3` != 5300*da0073e9SAndroid Build Coastguard Worker # `mask.sparse_dim() == 2 5301*da0073e9SAndroid Build Coastguard Worker func.__name__ == 'sum' 5302*da0073e9SAndroid Build Coastguard Worker # FIXME: RuntimeError: expected 5303*da0073e9SAndroid Build Coastguard Worker # col_indices to be a contiguous tensor 5304*da0073e9SAndroid Build Coastguard Worker # per batch 5305*da0073e9SAndroid Build Coastguard Worker or func.__name__ == 'to_sparse' 5306*da0073e9SAndroid Build Coastguard Worker ))): 5307*da0073e9SAndroid Build Coastguard Worker if layout is torch.sparse_coo and func.__name__ == 'values': 5308*da0073e9SAndroid Build Coastguard Worker x = x.coalesce() 5309*da0073e9SAndroid Build Coastguard Worker 5310*da0073e9SAndroid Build Coastguard Worker gradcheck(func, x.requires_grad_(True), masked=masked, fast_mode=fast_mode) 5311*da0073e9SAndroid Build Coastguard Worker 5312*da0073e9SAndroid Build Coastguard Worker @onlyCPU 5313*da0073e9SAndroid Build Coastguard Worker @all_sparse_layouts('layout', include_strided=False) 5314*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 5315*da0073e9SAndroid Build Coastguard Worker def test_dataloader(self, device, layout, dtype): 5316*da0073e9SAndroid Build Coastguard Worker 5317*da0073e9SAndroid Build Coastguard Worker data = list(self.generate_simple_inputs(layout, device=device, dtype=dtype)) 5318*da0073e9SAndroid Build Coastguard Worker 5319*da0073e9SAndroid Build Coastguard Worker dataset = _SparseDataset(data) 5320*da0073e9SAndroid Build Coastguard Worker loader = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=2) 5321*da0073e9SAndroid Build Coastguard Worker 5322*da0073e9SAndroid Build Coastguard Worker loaded_data = list(loader) 5323*da0073e9SAndroid Build Coastguard Worker self.assertEqual(data, loaded_data) 5324*da0073e9SAndroid Build Coastguard Worker 5325*da0073e9SAndroid Build Coastguard Worker @onlyCPU 5326*da0073e9SAndroid Build Coastguard Worker def test_invalid_blocksize(self): 5327*da0073e9SAndroid Build Coastguard Worker # Blocksize should be a tuple/list/torch.Size containing two values 5328*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, ".*blocksize.*, but got 1"): 5329*da0073e9SAndroid Build Coastguard Worker torch.randn(1).to_sparse(blocksize=(1,)) 5330*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, ".*blocksize.*, but got 1"): 5331*da0073e9SAndroid Build Coastguard Worker torch.randn(1).to_sparse(blocksize=[1]) 5332*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, ".*blocksize.*, but got 1"): 5333*da0073e9SAndroid Build Coastguard Worker torch.randn(1).to_sparse(blocksize=torch.Size((1,))) 5334*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, ".*blocksize.*, but got 3"): 5335*da0073e9SAndroid Build Coastguard Worker torch.randn(1).to_sparse(blocksize=(1, 1, 1)) 5336*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, ".*blocksize.*, but got 3"): 5337*da0073e9SAndroid Build Coastguard Worker torch.randn(1).to_sparse(blocksize=[1, 1, 1]) 5338*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, ".*blocksize.*, but got 3"): 5339*da0073e9SAndroid Build Coastguard Worker torch.randn(1).to_sparse(blocksize=torch.Size((1, 1, 1))) 5340*da0073e9SAndroid Build Coastguard Worker 5341*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.cuda.is_available(), 'requires cuda') 5342*da0073e9SAndroid Build Coastguard Worker @onlyCPU 5343*da0073e9SAndroid Build Coastguard Worker @all_sparse_layouts('layout', include_strided=True) 5344*da0073e9SAndroid Build Coastguard Worker def test_constructor_pin_memory(self, device, layout): 5345*da0073e9SAndroid Build Coastguard Worker """Tests sparse_xyz_tensor(indices, values, pin_memory=True) 5346*da0073e9SAndroid Build Coastguard Worker """ 5347*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device, "cpu") 5348*da0073e9SAndroid Build Coastguard Worker for t in self.generate_simple_inputs( 5349*da0073e9SAndroid Build Coastguard Worker layout, device=device, dtype=torch.float64, 5350*da0073e9SAndroid Build Coastguard Worker enable_zero_sized=False, # pinning zero-sized tensors is a no-op 5351*da0073e9SAndroid Build Coastguard Worker pin_memory=True, 5352*da0073e9SAndroid Build Coastguard Worker enable_batch=False, # TODO: remove after gh-104868 is resolved 5353*da0073e9SAndroid Build Coastguard Worker ): 5354*da0073e9SAndroid Build Coastguard Worker if layout is torch.sparse_coo: 5355*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t._indices().is_pinned()) 5356*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t._values().is_pinned()) 5357*da0073e9SAndroid Build Coastguard Worker elif layout in {torch.sparse_csr, torch.sparse_bsr}: 5358*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.crow_indices().is_pinned()) 5359*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.col_indices().is_pinned()) 5360*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.values().is_pinned()) 5361*da0073e9SAndroid Build Coastguard Worker elif layout in {torch.sparse_csc, torch.sparse_bsc}: 5362*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.ccol_indices().is_pinned()) 5363*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.row_indices().is_pinned()) 5364*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.values().is_pinned()) 5365*da0073e9SAndroid Build Coastguard Worker elif layout is torch.strided: 5366*da0073e9SAndroid Build Coastguard Worker pass 5367*da0073e9SAndroid Build Coastguard Worker else: 5368*da0073e9SAndroid Build Coastguard Worker assert 0 # unreachable 5369*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.is_pinned()) 5370*da0073e9SAndroid Build Coastguard Worker 5371*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.cuda.is_available(), 'requires cuda') 5372*da0073e9SAndroid Build Coastguard Worker @onlyCPU 5373*da0073e9SAndroid Build Coastguard Worker @all_sparse_layouts('layout', include_strided=True) 5374*da0073e9SAndroid Build Coastguard Worker def test_method_pin_memory(self, device, layout): 5375*da0073e9SAndroid Build Coastguard Worker """Tests sparse_xyz_tensor(indices, values, pin_memory=False).pin_memory() 5376*da0073e9SAndroid Build Coastguard Worker """ 5377*da0073e9SAndroid Build Coastguard Worker 5378*da0073e9SAndroid Build Coastguard Worker for t_ in self.generate_simple_inputs( 5379*da0073e9SAndroid Build Coastguard Worker layout, device=device, dtype=torch.float64, 5380*da0073e9SAndroid Build Coastguard Worker enable_zero_sized=False, # pinning zero-sized tensors is a no-op 5381*da0073e9SAndroid Build Coastguard Worker pin_memory=False, # no pinning 5382*da0073e9SAndroid Build Coastguard Worker enable_batch=False, # TODO: remove after gh-104868 is resolved 5383*da0073e9SAndroid Build Coastguard Worker ): 5384*da0073e9SAndroid Build Coastguard Worker t = t_.pin_memory() 5385*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.is_pinned()) 5386*da0073e9SAndroid Build Coastguard Worker 5387*da0073e9SAndroid Build Coastguard Worker # registering a non-pinned tensor with CUDA memory is a 5388*da0073e9SAndroid Build Coastguard Worker # clone operation 5389*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t_.is_pinned()) 5390*da0073e9SAndroid Build Coastguard Worker 5391*da0073e9SAndroid Build Coastguard Worker # registering already pinned tensor with CUDA memory is an 5392*da0073e9SAndroid Build Coastguard Worker # identity operation: 5393*da0073e9SAndroid Build Coastguard Worker t2 = t.pin_memory() 5394*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t2 is t) 5395*da0073e9SAndroid Build Coastguard Worker 5396*da0073e9SAndroid Build Coastguard Worker if layout is torch.sparse_coo: 5397*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t._indices().is_pinned()) 5398*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t._values().is_pinned()) 5399*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t_._indices().is_pinned()) 5400*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t_._values().is_pinned()) 5401*da0073e9SAndroid Build Coastguard Worker elif layout in {torch.sparse_csr, torch.sparse_bsr}: 5402*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.crow_indices().is_pinned()) 5403*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.col_indices().is_pinned()) 5404*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.values().is_pinned()) 5405*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t_.crow_indices().is_pinned()) 5406*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t_.col_indices().is_pinned()) 5407*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t_.values().is_pinned()) 5408*da0073e9SAndroid Build Coastguard Worker elif layout in {torch.sparse_csc, torch.sparse_bsc}: 5409*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.ccol_indices().is_pinned()) 5410*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.row_indices().is_pinned()) 5411*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.values().is_pinned()) 5412*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t_.ccol_indices().is_pinned()) 5413*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t_.row_indices().is_pinned()) 5414*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t_.values().is_pinned()) 5415*da0073e9SAndroid Build Coastguard Worker elif layout is torch.strided: 5416*da0073e9SAndroid Build Coastguard Worker pass 5417*da0073e9SAndroid Build Coastguard Worker else: 5418*da0073e9SAndroid Build Coastguard Worker assert 0 # unreachable 5419*da0073e9SAndroid Build Coastguard Worker 5420*da0073e9SAndroid Build Coastguard Worker 5421*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.cuda.is_available(), 'requires cuda') 5422*da0073e9SAndroid Build Coastguard Worker @onlyCPU 5423*da0073e9SAndroid Build Coastguard Worker @all_sparse_layouts('layout', include_strided=True) 5424*da0073e9SAndroid Build Coastguard Worker def test_constructor_pinned_memory(self, device, layout): 5425*da0073e9SAndroid Build Coastguard Worker """Tests sparse_xyz_tensor(indices.pin_memory(device), values.pin_memory(device)) 5426*da0073e9SAndroid Build Coastguard Worker """ 5427*da0073e9SAndroid Build Coastguard Worker pin_memory_device = "cuda" 5428*da0073e9SAndroid Build Coastguard Worker for t in self.generate_simple_inputs( 5429*da0073e9SAndroid Build Coastguard Worker layout, device=device, dtype=torch.float64, 5430*da0073e9SAndroid Build Coastguard Worker enable_zero_sized=False, # pinning zero-sized tensors is a no-op 5431*da0073e9SAndroid Build Coastguard Worker pin_memory=None, # constructor does not specify pin_memory=... 5432*da0073e9SAndroid Build Coastguard Worker members_pin_memory=True, # indices and values are pinned 5433*da0073e9SAndroid Build Coastguard Worker enable_batch=False, # TODO: remove after gh-104868 is resolved 5434*da0073e9SAndroid Build Coastguard Worker ): 5435*da0073e9SAndroid Build Coastguard Worker if layout is torch.sparse_coo: 5436*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t._indices().is_pinned()) 5437*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t._values().is_pinned()) 5438*da0073e9SAndroid Build Coastguard Worker elif layout in {torch.sparse_csr, torch.sparse_bsr}: 5439*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.crow_indices().is_pinned()) 5440*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.col_indices().is_pinned()) 5441*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.values().is_pinned()) 5442*da0073e9SAndroid Build Coastguard Worker elif layout in {torch.sparse_csc, torch.sparse_bsc}: 5443*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.ccol_indices().is_pinned()) 5444*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.row_indices().is_pinned()) 5445*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.values().is_pinned()) 5446*da0073e9SAndroid Build Coastguard Worker elif layout is torch.strided: 5447*da0073e9SAndroid Build Coastguard Worker pass 5448*da0073e9SAndroid Build Coastguard Worker else: 5449*da0073e9SAndroid Build Coastguard Worker assert 0 # unreachable 5450*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.is_pinned()) 5451*da0073e9SAndroid Build Coastguard Worker 5452*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.cuda.is_available(), 'requires cuda') 5453*da0073e9SAndroid Build Coastguard Worker @onlyCPU 5454*da0073e9SAndroid Build Coastguard Worker @all_sparse_layouts('layout', include_strided=False) 5455*da0073e9SAndroid Build Coastguard Worker def test_constructor_mismatched_pinned_memory(self, device, layout): 5456*da0073e9SAndroid Build Coastguard Worker """Test the failure to construct sparse tensor from indices and values 5457*da0073e9SAndroid Build Coastguard Worker that have different pinning states. 5458*da0073e9SAndroid Build Coastguard Worker """ 5459*da0073e9SAndroid Build Coastguard Worker def generic_constructor(*args, **kwargs): 5460*da0073e9SAndroid Build Coastguard Worker if layout in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}: 5461*da0073e9SAndroid Build Coastguard Worker kwargs.update(layout=layout) 5462*da0073e9SAndroid Build Coastguard Worker return torch.sparse_compressed_tensor(*args, **kwargs) 5463*da0073e9SAndroid Build Coastguard Worker elif layout is torch.sparse_coo: 5464*da0073e9SAndroid Build Coastguard Worker return torch.sparse_coo_tensor(*args, **kwargs) 5465*da0073e9SAndroid Build Coastguard Worker else: 5466*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError(layout) 5467*da0073e9SAndroid Build Coastguard Worker 5468*da0073e9SAndroid Build Coastguard Worker for args, kwargs in self.generate_simple_inputs( 5469*da0073e9SAndroid Build Coastguard Worker layout, device=device, dtype=torch.float64, 5470*da0073e9SAndroid Build Coastguard Worker enable_zero_sized=False, # pinning zero-sized tensors is a no-op 5471*da0073e9SAndroid Build Coastguard Worker enable_batch=False, # TODO: remove after gh-104868 is resolved 5472*da0073e9SAndroid Build Coastguard Worker output_tensor=False): 5473*da0073e9SAndroid Build Coastguard Worker 5474*da0073e9SAndroid Build Coastguard Worker # indices are pinned, values is a non-pinned tensor 5475*da0073e9SAndroid Build Coastguard Worker args1 = (args[0].pin_memory(), *args[1:]) 5476*da0073e9SAndroid Build Coastguard Worker 5477*da0073e9SAndroid Build Coastguard Worker # indices are non-pinned, values is a pinned tensor 5478*da0073e9SAndroid Build Coastguard Worker args2 = (*args[:-1], args[-1].pin_memory()) 5479*da0073e9SAndroid Build Coastguard Worker 5480*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5481*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"memory pinning of \w*indices \(=1\) must match memory pinning of values \(=0\)"): 5482*da0073e9SAndroid Build Coastguard Worker generic_constructor(*args1, **kwargs) 5483*da0073e9SAndroid Build Coastguard Worker 5484*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5485*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"memory pinning of \w*indices \(=0\) must match memory pinning of values \(=1\)"): 5486*da0073e9SAndroid Build Coastguard Worker generic_constructor(*args2, **kwargs) 5487*da0073e9SAndroid Build Coastguard Worker 5488*da0073e9SAndroid Build Coastguard Worker 5489*da0073e9SAndroid Build Coastguard Worker# e.g., TestSparseUnaryUfuncsCPU and TestSparseUnaryUfuncsCUDA 5490*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestSparseUnaryUfuncs, globals(), except_for='meta') 5491*da0073e9SAndroid Build Coastguard Worker 5492*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestSparseMaskedReductions, globals(), except_for='meta') 5493*da0073e9SAndroid Build Coastguard Worker 5494*da0073e9SAndroid Build Coastguard Worker# e.g., TestSparseCPU and TestSparseCUDA 5495*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestSparse, globals(), except_for='meta') 5496*da0073e9SAndroid Build Coastguard Worker 5497*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestSparseAny, globals(), except_for='meta') 5498*da0073e9SAndroid Build Coastguard Worker 5499*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestSparseMeta) 5500*da0073e9SAndroid Build Coastguard Worker 5501*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestSparseLegacyAndDeprecation) 5502*da0073e9SAndroid Build Coastguard Worker 5503*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 5504*da0073e9SAndroid Build Coastguard Worker run_tests() 5505