1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: tests"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport torch 4*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor 5*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 6*da0073e9SAndroid Build Coastguard Worker dtypes, 7*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 8*da0073e9SAndroid Build Coastguard Worker onlyCUDA, 9*da0073e9SAndroid Build Coastguard Worker onlyNativeDeviceTypes, 10*da0073e9SAndroid Build Coastguard Worker skipCUDAIfRocm, 11*da0073e9SAndroid Build Coastguard Worker skipMeta, 12*da0073e9SAndroid Build Coastguard Worker) 13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import all_types_and_complex_and 14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import IS_JETSON, run_tests, TestCase 15*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.dlpack import from_dlpack, to_dlpack 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Workerclass TestTorchDlPack(TestCase): 19*da0073e9SAndroid Build Coastguard Worker exact_dtype = True 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker @skipMeta 22*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 23*da0073e9SAndroid Build Coastguard Worker @dtypes( 24*da0073e9SAndroid Build Coastguard Worker *all_types_and_complex_and( 25*da0073e9SAndroid Build Coastguard Worker torch.half, 26*da0073e9SAndroid Build Coastguard Worker torch.bfloat16, 27*da0073e9SAndroid Build Coastguard Worker torch.bool, 28*da0073e9SAndroid Build Coastguard Worker torch.uint16, 29*da0073e9SAndroid Build Coastguard Worker torch.uint32, 30*da0073e9SAndroid Build Coastguard Worker torch.uint64, 31*da0073e9SAndroid Build Coastguard Worker ) 32*da0073e9SAndroid Build Coastguard Worker ) 33*da0073e9SAndroid Build Coastguard Worker def test_dlpack_capsule_conversion(self, device, dtype): 34*da0073e9SAndroid Build Coastguard Worker x = make_tensor((5,), dtype=dtype, device=device) 35*da0073e9SAndroid Build Coastguard Worker z = from_dlpack(to_dlpack(x)) 36*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, x) 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker @skipMeta 39*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 40*da0073e9SAndroid Build Coastguard Worker @dtypes( 41*da0073e9SAndroid Build Coastguard Worker *all_types_and_complex_and( 42*da0073e9SAndroid Build Coastguard Worker torch.half, 43*da0073e9SAndroid Build Coastguard Worker torch.bfloat16, 44*da0073e9SAndroid Build Coastguard Worker torch.bool, 45*da0073e9SAndroid Build Coastguard Worker torch.uint16, 46*da0073e9SAndroid Build Coastguard Worker torch.uint32, 47*da0073e9SAndroid Build Coastguard Worker torch.uint64, 48*da0073e9SAndroid Build Coastguard Worker ) 49*da0073e9SAndroid Build Coastguard Worker ) 50*da0073e9SAndroid Build Coastguard Worker def test_dlpack_protocol_conversion(self, device, dtype): 51*da0073e9SAndroid Build Coastguard Worker x = make_tensor((5,), dtype=dtype, device=device) 52*da0073e9SAndroid Build Coastguard Worker z = from_dlpack(x) 53*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, x) 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker @skipMeta 56*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 57*da0073e9SAndroid Build Coastguard Worker def test_dlpack_shared_storage(self, device): 58*da0073e9SAndroid Build Coastguard Worker x = make_tensor((5,), dtype=torch.float64, device=device) 59*da0073e9SAndroid Build Coastguard Worker z = from_dlpack(to_dlpack(x)) 60*da0073e9SAndroid Build Coastguard Worker z[0] = z[0] + 20.0 61*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, x) 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker @skipMeta 64*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 65*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) 66*da0073e9SAndroid Build Coastguard Worker def test_dlpack_conversion_with_streams(self, device, dtype): 67*da0073e9SAndroid Build Coastguard Worker # Create a stream where the tensor will reside 68*da0073e9SAndroid Build Coastguard Worker stream = torch.cuda.Stream() 69*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(stream): 70*da0073e9SAndroid Build Coastguard Worker # Do an operation in the actual stream 71*da0073e9SAndroid Build Coastguard Worker x = make_tensor((5,), dtype=dtype, device=device) + 1 72*da0073e9SAndroid Build Coastguard Worker # DLPack protocol helps establish a correct stream order 73*da0073e9SAndroid Build Coastguard Worker # (hence data dependency) at the exchange boundary. 74*da0073e9SAndroid Build Coastguard Worker # DLPack manages this synchronization for us, so we don't need to 75*da0073e9SAndroid Build Coastguard Worker # explicitly wait until x is populated 76*da0073e9SAndroid Build Coastguard Worker if IS_JETSON: 77*da0073e9SAndroid Build Coastguard Worker # DLPack protocol that establishes correct stream order 78*da0073e9SAndroid Build Coastguard Worker # does not behave as expected on Jetson 79*da0073e9SAndroid Build Coastguard Worker stream.synchronize() 80*da0073e9SAndroid Build Coastguard Worker stream = torch.cuda.Stream() 81*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(stream): 82*da0073e9SAndroid Build Coastguard Worker z = from_dlpack(x) 83*da0073e9SAndroid Build Coastguard Worker stream.synchronize() 84*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, x) 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker @skipMeta 87*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 88*da0073e9SAndroid Build Coastguard Worker @dtypes( 89*da0073e9SAndroid Build Coastguard Worker *all_types_and_complex_and( 90*da0073e9SAndroid Build Coastguard Worker torch.half, 91*da0073e9SAndroid Build Coastguard Worker torch.bfloat16, 92*da0073e9SAndroid Build Coastguard Worker torch.bool, 93*da0073e9SAndroid Build Coastguard Worker torch.uint16, 94*da0073e9SAndroid Build Coastguard Worker torch.uint32, 95*da0073e9SAndroid Build Coastguard Worker torch.uint64, 96*da0073e9SAndroid Build Coastguard Worker ) 97*da0073e9SAndroid Build Coastguard Worker ) 98*da0073e9SAndroid Build Coastguard Worker def test_from_dlpack(self, device, dtype): 99*da0073e9SAndroid Build Coastguard Worker x = make_tensor((5,), dtype=dtype, device=device) 100*da0073e9SAndroid Build Coastguard Worker y = torch.from_dlpack(x) 101*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, y) 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker @skipMeta 104*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 105*da0073e9SAndroid Build Coastguard Worker @dtypes( 106*da0073e9SAndroid Build Coastguard Worker *all_types_and_complex_and( 107*da0073e9SAndroid Build Coastguard Worker torch.half, 108*da0073e9SAndroid Build Coastguard Worker torch.bfloat16, 109*da0073e9SAndroid Build Coastguard Worker torch.bool, 110*da0073e9SAndroid Build Coastguard Worker torch.uint16, 111*da0073e9SAndroid Build Coastguard Worker torch.uint32, 112*da0073e9SAndroid Build Coastguard Worker torch.uint64, 113*da0073e9SAndroid Build Coastguard Worker ) 114*da0073e9SAndroid Build Coastguard Worker ) 115*da0073e9SAndroid Build Coastguard Worker def test_from_dlpack_noncontinguous(self, device, dtype): 116*da0073e9SAndroid Build Coastguard Worker x = make_tensor((25,), dtype=dtype, device=device).reshape(5, 5) 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker y1 = x[0] 119*da0073e9SAndroid Build Coastguard Worker y1_dl = torch.from_dlpack(y1) 120*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1, y1_dl) 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker y2 = x[:, 0] 123*da0073e9SAndroid Build Coastguard Worker y2_dl = torch.from_dlpack(y2) 124*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y2, y2_dl) 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker y3 = x[1, :] 127*da0073e9SAndroid Build Coastguard Worker y3_dl = torch.from_dlpack(y3) 128*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y3, y3_dl) 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker y4 = x[1] 131*da0073e9SAndroid Build Coastguard Worker y4_dl = torch.from_dlpack(y4) 132*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y4, y4_dl) 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Worker y5 = x.t() 135*da0073e9SAndroid Build Coastguard Worker y5_dl = torch.from_dlpack(y5) 136*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y5, y5_dl) 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker @skipMeta 139*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 140*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) 141*da0073e9SAndroid Build Coastguard Worker def test_dlpack_conversion_with_diff_streams(self, device, dtype): 142*da0073e9SAndroid Build Coastguard Worker stream_a = torch.cuda.Stream() 143*da0073e9SAndroid Build Coastguard Worker stream_b = torch.cuda.Stream() 144*da0073e9SAndroid Build Coastguard Worker # DLPack protocol helps establish a correct stream order 145*da0073e9SAndroid Build Coastguard Worker # (hence data dependency) at the exchange boundary. 146*da0073e9SAndroid Build Coastguard Worker # the `tensor.__dlpack__` method will insert a synchronization event 147*da0073e9SAndroid Build Coastguard Worker # in the current stream to make sure that it was correctly populated. 148*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(stream_a): 149*da0073e9SAndroid Build Coastguard Worker x = make_tensor((5,), dtype=dtype, device=device) + 1 150*da0073e9SAndroid Build Coastguard Worker z = torch.from_dlpack(x.__dlpack__(stream_b.cuda_stream)) 151*da0073e9SAndroid Build Coastguard Worker stream_a.synchronize() 152*da0073e9SAndroid Build Coastguard Worker stream_b.synchronize() 153*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, x) 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker @skipMeta 156*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 157*da0073e9SAndroid Build Coastguard Worker @dtypes( 158*da0073e9SAndroid Build Coastguard Worker *all_types_and_complex_and( 159*da0073e9SAndroid Build Coastguard Worker torch.half, 160*da0073e9SAndroid Build Coastguard Worker torch.bfloat16, 161*da0073e9SAndroid Build Coastguard Worker torch.bool, 162*da0073e9SAndroid Build Coastguard Worker torch.uint16, 163*da0073e9SAndroid Build Coastguard Worker torch.uint32, 164*da0073e9SAndroid Build Coastguard Worker torch.uint64, 165*da0073e9SAndroid Build Coastguard Worker ) 166*da0073e9SAndroid Build Coastguard Worker ) 167*da0073e9SAndroid Build Coastguard Worker def test_from_dlpack_dtype(self, device, dtype): 168*da0073e9SAndroid Build Coastguard Worker x = make_tensor((5,), dtype=dtype, device=device) 169*da0073e9SAndroid Build Coastguard Worker y = torch.from_dlpack(x) 170*da0073e9SAndroid Build Coastguard Worker assert x.dtype == y.dtype 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker @skipMeta 173*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 174*da0073e9SAndroid Build Coastguard Worker def test_dlpack_default_stream(self, device): 175*da0073e9SAndroid Build Coastguard Worker class DLPackTensor: 176*da0073e9SAndroid Build Coastguard Worker def __init__(self, tensor): 177*da0073e9SAndroid Build Coastguard Worker self.tensor = tensor 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Worker def __dlpack_device__(self): 180*da0073e9SAndroid Build Coastguard Worker return self.tensor.__dlpack_device__() 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker def __dlpack__(self, stream=None): 183*da0073e9SAndroid Build Coastguard Worker if torch.version.hip is None: 184*da0073e9SAndroid Build Coastguard Worker assert stream == 1 185*da0073e9SAndroid Build Coastguard Worker else: 186*da0073e9SAndroid Build Coastguard Worker assert stream == 0 187*da0073e9SAndroid Build Coastguard Worker capsule = self.tensor.__dlpack__(stream) 188*da0073e9SAndroid Build Coastguard Worker return capsule 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker # CUDA-based tests runs on non-default streams 191*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(torch.cuda.default_stream()): 192*da0073e9SAndroid Build Coastguard Worker x = DLPackTensor(make_tensor((5,), dtype=torch.float32, device=device)) 193*da0073e9SAndroid Build Coastguard Worker from_dlpack(x) 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker @skipMeta 196*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 197*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfRocm 198*da0073e9SAndroid Build Coastguard Worker def test_dlpack_convert_default_stream(self, device): 199*da0073e9SAndroid Build Coastguard Worker # tests run on non-default stream, so _sleep call 200*da0073e9SAndroid Build Coastguard Worker # below will run on a non-default stream, causing 201*da0073e9SAndroid Build Coastguard Worker # default stream to wait due to inserted syncs 202*da0073e9SAndroid Build Coastguard Worker torch.cuda.default_stream().synchronize() 203*da0073e9SAndroid Build Coastguard Worker # run _sleep call on a non-default stream, causing 204*da0073e9SAndroid Build Coastguard Worker # default stream to wait due to inserted syncs 205*da0073e9SAndroid Build Coastguard Worker side_stream = torch.cuda.Stream() 206*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(side_stream): 207*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(1, device=device) 208*da0073e9SAndroid Build Coastguard Worker torch.cuda._sleep(2**20) 209*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.cuda.default_stream().query()) 210*da0073e9SAndroid Build Coastguard Worker d = x.__dlpack__(1) 211*da0073e9SAndroid Build Coastguard Worker # check that the default stream has work (a pending cudaStreamWaitEvent) 212*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.cuda.default_stream().query()) 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker @skipMeta 215*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 216*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16)) 217*da0073e9SAndroid Build Coastguard Worker def test_dlpack_tensor_invalid_stream(self, device, dtype): 218*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 219*da0073e9SAndroid Build Coastguard Worker x = make_tensor((5,), dtype=dtype, device=device) 220*da0073e9SAndroid Build Coastguard Worker x.__dlpack__(stream=object()) 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker # TODO: add interchange tests once NumPy 1.22 (dlpack support) is required 223*da0073e9SAndroid Build Coastguard Worker @skipMeta 224*da0073e9SAndroid Build Coastguard Worker def test_dlpack_export_requires_grad(self): 225*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(10, dtype=torch.float32, requires_grad=True) 226*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"require gradient"): 227*da0073e9SAndroid Build Coastguard Worker x.__dlpack__() 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard Worker @skipMeta 230*da0073e9SAndroid Build Coastguard Worker def test_dlpack_export_is_conj(self): 231*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]) 232*da0073e9SAndroid Build Coastguard Worker y = torch.conj(x) 233*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"conjugate bit"): 234*da0073e9SAndroid Build Coastguard Worker y.__dlpack__() 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker @skipMeta 237*da0073e9SAndroid Build Coastguard Worker def test_dlpack_export_non_strided(self): 238*da0073e9SAndroid Build Coastguard Worker x = torch.sparse_coo_tensor([[0]], [1], size=(1,)) 239*da0073e9SAndroid Build Coastguard Worker y = torch.conj(x) 240*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"strided"): 241*da0073e9SAndroid Build Coastguard Worker y.__dlpack__() 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Worker @skipMeta 244*da0073e9SAndroid Build Coastguard Worker def test_dlpack_normalize_strides(self): 245*da0073e9SAndroid Build Coastguard Worker x = torch.rand(16) 246*da0073e9SAndroid Build Coastguard Worker y = x[::3][:1] 247*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.shape, (1,)) 248*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.stride(), (3,)) 249*da0073e9SAndroid Build Coastguard Worker z = from_dlpack(y) 250*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.shape, (1,)) 251*da0073e9SAndroid Build Coastguard Worker # gh-83069, make sure __dlpack__ normalizes strides 252*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.stride(), (1,)) 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestTorchDlPack, globals()) 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 258*da0073e9SAndroid Build Coastguard Worker run_tests() 259