xref: /aosp_15_r20/external/pytorch/test/test_dlpack.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: tests"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport torch
4*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor
5*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
6*da0073e9SAndroid Build Coastguard Worker    dtypes,
7*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
8*da0073e9SAndroid Build Coastguard Worker    onlyCUDA,
9*da0073e9SAndroid Build Coastguard Worker    onlyNativeDeviceTypes,
10*da0073e9SAndroid Build Coastguard Worker    skipCUDAIfRocm,
11*da0073e9SAndroid Build Coastguard Worker    skipMeta,
12*da0073e9SAndroid Build Coastguard Worker)
13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import all_types_and_complex_and
14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import IS_JETSON, run_tests, TestCase
15*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.dlpack import from_dlpack, to_dlpack
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Workerclass TestTorchDlPack(TestCase):
19*da0073e9SAndroid Build Coastguard Worker    exact_dtype = True
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker    @skipMeta
22*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
23*da0073e9SAndroid Build Coastguard Worker    @dtypes(
24*da0073e9SAndroid Build Coastguard Worker        *all_types_and_complex_and(
25*da0073e9SAndroid Build Coastguard Worker            torch.half,
26*da0073e9SAndroid Build Coastguard Worker            torch.bfloat16,
27*da0073e9SAndroid Build Coastguard Worker            torch.bool,
28*da0073e9SAndroid Build Coastguard Worker            torch.uint16,
29*da0073e9SAndroid Build Coastguard Worker            torch.uint32,
30*da0073e9SAndroid Build Coastguard Worker            torch.uint64,
31*da0073e9SAndroid Build Coastguard Worker        )
32*da0073e9SAndroid Build Coastguard Worker    )
33*da0073e9SAndroid Build Coastguard Worker    def test_dlpack_capsule_conversion(self, device, dtype):
34*da0073e9SAndroid Build Coastguard Worker        x = make_tensor((5,), dtype=dtype, device=device)
35*da0073e9SAndroid Build Coastguard Worker        z = from_dlpack(to_dlpack(x))
36*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z, x)
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker    @skipMeta
39*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
40*da0073e9SAndroid Build Coastguard Worker    @dtypes(
41*da0073e9SAndroid Build Coastguard Worker        *all_types_and_complex_and(
42*da0073e9SAndroid Build Coastguard Worker            torch.half,
43*da0073e9SAndroid Build Coastguard Worker            torch.bfloat16,
44*da0073e9SAndroid Build Coastguard Worker            torch.bool,
45*da0073e9SAndroid Build Coastguard Worker            torch.uint16,
46*da0073e9SAndroid Build Coastguard Worker            torch.uint32,
47*da0073e9SAndroid Build Coastguard Worker            torch.uint64,
48*da0073e9SAndroid Build Coastguard Worker        )
49*da0073e9SAndroid Build Coastguard Worker    )
50*da0073e9SAndroid Build Coastguard Worker    def test_dlpack_protocol_conversion(self, device, dtype):
51*da0073e9SAndroid Build Coastguard Worker        x = make_tensor((5,), dtype=dtype, device=device)
52*da0073e9SAndroid Build Coastguard Worker        z = from_dlpack(x)
53*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z, x)
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker    @skipMeta
56*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
57*da0073e9SAndroid Build Coastguard Worker    def test_dlpack_shared_storage(self, device):
58*da0073e9SAndroid Build Coastguard Worker        x = make_tensor((5,), dtype=torch.float64, device=device)
59*da0073e9SAndroid Build Coastguard Worker        z = from_dlpack(to_dlpack(x))
60*da0073e9SAndroid Build Coastguard Worker        z[0] = z[0] + 20.0
61*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z, x)
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker    @skipMeta
64*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
65*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
66*da0073e9SAndroid Build Coastguard Worker    def test_dlpack_conversion_with_streams(self, device, dtype):
67*da0073e9SAndroid Build Coastguard Worker        # Create a stream where the tensor will reside
68*da0073e9SAndroid Build Coastguard Worker        stream = torch.cuda.Stream()
69*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(stream):
70*da0073e9SAndroid Build Coastguard Worker            # Do an operation in the actual stream
71*da0073e9SAndroid Build Coastguard Worker            x = make_tensor((5,), dtype=dtype, device=device) + 1
72*da0073e9SAndroid Build Coastguard Worker        # DLPack protocol helps establish a correct stream order
73*da0073e9SAndroid Build Coastguard Worker        # (hence data dependency) at the exchange boundary.
74*da0073e9SAndroid Build Coastguard Worker        # DLPack manages this synchronization for us, so we don't need to
75*da0073e9SAndroid Build Coastguard Worker        # explicitly wait until x is populated
76*da0073e9SAndroid Build Coastguard Worker        if IS_JETSON:
77*da0073e9SAndroid Build Coastguard Worker            # DLPack protocol that establishes correct stream order
78*da0073e9SAndroid Build Coastguard Worker            # does not behave as expected on Jetson
79*da0073e9SAndroid Build Coastguard Worker            stream.synchronize()
80*da0073e9SAndroid Build Coastguard Worker        stream = torch.cuda.Stream()
81*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(stream):
82*da0073e9SAndroid Build Coastguard Worker            z = from_dlpack(x)
83*da0073e9SAndroid Build Coastguard Worker        stream.synchronize()
84*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z, x)
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker    @skipMeta
87*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
88*da0073e9SAndroid Build Coastguard Worker    @dtypes(
89*da0073e9SAndroid Build Coastguard Worker        *all_types_and_complex_and(
90*da0073e9SAndroid Build Coastguard Worker            torch.half,
91*da0073e9SAndroid Build Coastguard Worker            torch.bfloat16,
92*da0073e9SAndroid Build Coastguard Worker            torch.bool,
93*da0073e9SAndroid Build Coastguard Worker            torch.uint16,
94*da0073e9SAndroid Build Coastguard Worker            torch.uint32,
95*da0073e9SAndroid Build Coastguard Worker            torch.uint64,
96*da0073e9SAndroid Build Coastguard Worker        )
97*da0073e9SAndroid Build Coastguard Worker    )
98*da0073e9SAndroid Build Coastguard Worker    def test_from_dlpack(self, device, dtype):
99*da0073e9SAndroid Build Coastguard Worker        x = make_tensor((5,), dtype=dtype, device=device)
100*da0073e9SAndroid Build Coastguard Worker        y = torch.from_dlpack(x)
101*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, y)
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker    @skipMeta
104*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
105*da0073e9SAndroid Build Coastguard Worker    @dtypes(
106*da0073e9SAndroid Build Coastguard Worker        *all_types_and_complex_and(
107*da0073e9SAndroid Build Coastguard Worker            torch.half,
108*da0073e9SAndroid Build Coastguard Worker            torch.bfloat16,
109*da0073e9SAndroid Build Coastguard Worker            torch.bool,
110*da0073e9SAndroid Build Coastguard Worker            torch.uint16,
111*da0073e9SAndroid Build Coastguard Worker            torch.uint32,
112*da0073e9SAndroid Build Coastguard Worker            torch.uint64,
113*da0073e9SAndroid Build Coastguard Worker        )
114*da0073e9SAndroid Build Coastguard Worker    )
115*da0073e9SAndroid Build Coastguard Worker    def test_from_dlpack_noncontinguous(self, device, dtype):
116*da0073e9SAndroid Build Coastguard Worker        x = make_tensor((25,), dtype=dtype, device=device).reshape(5, 5)
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker        y1 = x[0]
119*da0073e9SAndroid Build Coastguard Worker        y1_dl = torch.from_dlpack(y1)
120*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y1, y1_dl)
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker        y2 = x[:, 0]
123*da0073e9SAndroid Build Coastguard Worker        y2_dl = torch.from_dlpack(y2)
124*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y2, y2_dl)
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker        y3 = x[1, :]
127*da0073e9SAndroid Build Coastguard Worker        y3_dl = torch.from_dlpack(y3)
128*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y3, y3_dl)
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker        y4 = x[1]
131*da0073e9SAndroid Build Coastguard Worker        y4_dl = torch.from_dlpack(y4)
132*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y4, y4_dl)
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker        y5 = x.t()
135*da0073e9SAndroid Build Coastguard Worker        y5_dl = torch.from_dlpack(y5)
136*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y5, y5_dl)
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker    @skipMeta
139*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
140*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
141*da0073e9SAndroid Build Coastguard Worker    def test_dlpack_conversion_with_diff_streams(self, device, dtype):
142*da0073e9SAndroid Build Coastguard Worker        stream_a = torch.cuda.Stream()
143*da0073e9SAndroid Build Coastguard Worker        stream_b = torch.cuda.Stream()
144*da0073e9SAndroid Build Coastguard Worker        # DLPack protocol helps establish a correct stream order
145*da0073e9SAndroid Build Coastguard Worker        # (hence data dependency) at the exchange boundary.
146*da0073e9SAndroid Build Coastguard Worker        # the `tensor.__dlpack__` method will insert a synchronization event
147*da0073e9SAndroid Build Coastguard Worker        # in the current stream to make sure that it was correctly populated.
148*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(stream_a):
149*da0073e9SAndroid Build Coastguard Worker            x = make_tensor((5,), dtype=dtype, device=device) + 1
150*da0073e9SAndroid Build Coastguard Worker            z = torch.from_dlpack(x.__dlpack__(stream_b.cuda_stream))
151*da0073e9SAndroid Build Coastguard Worker            stream_a.synchronize()
152*da0073e9SAndroid Build Coastguard Worker        stream_b.synchronize()
153*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z, x)
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker    @skipMeta
156*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
157*da0073e9SAndroid Build Coastguard Worker    @dtypes(
158*da0073e9SAndroid Build Coastguard Worker        *all_types_and_complex_and(
159*da0073e9SAndroid Build Coastguard Worker            torch.half,
160*da0073e9SAndroid Build Coastguard Worker            torch.bfloat16,
161*da0073e9SAndroid Build Coastguard Worker            torch.bool,
162*da0073e9SAndroid Build Coastguard Worker            torch.uint16,
163*da0073e9SAndroid Build Coastguard Worker            torch.uint32,
164*da0073e9SAndroid Build Coastguard Worker            torch.uint64,
165*da0073e9SAndroid Build Coastguard Worker        )
166*da0073e9SAndroid Build Coastguard Worker    )
167*da0073e9SAndroid Build Coastguard Worker    def test_from_dlpack_dtype(self, device, dtype):
168*da0073e9SAndroid Build Coastguard Worker        x = make_tensor((5,), dtype=dtype, device=device)
169*da0073e9SAndroid Build Coastguard Worker        y = torch.from_dlpack(x)
170*da0073e9SAndroid Build Coastguard Worker        assert x.dtype == y.dtype
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker    @skipMeta
173*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
174*da0073e9SAndroid Build Coastguard Worker    def test_dlpack_default_stream(self, device):
175*da0073e9SAndroid Build Coastguard Worker        class DLPackTensor:
176*da0073e9SAndroid Build Coastguard Worker            def __init__(self, tensor):
177*da0073e9SAndroid Build Coastguard Worker                self.tensor = tensor
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Worker            def __dlpack_device__(self):
180*da0073e9SAndroid Build Coastguard Worker                return self.tensor.__dlpack_device__()
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker            def __dlpack__(self, stream=None):
183*da0073e9SAndroid Build Coastguard Worker                if torch.version.hip is None:
184*da0073e9SAndroid Build Coastguard Worker                    assert stream == 1
185*da0073e9SAndroid Build Coastguard Worker                else:
186*da0073e9SAndroid Build Coastguard Worker                    assert stream == 0
187*da0073e9SAndroid Build Coastguard Worker                capsule = self.tensor.__dlpack__(stream)
188*da0073e9SAndroid Build Coastguard Worker                return capsule
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker        # CUDA-based tests runs on non-default streams
191*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(torch.cuda.default_stream()):
192*da0073e9SAndroid Build Coastguard Worker            x = DLPackTensor(make_tensor((5,), dtype=torch.float32, device=device))
193*da0073e9SAndroid Build Coastguard Worker            from_dlpack(x)
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker    @skipMeta
196*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
197*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfRocm
198*da0073e9SAndroid Build Coastguard Worker    def test_dlpack_convert_default_stream(self, device):
199*da0073e9SAndroid Build Coastguard Worker        # tests run on non-default stream, so _sleep call
200*da0073e9SAndroid Build Coastguard Worker        # below will run on a non-default stream, causing
201*da0073e9SAndroid Build Coastguard Worker        # default stream to wait due to inserted syncs
202*da0073e9SAndroid Build Coastguard Worker        torch.cuda.default_stream().synchronize()
203*da0073e9SAndroid Build Coastguard Worker        # run _sleep call on a non-default stream, causing
204*da0073e9SAndroid Build Coastguard Worker        # default stream to wait due to inserted syncs
205*da0073e9SAndroid Build Coastguard Worker        side_stream = torch.cuda.Stream()
206*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(side_stream):
207*da0073e9SAndroid Build Coastguard Worker            x = torch.zeros(1, device=device)
208*da0073e9SAndroid Build Coastguard Worker            torch.cuda._sleep(2**20)
209*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.cuda.default_stream().query())
210*da0073e9SAndroid Build Coastguard Worker            d = x.__dlpack__(1)
211*da0073e9SAndroid Build Coastguard Worker        # check that the default stream has work (a pending cudaStreamWaitEvent)
212*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.cuda.default_stream().query())
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker    @skipMeta
215*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
216*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
217*da0073e9SAndroid Build Coastguard Worker    def test_dlpack_tensor_invalid_stream(self, device, dtype):
218*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
219*da0073e9SAndroid Build Coastguard Worker            x = make_tensor((5,), dtype=dtype, device=device)
220*da0073e9SAndroid Build Coastguard Worker            x.__dlpack__(stream=object())
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker    # TODO: add interchange tests once NumPy 1.22 (dlpack support) is required
223*da0073e9SAndroid Build Coastguard Worker    @skipMeta
224*da0073e9SAndroid Build Coastguard Worker    def test_dlpack_export_requires_grad(self):
225*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(10, dtype=torch.float32, requires_grad=True)
226*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"require gradient"):
227*da0073e9SAndroid Build Coastguard Worker            x.__dlpack__()
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Worker    @skipMeta
230*da0073e9SAndroid Build Coastguard Worker    def test_dlpack_export_is_conj(self):
231*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])
232*da0073e9SAndroid Build Coastguard Worker        y = torch.conj(x)
233*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"conjugate bit"):
234*da0073e9SAndroid Build Coastguard Worker            y.__dlpack__()
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker    @skipMeta
237*da0073e9SAndroid Build Coastguard Worker    def test_dlpack_export_non_strided(self):
238*da0073e9SAndroid Build Coastguard Worker        x = torch.sparse_coo_tensor([[0]], [1], size=(1,))
239*da0073e9SAndroid Build Coastguard Worker        y = torch.conj(x)
240*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"strided"):
241*da0073e9SAndroid Build Coastguard Worker            y.__dlpack__()
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Worker    @skipMeta
244*da0073e9SAndroid Build Coastguard Worker    def test_dlpack_normalize_strides(self):
245*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(16)
246*da0073e9SAndroid Build Coastguard Worker        y = x[::3][:1]
247*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.shape, (1,))
248*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.stride(), (3,))
249*da0073e9SAndroid Build Coastguard Worker        z = from_dlpack(y)
250*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z.shape, (1,))
251*da0073e9SAndroid Build Coastguard Worker        # gh-83069, make sure __dlpack__ normalizes strides
252*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z.stride(), (1,))
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestTorchDlPack, globals())
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
258*da0073e9SAndroid Build Coastguard Worker    run_tests()
259