1# Owner(s): ["oncall: distributed"] 2 3from unittest import mock 4 5import torch.distributed as c10d 6from torch.distributed.collective_utils import all_gather, broadcast 7from torch.testing._internal.common_distributed import MultiProcessTestCase 8 9 10class TestCollectiveUtils(MultiProcessTestCase): 11 def setUp(self): 12 super().setUp() 13 self._spawn_processes() 14 15 def tearDown(self) -> None: 16 super().tearDown() 17 18 def opts(self, threads=2): 19 opts = c10d.ProcessGroupGloo._Options() 20 opts._timeout = 50.0 21 opts._threads = threads 22 return opts 23 24 def test_broadcast_result(self) -> None: 25 """ 26 Basic unit test for broadcast using a process group of default world size. 27 """ 28 store = c10d.FileStore(self.file_name, self.world_size) 29 c10d.init_process_group( 30 backend="gloo", store=store, rank=self.rank, world_size=self.world_size 31 ) 32 pg = c10d.new_group(pg_options=self.opts()) 33 34 func = mock.MagicMock() 35 func.return_value = pg.rank() 36 37 res = broadcast(data_or_fn=func, rank=0, pg=pg) 38 assert res == 0, f"Expect res to be 0 (got {res})" 39 40 if pg.rank() == 0: 41 func.assert_called_once() 42 else: 43 func.assert_not_called() 44 45 func.reset_mock() 46 47 res = broadcast(data_or_fn=func, rank=1, pg=pg) 48 assert res == 1, f"Expect res to be 1 (got {res})" 49 50 if pg.rank() == 1: 51 func.assert_called_once() 52 else: 53 func.assert_not_called() 54 55 def test_broadcast_result_no_pg(self) -> None: 56 """ 57 Ensure broadcast has no dependency on torch.distributed when run in single process. 58 """ 59 func = mock.MagicMock() 60 res = broadcast(data_or_fn=func, rank=0) 61 func.assert_called_once() 62 63 def test_broadcast_result_raises_exceptions_from_func( 64 self, 65 ) -> None: 66 """ 67 Ensure broadcast exception is propagated properly. 68 """ 69 # no process group 70 func = mock.MagicMock() 71 exc = Exception("test exception") 72 func.side_effect = exc 73 expected_exception = "test exception" 74 with self.assertRaisesRegex(Exception, expected_exception): 75 broadcast(data_or_fn=func, rank=0) 76 77 def test_all_gather_result(self) -> None: 78 """ 79 Basic unit test for all_gather using a process group of default world size. 80 """ 81 store = c10d.FileStore(self.file_name, self.world_size) 82 c10d.init_process_group( 83 backend="gloo", store=store, rank=self.rank, world_size=self.world_size 84 ) 85 pg = c10d.new_group(pg_options=self.opts()) 86 87 func = mock.MagicMock() 88 func.return_value = pg.rank() 89 90 res = all_gather(data_or_fn=func, pg=pg) 91 func.assert_called_once() 92 assert res == list( 93 range(self.world_size) 94 ), f"Expect res to be list of 0 through {self.world_size} (got {res})" 95 96 def test_all_gather_result_no_pg(self) -> None: 97 """ 98 Ensure all_gather has no dependency on torch.distributed when run in single process. 99 """ 100 func = mock.MagicMock() 101 res = all_gather(data_or_fn=func) 102 func.assert_called_once() 103 104 def test_all_gather_result_raises_exceptions_from_func( 105 self, 106 ) -> None: 107 """ 108 Ensure all_gather exception is propagated properly. 109 """ 110 # no process group 111 func = mock.MagicMock() 112 exc = Exception("test exception") 113 func.side_effect = exc 114 expected_exception = "test exception" 115 with self.assertRaisesRegex(Exception, expected_exception): 116 all_gather(data_or_fn=func) 117