xref: /aosp_15_r20/external/pytorch/test/test_dlpack.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: tests"]
2
3import torch
4from torch.testing import make_tensor
5from torch.testing._internal.common_device_type import (
6    dtypes,
7    instantiate_device_type_tests,
8    onlyCUDA,
9    onlyNativeDeviceTypes,
10    skipCUDAIfRocm,
11    skipMeta,
12)
13from torch.testing._internal.common_dtype import all_types_and_complex_and
14from torch.testing._internal.common_utils import IS_JETSON, run_tests, TestCase
15from torch.utils.dlpack import from_dlpack, to_dlpack
16
17
18class TestTorchDlPack(TestCase):
19    exact_dtype = True
20
21    @skipMeta
22    @onlyNativeDeviceTypes
23    @dtypes(
24        *all_types_and_complex_and(
25            torch.half,
26            torch.bfloat16,
27            torch.bool,
28            torch.uint16,
29            torch.uint32,
30            torch.uint64,
31        )
32    )
33    def test_dlpack_capsule_conversion(self, device, dtype):
34        x = make_tensor((5,), dtype=dtype, device=device)
35        z = from_dlpack(to_dlpack(x))
36        self.assertEqual(z, x)
37
38    @skipMeta
39    @onlyNativeDeviceTypes
40    @dtypes(
41        *all_types_and_complex_and(
42            torch.half,
43            torch.bfloat16,
44            torch.bool,
45            torch.uint16,
46            torch.uint32,
47            torch.uint64,
48        )
49    )
50    def test_dlpack_protocol_conversion(self, device, dtype):
51        x = make_tensor((5,), dtype=dtype, device=device)
52        z = from_dlpack(x)
53        self.assertEqual(z, x)
54
55    @skipMeta
56    @onlyNativeDeviceTypes
57    def test_dlpack_shared_storage(self, device):
58        x = make_tensor((5,), dtype=torch.float64, device=device)
59        z = from_dlpack(to_dlpack(x))
60        z[0] = z[0] + 20.0
61        self.assertEqual(z, x)
62
63    @skipMeta
64    @onlyCUDA
65    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
66    def test_dlpack_conversion_with_streams(self, device, dtype):
67        # Create a stream where the tensor will reside
68        stream = torch.cuda.Stream()
69        with torch.cuda.stream(stream):
70            # Do an operation in the actual stream
71            x = make_tensor((5,), dtype=dtype, device=device) + 1
72        # DLPack protocol helps establish a correct stream order
73        # (hence data dependency) at the exchange boundary.
74        # DLPack manages this synchronization for us, so we don't need to
75        # explicitly wait until x is populated
76        if IS_JETSON:
77            # DLPack protocol that establishes correct stream order
78            # does not behave as expected on Jetson
79            stream.synchronize()
80        stream = torch.cuda.Stream()
81        with torch.cuda.stream(stream):
82            z = from_dlpack(x)
83        stream.synchronize()
84        self.assertEqual(z, x)
85
86    @skipMeta
87    @onlyNativeDeviceTypes
88    @dtypes(
89        *all_types_and_complex_and(
90            torch.half,
91            torch.bfloat16,
92            torch.bool,
93            torch.uint16,
94            torch.uint32,
95            torch.uint64,
96        )
97    )
98    def test_from_dlpack(self, device, dtype):
99        x = make_tensor((5,), dtype=dtype, device=device)
100        y = torch.from_dlpack(x)
101        self.assertEqual(x, y)
102
103    @skipMeta
104    @onlyNativeDeviceTypes
105    @dtypes(
106        *all_types_and_complex_and(
107            torch.half,
108            torch.bfloat16,
109            torch.bool,
110            torch.uint16,
111            torch.uint32,
112            torch.uint64,
113        )
114    )
115    def test_from_dlpack_noncontinguous(self, device, dtype):
116        x = make_tensor((25,), dtype=dtype, device=device).reshape(5, 5)
117
118        y1 = x[0]
119        y1_dl = torch.from_dlpack(y1)
120        self.assertEqual(y1, y1_dl)
121
122        y2 = x[:, 0]
123        y2_dl = torch.from_dlpack(y2)
124        self.assertEqual(y2, y2_dl)
125
126        y3 = x[1, :]
127        y3_dl = torch.from_dlpack(y3)
128        self.assertEqual(y3, y3_dl)
129
130        y4 = x[1]
131        y4_dl = torch.from_dlpack(y4)
132        self.assertEqual(y4, y4_dl)
133
134        y5 = x.t()
135        y5_dl = torch.from_dlpack(y5)
136        self.assertEqual(y5, y5_dl)
137
138    @skipMeta
139    @onlyCUDA
140    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
141    def test_dlpack_conversion_with_diff_streams(self, device, dtype):
142        stream_a = torch.cuda.Stream()
143        stream_b = torch.cuda.Stream()
144        # DLPack protocol helps establish a correct stream order
145        # (hence data dependency) at the exchange boundary.
146        # the `tensor.__dlpack__` method will insert a synchronization event
147        # in the current stream to make sure that it was correctly populated.
148        with torch.cuda.stream(stream_a):
149            x = make_tensor((5,), dtype=dtype, device=device) + 1
150            z = torch.from_dlpack(x.__dlpack__(stream_b.cuda_stream))
151            stream_a.synchronize()
152        stream_b.synchronize()
153        self.assertEqual(z, x)
154
155    @skipMeta
156    @onlyNativeDeviceTypes
157    @dtypes(
158        *all_types_and_complex_and(
159            torch.half,
160            torch.bfloat16,
161            torch.bool,
162            torch.uint16,
163            torch.uint32,
164            torch.uint64,
165        )
166    )
167    def test_from_dlpack_dtype(self, device, dtype):
168        x = make_tensor((5,), dtype=dtype, device=device)
169        y = torch.from_dlpack(x)
170        assert x.dtype == y.dtype
171
172    @skipMeta
173    @onlyCUDA
174    def test_dlpack_default_stream(self, device):
175        class DLPackTensor:
176            def __init__(self, tensor):
177                self.tensor = tensor
178
179            def __dlpack_device__(self):
180                return self.tensor.__dlpack_device__()
181
182            def __dlpack__(self, stream=None):
183                if torch.version.hip is None:
184                    assert stream == 1
185                else:
186                    assert stream == 0
187                capsule = self.tensor.__dlpack__(stream)
188                return capsule
189
190        # CUDA-based tests runs on non-default streams
191        with torch.cuda.stream(torch.cuda.default_stream()):
192            x = DLPackTensor(make_tensor((5,), dtype=torch.float32, device=device))
193            from_dlpack(x)
194
195    @skipMeta
196    @onlyCUDA
197    @skipCUDAIfRocm
198    def test_dlpack_convert_default_stream(self, device):
199        # tests run on non-default stream, so _sleep call
200        # below will run on a non-default stream, causing
201        # default stream to wait due to inserted syncs
202        torch.cuda.default_stream().synchronize()
203        # run _sleep call on a non-default stream, causing
204        # default stream to wait due to inserted syncs
205        side_stream = torch.cuda.Stream()
206        with torch.cuda.stream(side_stream):
207            x = torch.zeros(1, device=device)
208            torch.cuda._sleep(2**20)
209            self.assertTrue(torch.cuda.default_stream().query())
210            d = x.__dlpack__(1)
211        # check that the default stream has work (a pending cudaStreamWaitEvent)
212        self.assertFalse(torch.cuda.default_stream().query())
213
214    @skipMeta
215    @onlyNativeDeviceTypes
216    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
217    def test_dlpack_tensor_invalid_stream(self, device, dtype):
218        with self.assertRaises(TypeError):
219            x = make_tensor((5,), dtype=dtype, device=device)
220            x.__dlpack__(stream=object())
221
222    # TODO: add interchange tests once NumPy 1.22 (dlpack support) is required
223    @skipMeta
224    def test_dlpack_export_requires_grad(self):
225        x = torch.zeros(10, dtype=torch.float32, requires_grad=True)
226        with self.assertRaisesRegex(RuntimeError, r"require gradient"):
227            x.__dlpack__()
228
229    @skipMeta
230    def test_dlpack_export_is_conj(self):
231        x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])
232        y = torch.conj(x)
233        with self.assertRaisesRegex(RuntimeError, r"conjugate bit"):
234            y.__dlpack__()
235
236    @skipMeta
237    def test_dlpack_export_non_strided(self):
238        x = torch.sparse_coo_tensor([[0]], [1], size=(1,))
239        y = torch.conj(x)
240        with self.assertRaisesRegex(RuntimeError, r"strided"):
241            y.__dlpack__()
242
243    @skipMeta
244    def test_dlpack_normalize_strides(self):
245        x = torch.rand(16)
246        y = x[::3][:1]
247        self.assertEqual(y.shape, (1,))
248        self.assertEqual(y.stride(), (3,))
249        z = from_dlpack(y)
250        self.assertEqual(z.shape, (1,))
251        # gh-83069, make sure __dlpack__ normalizes strides
252        self.assertEqual(z.stride(), (1,))
253
254
255instantiate_device_type_tests(TestTorchDlPack, globals())
256
257if __name__ == "__main__":
258    run_tests()
259