xref: /aosp_15_r20/external/pytorch/test/test_xpu.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: intel"]
2
3import collections
4import sys
5import tempfile
6import unittest
7
8import torch
9import torch.xpu._gpu_trace as gpu_trace
10from torch.testing._internal.autocast_test_lists import AutocastTestLists
11from torch.testing._internal.common_device_type import (
12    instantiate_device_type_tests,
13    onlyXPU,
14    OpDTypes,
15    ops,
16)
17from torch.testing._internal.common_methods_invocations import ops_and_refs
18from torch.testing._internal.common_utils import (
19    NoTest,
20    run_tests,
21    suppress_warnings,
22    TEST_WITH_UBSAN,
23    TEST_XPU,
24    TestCase,
25)
26
27if not TEST_XPU:
28    print("XPU not available, skipping tests", file=sys.stderr)
29    TestCase = NoTest  # noqa: F811
30
31TEST_MULTIXPU = torch.xpu.device_count() > 1
32
33cpu_device = torch.device("cpu")
34xpu_device = torch.device("xpu")
35
36any_common_cpu_xpu_one = OpDTypes.any_common_cpu_cuda_one
37_xpu_computation_op_list = [
38    "fill",
39    "zeros",
40    "zeros_like",
41    "clone",
42    "view_as_real",
43    "view_as_complex",
44    "view",
45    "resize_",
46    "resize_as_",
47    "add",
48    "sub",
49    "mul",
50    "div",
51    "abs",
52]
53_xpu_tensor_factory_op_list = [
54    "as_strided",
55    "empty",
56    "empty_strided",
57]
58_xpu_not_test_dtype_op_list = [
59    "resize_",  # Skipped by CPU
60    "resize_as_",  # Skipped by CPU
61    "abs",  # Not aligned dtype
62]
63_xpu_all_op_list = _xpu_computation_op_list + _xpu_tensor_factory_op_list
64_xpu_all_ops = [op for op in ops_and_refs if op.name in _xpu_all_op_list]
65_xpu_computation_ops = [
66    op for op in ops_and_refs if op.name in _xpu_computation_op_list
67]
68
69
70class TestXpu(TestCase):
71    def test_device_behavior(self):
72        current_device = torch.xpu.current_device()
73        torch.xpu.set_device(current_device)
74        self.assertEqual(current_device, torch.xpu.current_device())
75
76    @unittest.skipIf(not TEST_MULTIXPU, "only one GPU detected")
77    def test_multi_device_behavior(self):
78        current_device = torch.xpu.current_device()
79        target_device = (current_device + 1) % torch.xpu.device_count()
80
81        with torch.xpu.device(target_device):
82            self.assertEqual(target_device, torch.xpu.current_device())
83        self.assertEqual(current_device, torch.xpu.current_device())
84
85        with torch.xpu._DeviceGuard(target_device):
86            self.assertEqual(target_device, torch.xpu.current_device())
87        self.assertEqual(current_device, torch.xpu.current_device())
88
89    def test_get_device_properties(self):
90        current_device = torch.xpu.current_device()
91        device_properties = torch.xpu.get_device_properties(current_device)
92        self.assertEqual(device_properties, torch.xpu.get_device_properties(None))
93        self.assertEqual(device_properties, torch.xpu.get_device_properties())
94
95        device_name = torch.xpu.get_device_name(current_device)
96        self.assertEqual(device_name, torch.xpu.get_device_name(None))
97        self.assertEqual(device_name, torch.xpu.get_device_name())
98
99        device_capability = torch.xpu.get_device_capability(current_device)
100        self.assertTrue(device_capability["max_work_group_size"] > 0)
101        self.assertTrue(device_capability["max_num_sub_groups"] > 0)
102        self.assertEqual(
103            device_properties.driver_version, device_capability["driver_version"]
104        )
105        self.assertEqual(device_properties.has_fp16, device_capability["has_fp16"])
106        self.assertEqual(device_properties.has_fp64, device_capability["has_fp64"])
107        self.assertEqual(
108            device_properties.has_atomic64, device_capability["has_atomic64"]
109        )
110
111    def test_wrong_xpu_fork(self):
112        stderr = TestCase.runWithPytorchAPIUsageStderr(
113            """\
114import torch
115from torch.multiprocessing import Process
116def run(rank):
117    torch.xpu.set_device(rank)
118if __name__ == "__main__":
119    size = 2
120    processes = []
121    for rank in range(size):
122        # it would work fine without the line below
123        torch.xpu.set_device(0)
124        p = Process(target=run, args=(rank,))
125        p.start()
126        processes.append(p)
127    for p in processes:
128        p.join()
129"""
130        )
131        self.assertRegex(stderr, "Cannot re-initialize XPU in forked subprocess.")
132
133    def test_streams(self):
134        s0 = torch.xpu.Stream()
135        torch.xpu.set_stream(s0)
136        s1 = torch.xpu.current_stream()
137        self.assertEqual(s0, s1)
138        s2 = torch.xpu.Stream()
139        self.assertFalse(s0 == s2)
140        torch.xpu.set_stream(s2)
141        with torch.xpu.stream(s0):
142            self.assertEqual(s0, torch.xpu.current_stream())
143        self.assertEqual(s2, torch.xpu.current_stream())
144
145    def test_stream_priority(self):
146        low, high = torch.xpu.Stream.priority_range()
147        s0 = torch.xpu.Stream(device=0, priority=low)
148
149        self.assertEqual(low, s0.priority)
150        self.assertEqual(torch.device("xpu:0"), s0.device)
151
152        s1 = torch.xpu.Stream(device=0, priority=high)
153
154        self.assertEqual(high, s1.priority)
155        self.assertEqual(torch.device("xpu:0"), s1.device)
156
157    def test_stream_event_repr(self):
158        s = torch.xpu.current_stream()
159        self.assertTrue("torch.xpu.Stream" in str(s))
160        e = torch.xpu.Event()
161        self.assertTrue("torch.xpu.Event(uninitialized)" in str(e))
162        s.record_event(e)
163        self.assertTrue("torch.xpu.Event" in str(e))
164
165    def test_events(self):
166        stream = torch.xpu.current_stream()
167        event = torch.xpu.Event()
168        self.assertTrue(event.query())
169        stream.record_event(event)
170        event.synchronize()
171        self.assertTrue(event.query())
172
173    def test_generic_stream_event(self):
174        stream = torch.Stream("xpu")
175        self.assertEqual(stream.device_index, torch.xpu.current_device())
176        xpu_stream = torch.xpu.Stream(
177            stream_id=stream.stream_id,
178            device_index=stream.device_index,
179            device_type=stream.device_type,
180        )
181        self.assertEqual(stream.stream_id, xpu_stream.stream_id)
182        self.assertNotEqual(stream.stream_id, torch.xpu.current_stream().stream_id)
183
184        event1 = torch.Event("xpu")
185        event2 = torch.Event("xpu")
186        self.assertEqual(event1.event_id, 0)
187        a = torch.randn(1000)
188        b = torch.randn(1000)
189        with torch.xpu.stream(xpu_stream):
190            a_xpu = a.to("xpu", non_blocking=True)
191            b_xpu = b.to("xpu", non_blocking=True)
192            self.assertEqual(stream.stream_id, torch.xpu.current_stream().stream_id)
193        event1.record(stream)
194        event1.synchronize()
195        self.assertTrue(event1.query())
196        c_xpu = a_xpu + b_xpu
197        event2.record()
198        event2.synchronize()
199        self.assertTrue(event2.query())
200        self.assertNotEqual(event1.event_id, event2.event_id)
201        self.assertEqual(c_xpu.cpu(), a + b)
202        with self.assertRaisesRegex(
203            NotImplementedError, "elapsedTime is not supported by XPU backend."
204        ):
205            event1.elapsed_time(event2)
206
207    def test_generator(self):
208        torch.manual_seed(2024)
209        g_state0 = torch.xpu.get_rng_state()
210        torch.manual_seed(1234)
211        g_state1 = torch.xpu.get_rng_state()
212        self.assertNotEqual(g_state0, g_state1)
213
214        torch.xpu.manual_seed(2024)
215        g_state2 = torch.xpu.get_rng_state()
216        self.assertEqual(g_state0, g_state2)
217
218        torch.xpu.set_rng_state(g_state1)
219        self.assertEqual(g_state1, torch.xpu.get_rng_state())
220
221        torch.manual_seed(1234)
222        torch.xpu.set_rng_state(g_state0)
223        self.assertEqual(2024, torch.xpu.initial_seed())
224
225    @onlyXPU
226    @suppress_warnings
227    @ops(_xpu_computation_ops, dtypes=any_common_cpu_xpu_one)
228    def test_compare_cpu(self, device, dtype, op):
229        def to_cpu(arg):
230            if isinstance(arg, torch.Tensor):
231                return arg.to(device="cpu")
232            return arg
233
234        samples = op.reference_inputs(device, dtype)
235
236        for sample in samples:
237            cpu_sample = sample.transform(to_cpu)
238            xpu_results = op(sample.input, *sample.args, **sample.kwargs)
239            cpu_results = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs)
240
241            xpu_results = sample.output_process_fn_grad(xpu_results)
242            cpu_results = cpu_sample.output_process_fn_grad(cpu_results)
243
244            # Lower tolerance because we are running this as a `@slowTest`
245            # Don't want the periodic tests to fail frequently
246            self.assertEqual(xpu_results, cpu_results, atol=1e-4, rtol=1e-4)
247
248    @onlyXPU
249    @ops(_xpu_computation_ops, allowed_dtypes=(torch.bool,))
250    @unittest.skipIf(TEST_WITH_UBSAN, "Test uses undefined behavior")
251    def test_non_standard_bool_values(self, device, dtype, op):
252        # Test boolean values other than 0x00 and 0x01 (gh-54789)
253        def convert_boolean_tensors(x):
254            if not isinstance(x, torch.Tensor) or x.dtype != torch.bool:
255                return x
256
257            # Map False -> 0 and True -> Random value in [2, 255]
258            true_vals = torch.randint(
259                2, 255, x.shape, dtype=torch.uint8, device=x.device
260            )
261            false_vals = torch.zeros((), dtype=torch.uint8, device=x.device)
262            x_int = torch.where(x, true_vals, false_vals)
263
264            ret = x_int.view(torch.bool)
265            self.assertEqual(ret, x)
266            return ret
267
268        for sample in op.sample_inputs(device, dtype):
269            expect = op(sample.input, *sample.args, **sample.kwargs)
270
271            transformed = sample.transform(convert_boolean_tensors)
272            actual = op(transformed.input, *transformed.args, **transformed.kwargs)
273
274            self.assertEqual(expect, actual)
275
276    def test_serialization_array_with_storage(self):
277        x = torch.randn(5, 5).xpu()
278        y = torch.zeros(2, 5, dtype=torch.int, device="xpu")
279        q = [x, y, x, y.storage()]
280        with tempfile.NamedTemporaryFile() as f:
281            torch.save(q, f)
282            f.seek(0)
283            q_copy = torch.load(f)
284        self.assertEqual(q_copy, q, atol=0, rtol=0)
285        q_copy[0].fill_(5)
286        self.assertEqual(q_copy[0], q_copy[2], atol=0, rtol=0)
287        self.assertEqual(q_copy[0].dtype, torch.float)
288        self.assertEqual(q_copy[1].dtype, torch.int)
289        self.assertEqual(q_copy[2].dtype, torch.float)
290        self.assertTrue(isinstance(q_copy[3], torch.storage.TypedStorage))
291        self.assertTrue(isinstance(q_copy[3]._untyped_storage, torch.UntypedStorage))
292        q_copy[1].fill_(10)
293        y.fill_(10)
294        self.assertEqual(q_copy[3], y.storage())
295
296    def test_serialization_array_with_empty(self):
297        x = [
298            torch.randn(4, 4).xpu(),
299            torch.tensor([], dtype=torch.float, device=torch.device("xpu")),
300        ]
301        with tempfile.NamedTemporaryFile() as f:
302            torch.save(x, f)
303            f.seek(0)
304            x_copy = torch.load(f)
305        for original, copy in zip(x, x_copy):
306            self.assertEqual(copy, original)
307            self.assertIs(type(copy), type(original))
308            self.assertEqual(copy.get_device(), original.get_device())
309
310
311instantiate_device_type_tests(TestXpu, globals(), only_for="xpu")
312
313
314class TestXpuAutocast(TestCase):
315    # These operators are not implemented on XPU backend and we can NOT fall back
316    # them to CPU. So we have to skip them at this moment.
317    # TODO: remove these operators from skip list when they are implemented on XPU backend.
318    skip_list = ["gru_cell"]
319
320    def setUp(self):
321        super().setUp()
322        self.autocast_lists = AutocastTestLists(torch.device("xpu"))
323
324    def tearDown(self):
325        del self.autocast_lists
326        super().tearDown()
327
328    def _run_autocast_outofplace(
329        self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None
330    ):
331        # helper to cast args
332        def cast(val, to_type):
333            if isinstance(val, torch.Tensor):
334                return val.to(to_type) if val.is_floating_point() else val
335            elif isinstance(val, collections.abc.Iterable):
336                return type(val)(cast(v, to_type) for v in val)
337            else:
338                return val
339
340        if add_kwargs is None:
341            add_kwargs = {}
342        fast_dtype = torch.bfloat16 if run_as_type == torch.bfloat16 else torch.float16
343        self.assertFalse(torch.is_autocast_enabled("xpu"))
344        with torch.amp.autocast("xpu", dtype=fast_dtype):
345            self.assertTrue(torch.is_autocast_enabled("xpu"))
346
347            out_type = out_type if out_type is not None else run_as_type
348            output = output_method = None
349
350            # Try module.* variant, if requested:
351            if module is not None and hasattr(module, op):
352                output = getattr(module, op)(*args, **add_kwargs)
353                if isinstance(output, torch.Tensor):
354                    self.assertTrue(
355                        out_type == output.dtype,
356                        f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}",
357                    )
358
359            # Try Tensor.* variant:
360            if hasattr(torch.Tensor, op):
361                output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
362                if isinstance(output_method, torch.Tensor):
363                    self.assertTrue(
364                        out_type == output_method.dtype,
365                        f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}",
366                    )
367
368            self.assertTrue(
369                (output is not None) or (output_method is not None),
370                f"{op} not found as an attribute on either Tensor or the requested module {module}",
371            )
372
373            # Accounts for ops that return Tensors, iterables, and other non-Tensors.
374            # For example, lstm_cell returns a tuple and equal returns bool.
375            def compare(first, second):
376                if isinstance(first, torch.Tensor):
377                    return torch.equal(first, second)
378                elif isinstance(first, collections.abc.Iterable):
379                    return all(compare(f, s) for f, s in zip(first, second))
380                else:
381                    return first == second
382
383            # If both torch.* and Tensor.* variants were found, check outputs are identical
384            if (output is not None) and (output_method is not None):
385                self.assertTrue(type(output) == type(output_method))
386                comparison = compare(output, output_method)
387                self.assertTrue(
388                    comparison, f"torch.{op} result did not match Tensor.{op} result"
389                )
390
391            # Compare numerics to Python-side "autocasting" that (we expect) does the same thing
392            # as the C++-side autocasting, and should be bitwise accurate.
393            output_to_compare = output if output is not None else output_method
394            with torch.amp.autocast("xpu", enabled=False):
395                self.assertFalse(torch.is_autocast_enabled("xpu"))
396
397                if module is not None and hasattr(module, op):
398                    control = getattr(module, op)(
399                        *cast(args, run_as_type), **add_kwargs
400                    )
401                else:
402                    control = getattr(args[0].to(run_as_type), op)(
403                        *cast(args[1:], run_as_type), **add_kwargs
404                    )
405                self.assertTrue(type(output_to_compare) == type(control))
406                comparison = compare(output_to_compare, control)
407                self.assertTrue(comparison, f"torch.{op} result did not match control")
408            self.assertTrue(torch.is_autocast_enabled("xpu"))
409        self.assertFalse(torch.is_autocast_enabled("xpu"))
410
411    def test_autocast_torch_fp16(self):
412        for op_with_args in self.autocast_lists.torch_fp16:
413            skip_test = False
414            op, args = op_with_args[0], op_with_args[1]
415            if op in self.skip_list:
416                skip_test = True  # skip unimplemented op
417            if len(op_with_args) == 3:
418                skip_test = True  # skip cudnn op
419            if not skip_test:
420                self._run_autocast_outofplace(op, args, torch.float16)
421
422    def test_autocast_torch_bf16(self):
423        for op_with_args in self.autocast_lists.torch_fp16:
424            skip_test = False
425            op, args = op_with_args[0], op_with_args[1]
426            if op in self.skip_list:
427                skip_test = True  # skip unimplemented op
428            if len(op_with_args) == 3:
429                skip_test = True  # skip cudnn op
430            if not skip_test:
431                self._run_autocast_outofplace(op, args, torch.bfloat16)
432
433    def test_autocast_torch_need_autocast_promote(self):
434        for op, args in self.autocast_lists.torch_need_autocast_promote:
435            self._run_autocast_outofplace(op, args, torch.float32)
436
437    def test_autocast_torch_expect_builtin_promote(self):
438        for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote:
439            self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type)
440
441    def test_xpu_autocast_dtype(self):
442        dtype = torch.get_autocast_dtype("xpu")
443        self.assertEqual(dtype, torch.float16)
444        mat0_fp32 = torch.randn((10, 10), dtype=torch.float32, device="xpu")
445        mat1_fp32 = torch.randn((10, 10), dtype=torch.float32, device="xpu")
446        with torch.amp.autocast("xpu"):
447            result = torch.mm(mat0_fp32, mat1_fp32)
448            self.assertEqual(result.dtype, torch.float16)
449
450
451class TestXpuTrace(TestCase):
452    def setUp(self):
453        torch._C._activate_gpu_trace()
454        self.mock = unittest.mock.MagicMock()
455
456    def test_event_creation_callback(self):
457        gpu_trace.register_callback_for_event_creation(self.mock)
458
459        event = torch.xpu.Event()
460        event.record()
461        self.mock.assert_called_once_with(event._as_parameter_.value)
462
463    def test_event_deletion_callback(self):
464        gpu_trace.register_callback_for_event_deletion(self.mock)
465
466        event = torch.xpu.Event()
467        event.record()
468        event_id = event._as_parameter_.value
469        del event
470        self.mock.assert_called_once_with(event_id)
471
472    def test_event_record_callback(self):
473        gpu_trace.register_callback_for_event_record(self.mock)
474
475        event = torch.xpu.Event()
476        event.record()
477        self.mock.assert_called_once_with(
478            event._as_parameter_.value, torch.xpu.current_stream().sycl_queue
479        )
480
481    def test_event_wait_callback(self):
482        gpu_trace.register_callback_for_event_wait(self.mock)
483
484        event = torch.xpu.Event()
485        event.record()
486        event.wait()
487        self.mock.assert_called_once_with(
488            event._as_parameter_.value, torch.xpu.current_stream().sycl_queue
489        )
490
491    def test_device_synchronization_callback(self):
492        gpu_trace.register_callback_for_device_synchronization(self.mock)
493
494        torch.xpu.synchronize()
495        self.mock.assert_called()
496
497    def test_stream_synchronization_callback(self):
498        gpu_trace.register_callback_for_stream_synchronization(self.mock)
499
500        stream = torch.xpu.Stream()
501        stream.synchronize()
502        self.mock.assert_called_once_with(stream.sycl_queue)
503
504    def test_event_synchronization_callback(self):
505        gpu_trace.register_callback_for_event_synchronization(self.mock)
506
507        event = torch.xpu.Event()
508        event.record()
509        event.synchronize()
510        self.mock.assert_called_once_with(event._as_parameter_.value)
511
512
513if __name__ == "__main__":
514    run_tests()
515