1# Owner(s): ["module: tests"] 2 3import torch 4from torch.testing import make_tensor 5from torch.testing._internal.common_device_type import ( 6 dtypes, 7 instantiate_device_type_tests, 8 onlyCUDA, 9 onlyNativeDeviceTypes, 10 skipCUDAIfRocm, 11 skipMeta, 12) 13from torch.testing._internal.common_dtype import all_types_and_complex_and 14from torch.testing._internal.common_utils import IS_JETSON, run_tests, TestCase 15from torch.utils.dlpack import from_dlpack, to_dlpack 16 17 18class TestTorchDlPack(TestCase): 19 exact_dtype = True 20 21 @skipMeta 22 @onlyNativeDeviceTypes 23 @dtypes( 24 *all_types_and_complex_and( 25 torch.half, 26 torch.bfloat16, 27 torch.bool, 28 torch.uint16, 29 torch.uint32, 30 torch.uint64, 31 ) 32 ) 33 def test_dlpack_capsule_conversion(self, device, dtype): 34 x = make_tensor((5,), dtype=dtype, device=device) 35 z = from_dlpack(to_dlpack(x)) 36 self.assertEqual(z, x) 37 38 @skipMeta 39 @onlyNativeDeviceTypes 40 @dtypes( 41 *all_types_and_complex_and( 42 torch.half, 43 torch.bfloat16, 44 torch.bool, 45 torch.uint16, 46 torch.uint32, 47 torch.uint64, 48 ) 49 ) 50 def test_dlpack_protocol_conversion(self, device, dtype): 51 x = make_tensor((5,), dtype=dtype, device=device) 52 z = from_dlpack(x) 53 self.assertEqual(z, x) 54 55 @skipMeta 56 @onlyNativeDeviceTypes 57 def test_dlpack_shared_storage(self, device): 58 x = make_tensor((5,), dtype=torch.float64, device=device) 59 z = from_dlpack(to_dlpack(x)) 60 z[0] = z[0] + 20.0 61 self.assertEqual(z, x) 62 63 @skipMeta 64 @onlyCUDA 65 @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) 66 def test_dlpack_conversion_with_streams(self, device, dtype): 67 # Create a stream where the tensor will reside 68 stream = torch.cuda.Stream() 69 with torch.cuda.stream(stream): 70 # Do an operation in the actual stream 71 x = make_tensor((5,), dtype=dtype, device=device) + 1 72 # DLPack protocol helps establish a correct stream order 73 # (hence data dependency) at the exchange boundary. 74 # DLPack manages this synchronization for us, so we don't need to 75 # explicitly wait until x is populated 76 if IS_JETSON: 77 # DLPack protocol that establishes correct stream order 78 # does not behave as expected on Jetson 79 stream.synchronize() 80 stream = torch.cuda.Stream() 81 with torch.cuda.stream(stream): 82 z = from_dlpack(x) 83 stream.synchronize() 84 self.assertEqual(z, x) 85 86 @skipMeta 87 @onlyNativeDeviceTypes 88 @dtypes( 89 *all_types_and_complex_and( 90 torch.half, 91 torch.bfloat16, 92 torch.bool, 93 torch.uint16, 94 torch.uint32, 95 torch.uint64, 96 ) 97 ) 98 def test_from_dlpack(self, device, dtype): 99 x = make_tensor((5,), dtype=dtype, device=device) 100 y = torch.from_dlpack(x) 101 self.assertEqual(x, y) 102 103 @skipMeta 104 @onlyNativeDeviceTypes 105 @dtypes( 106 *all_types_and_complex_and( 107 torch.half, 108 torch.bfloat16, 109 torch.bool, 110 torch.uint16, 111 torch.uint32, 112 torch.uint64, 113 ) 114 ) 115 def test_from_dlpack_noncontinguous(self, device, dtype): 116 x = make_tensor((25,), dtype=dtype, device=device).reshape(5, 5) 117 118 y1 = x[0] 119 y1_dl = torch.from_dlpack(y1) 120 self.assertEqual(y1, y1_dl) 121 122 y2 = x[:, 0] 123 y2_dl = torch.from_dlpack(y2) 124 self.assertEqual(y2, y2_dl) 125 126 y3 = x[1, :] 127 y3_dl = torch.from_dlpack(y3) 128 self.assertEqual(y3, y3_dl) 129 130 y4 = x[1] 131 y4_dl = torch.from_dlpack(y4) 132 self.assertEqual(y4, y4_dl) 133 134 y5 = x.t() 135 y5_dl = torch.from_dlpack(y5) 136 self.assertEqual(y5, y5_dl) 137 138 @skipMeta 139 @onlyCUDA 140 @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) 141 def test_dlpack_conversion_with_diff_streams(self, device, dtype): 142 stream_a = torch.cuda.Stream() 143 stream_b = torch.cuda.Stream() 144 # DLPack protocol helps establish a correct stream order 145 # (hence data dependency) at the exchange boundary. 146 # the `tensor.__dlpack__` method will insert a synchronization event 147 # in the current stream to make sure that it was correctly populated. 148 with torch.cuda.stream(stream_a): 149 x = make_tensor((5,), dtype=dtype, device=device) + 1 150 z = torch.from_dlpack(x.__dlpack__(stream_b.cuda_stream)) 151 stream_a.synchronize() 152 stream_b.synchronize() 153 self.assertEqual(z, x) 154 155 @skipMeta 156 @onlyNativeDeviceTypes 157 @dtypes( 158 *all_types_and_complex_and( 159 torch.half, 160 torch.bfloat16, 161 torch.bool, 162 torch.uint16, 163 torch.uint32, 164 torch.uint64, 165 ) 166 ) 167 def test_from_dlpack_dtype(self, device, dtype): 168 x = make_tensor((5,), dtype=dtype, device=device) 169 y = torch.from_dlpack(x) 170 assert x.dtype == y.dtype 171 172 @skipMeta 173 @onlyCUDA 174 def test_dlpack_default_stream(self, device): 175 class DLPackTensor: 176 def __init__(self, tensor): 177 self.tensor = tensor 178 179 def __dlpack_device__(self): 180 return self.tensor.__dlpack_device__() 181 182 def __dlpack__(self, stream=None): 183 if torch.version.hip is None: 184 assert stream == 1 185 else: 186 assert stream == 0 187 capsule = self.tensor.__dlpack__(stream) 188 return capsule 189 190 # CUDA-based tests runs on non-default streams 191 with torch.cuda.stream(torch.cuda.default_stream()): 192 x = DLPackTensor(make_tensor((5,), dtype=torch.float32, device=device)) 193 from_dlpack(x) 194 195 @skipMeta 196 @onlyCUDA 197 @skipCUDAIfRocm 198 def test_dlpack_convert_default_stream(self, device): 199 # tests run on non-default stream, so _sleep call 200 # below will run on a non-default stream, causing 201 # default stream to wait due to inserted syncs 202 torch.cuda.default_stream().synchronize() 203 # run _sleep call on a non-default stream, causing 204 # default stream to wait due to inserted syncs 205 side_stream = torch.cuda.Stream() 206 with torch.cuda.stream(side_stream): 207 x = torch.zeros(1, device=device) 208 torch.cuda._sleep(2**20) 209 self.assertTrue(torch.cuda.default_stream().query()) 210 d = x.__dlpack__(1) 211 # check that the default stream has work (a pending cudaStreamWaitEvent) 212 self.assertFalse(torch.cuda.default_stream().query()) 213 214 @skipMeta 215 @onlyNativeDeviceTypes 216 @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16)) 217 def test_dlpack_tensor_invalid_stream(self, device, dtype): 218 with self.assertRaises(TypeError): 219 x = make_tensor((5,), dtype=dtype, device=device) 220 x.__dlpack__(stream=object()) 221 222 # TODO: add interchange tests once NumPy 1.22 (dlpack support) is required 223 @skipMeta 224 def test_dlpack_export_requires_grad(self): 225 x = torch.zeros(10, dtype=torch.float32, requires_grad=True) 226 with self.assertRaisesRegex(RuntimeError, r"require gradient"): 227 x.__dlpack__() 228 229 @skipMeta 230 def test_dlpack_export_is_conj(self): 231 x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]) 232 y = torch.conj(x) 233 with self.assertRaisesRegex(RuntimeError, r"conjugate bit"): 234 y.__dlpack__() 235 236 @skipMeta 237 def test_dlpack_export_non_strided(self): 238 x = torch.sparse_coo_tensor([[0]], [1], size=(1,)) 239 y = torch.conj(x) 240 with self.assertRaisesRegex(RuntimeError, r"strided"): 241 y.__dlpack__() 242 243 @skipMeta 244 def test_dlpack_normalize_strides(self): 245 x = torch.rand(16) 246 y = x[::3][:1] 247 self.assertEqual(y.shape, (1,)) 248 self.assertEqual(y.stride(), (3,)) 249 z = from_dlpack(y) 250 self.assertEqual(z.shape, (1,)) 251 # gh-83069, make sure __dlpack__ normalizes strides 252 self.assertEqual(z.stride(), (1,)) 253 254 255instantiate_device_type_tests(TestTorchDlPack, globals()) 256 257if __name__ == "__main__": 258 run_tests() 259