1# Owner(s): ["module: cpp"] 2 3import os 4import unittest 5 6import psutil 7import pytorch_openreg 8 9import torch 10from torch.testing._internal.common_utils import run_tests, TestCase 11 12 13class TestOpenReg(TestCase): 14 def test_initializes(self): 15 self.assertEqual(torch._C._get_privateuse1_backend_name(), "openreg") 16 17 @unittest.SkipTest 18 def test_autograd_init(self): 19 # Make sure autograd is initialized 20 torch.ones(2, requires_grad=True, device="openreg").sum().backward() 21 22 pid = os.getpid() 23 task_path = f"/proc/{pid}/task" 24 all_threads = psutil.Process(pid).threads() 25 26 all_thread_names = set() 27 28 for t in all_threads: 29 with open(f"{task_path}/{t.id}/comm") as file: 30 thread_name = file.read().strip() 31 all_thread_names.add(thread_name) 32 33 for i in range(pytorch_openreg._device_daemon.NUM_DEVICES): 34 self.assertIn(f"pt_autograd_{i}", all_thread_names) 35 36 def test_factory(self): 37 a = torch.empty(50, device="openreg") 38 self.assertEqual(a.device.type, "openreg") 39 40 a.fill_(3.5) 41 42 self.assertTrue(a.eq(3.5).all()) 43 44 def test_printing(self): 45 a = torch.ones(20, device="openreg") 46 # Does not crash! 47 str(a) 48 49 def test_cross_device_copy(self): 50 a = torch.rand(10) 51 b = a.to(device="openreg").add(2).to(device="cpu") 52 self.assertEqual(b, a + 2) 53 54 def test_data_dependent_output(self): 55 cpu_a = torch.randn(10) 56 a = cpu_a.to(device="openreg") 57 mask = a.gt(0) 58 out = torch.masked_select(a, mask) 59 60 self.assertEqual(out, cpu_a.masked_select(cpu_a.gt(0))) 61 62 63if __name__ == "__main__": 64 run_tests() 65