xref: /aosp_15_r20/external/pytorch/test/cpp_extensions/open_registration_extension/test/test_openreg.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: cpp"]
2
3import os
4import unittest
5
6import psutil
7import pytorch_openreg
8
9import torch
10from torch.testing._internal.common_utils import run_tests, TestCase
11
12
13class TestOpenReg(TestCase):
14    def test_initializes(self):
15        self.assertEqual(torch._C._get_privateuse1_backend_name(), "openreg")
16
17    @unittest.SkipTest
18    def test_autograd_init(self):
19        # Make sure autograd is initialized
20        torch.ones(2, requires_grad=True, device="openreg").sum().backward()
21
22        pid = os.getpid()
23        task_path = f"/proc/{pid}/task"
24        all_threads = psutil.Process(pid).threads()
25
26        all_thread_names = set()
27
28        for t in all_threads:
29            with open(f"{task_path}/{t.id}/comm") as file:
30                thread_name = file.read().strip()
31            all_thread_names.add(thread_name)
32
33        for i in range(pytorch_openreg._device_daemon.NUM_DEVICES):
34            self.assertIn(f"pt_autograd_{i}", all_thread_names)
35
36    def test_factory(self):
37        a = torch.empty(50, device="openreg")
38        self.assertEqual(a.device.type, "openreg")
39
40        a.fill_(3.5)
41
42        self.assertTrue(a.eq(3.5).all())
43
44    def test_printing(self):
45        a = torch.ones(20, device="openreg")
46        # Does not crash!
47        str(a)
48
49    def test_cross_device_copy(self):
50        a = torch.rand(10)
51        b = a.to(device="openreg").add(2).to(device="cpu")
52        self.assertEqual(b, a + 2)
53
54    def test_data_dependent_output(self):
55        cpu_a = torch.randn(10)
56        a = cpu_a.to(device="openreg")
57        mask = a.gt(0)
58        out = torch.masked_select(a, mask)
59
60        self.assertEqual(out, cpu_a.masked_select(cpu_a.gt(0)))
61
62
63if __name__ == "__main__":
64    run_tests()
65