1# Owner(s): ["oncall: distributed"] 2 3import contextlib 4import os 5import sys 6from typing import Any, Optional 7 8import torch 9import torch.distributed as dist 10 11 12if not dist.is_available(): 13 print("Distributed not available, skipping tests", file=sys.stderr) 14 sys.exit(0) 15 16from torch.distributed.algorithms.join import Join, Joinable, JoinHook 17from torch.testing._internal.common_distributed import ( 18 MultiProcessTestCase, 19 require_n_gpus_for_nccl_backend, 20) 21from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN 22 23 24if TEST_WITH_DEV_DBG_ASAN: 25 print( 26 "Skip dev-asan as torch + multiprocessing spawn have known issues", 27 file=sys.stderr, 28 ) 29 sys.exit(0) 30 31BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO 32WORLD_SIZE = min(4, max(2, torch.cuda.device_count())) 33 34# Constants used for testing post-hooks 35BEFORE_CONSTANT = 41 36AFTER_CONSTANT = 42 37 38 39class AllReducerJoinHook(JoinHook): 40 r""" 41 Join hook for :class:`AllReducer`. 42 43 Arguments: 44 allreducer (AllReducer): the :class:`AllReducer` object using this 45 hook. 46 num_allreduces (int): the number of all-reduces to shadow per 47 iteration. 48 run_post_hook (bool): a flag enabling the post-hook logic. 49 """ 50 51 def __init__(self, allreducer, num_allreduces, run_post_hook): 52 self.allreducer = allreducer 53 self.num_allreduces = num_allreduces 54 self.run_post_hook = run_post_hook 55 56 def main_hook(self): 57 r""" 58 Shadows each all-reduce; the number of all-reduces is passed into the 59 constructor as ``num_allreduces``. 60 """ 61 device = self.allreducer.device 62 for _ in range(self.num_allreduces): 63 t = torch.zeros(1, device=device) 64 dist.all_reduce(t) 65 66 def post_hook(self, is_last_joiner: bool): 67 r""" 68 Broadcasts a tensor containing a magic constant ``AFTER_CONSTANT`` from 69 the last joiner to all other processes. 70 """ 71 if not self.run_post_hook: 72 return 73 rank = dist.get_rank(self.allreducer.process_group) 74 common_rank = self.allreducer.find_common_rank(rank, is_last_joiner) 75 device = self.allreducer.device 76 if rank == common_rank: 77 self.allreducer.post_hook_tensor = torch.tensor( 78 [AFTER_CONSTANT], device=device 79 ) 80 dist.broadcast(self.allreducer.post_hook_tensor, src=common_rank) 81 82 83class AllReducer(Joinable): 84 r""" 85 Example :class:`Joinable` that performs some number of all-reduces as its 86 per-iteration collective communication. 87 """ 88 89 def __init__(self, device, process_group): 90 super().__init__() 91 self.device = device 92 self.process_group = process_group 93 self.post_hook_tensor = torch.tensor([BEFORE_CONSTANT], device=self.device) 94 95 def __call__(self, num_allreduces=1): 96 r""" 97 All-reduces a dim-1 one tensor ``num_allreduces``-many times, and 98 returns the total result. 99 """ 100 Join.notify_join_context(self) 101 device = self.device 102 total = 0 103 for _ in range(num_allreduces): 104 t = torch.ones(1, device=device) 105 dist.all_reduce(t) 106 total += t.item() 107 return total 108 109 def join_hook(self, **kwargs) -> JoinHook: 110 r""" 111 Returns a join hook that shadows some number of all-reduces; by default, 112 this number is 1. 113 """ 114 num_allreduces = kwargs.get("num_allreduces", 1) 115 run_post_hook = kwargs.get("run_post_hooks", False) 116 return AllReducerJoinHook(self, num_allreduces, run_post_hook) 117 118 @property 119 def join_device(self) -> torch.device: 120 return self.device 121 122 @property 123 def join_process_group(self) -> Any: 124 return self.process_group 125 126 def find_common_rank(self, rank, to_consider): 127 r""" 128 Returns the max rank of the ones to consider over the process group. 129 """ 130 common_rank = torch.tensor([rank if to_consider else -1], device=self.device) 131 dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group) 132 common_rank = common_rank.item() 133 assert common_rank >= 0 134 return common_rank 135 136 137class TestJoin(MultiProcessTestCase): 138 r"""Test cases for the generic join context.""" 139 140 def setUp(self): 141 super().setUp() 142 os.environ["WORLD_SIZE"] = str(self.world_size) 143 os.environ["BACKEND"] = BACKEND 144 self._spawn_processes() 145 146 @property 147 def device(self): 148 return ( 149 torch.device(self.rank) 150 if BACKEND == dist.Backend.NCCL 151 else torch.device("cpu") 152 ) 153 154 @property 155 def world_size(self): 156 return WORLD_SIZE 157 158 @property 159 def process_group(self): 160 return dist.group.WORLD 161 162 def tearDown(self): 163 try: 164 dist.destroy_process_group() 165 except AssertionError: 166 pass 167 try: 168 os.remove(self.file_name) 169 except OSError: 170 pass 171 172 def dist_init(self, rank, world_size, backend=BACKEND): 173 store = dist.FileStore(self.file_name, world_size) 174 return dist.init_process_group( 175 backend=backend, store=store, rank=rank, world_size=world_size 176 ) 177 178 def construct_uneven_inputs(self, base, offset, device=None): 179 r""" 180 Returns uneven inputs: rank i gets ``base`` + i * ``offset`` inputs. 181 """ 182 if device is None: 183 device = self.device 184 return [torch.zeros(1, device=device) for _ in range(base + self.rank * offset)] 185 186 def construct_even_inputs(self, base, device=None): 187 r"""Returns even inputs: each rank gets ``base`` inputs.""" 188 if device is None: 189 device = self.device 190 return [torch.zeros(1, device=device) for _ in range(base)] 191 192 @property 193 def base_num_inputs(self): 194 r"""Base number of inputs to be used by all ranks.""" 195 return 3 196 197 @property 198 def offset(self): 199 r"""Rank i gets i * ``offset`` additional inputs.""" 200 return 1 201 202 def _test_join_base( 203 self, 204 uneven_inputs: bool, 205 num_joinables: int, 206 enable: bool, 207 throw_on_early_termination: bool, 208 num_allreduces: int, 209 run_post_hooks: bool, 210 expected_total: Optional[int] = None, 211 ): 212 r""" 213 Skeleton for all :class:`Join` tests. 214 215 Arguments: 216 uneven_inputs (bool): ``True`` to use uneven inputs; ``False`` 217 otherwise. 218 num_joinables (int): number of :class:`AllReducer` s to construct. 219 enable (bool): ``True`` to enable the join context manager; 220 ``False`` otherwise. 221 throw_on_early_termination (bool): ``True`` to raise an exception 222 upon detecting uneven inputs; ``False`` otherwise. 223 num_allreduces (int): number of all-reduces to perform per input. 224 run_post_hooks (bool): ``True`` to run post-hooks; ``False`` 225 otherwise. 226 expected_total (Optional[int]): ``None`` to not check the expected 227 all-reduce total; otherwise, the expected total; default is 228 ``None``. 229 """ 230 self.dist_init(self.rank, self.world_size) 231 232 allreducers = [ 233 AllReducer(self.device, self.process_group) for _ in range(num_joinables) 234 ] 235 for allreducer in allreducers: 236 self.assertEqual(allreducer.post_hook_tensor.item(), BEFORE_CONSTANT) 237 238 inputs = ( 239 self.construct_uneven_inputs(self.base_num_inputs, self.offset) 240 if uneven_inputs 241 else self.construct_even_inputs(self.base_num_inputs) 242 ) 243 allreduce_total = 0 244 245 # Expect a `RuntimeError` if `throw_on_early_termination=True` 246 # Rank 0 exhausts its inputs first 247 expected_msg = ( 248 "Rank 0 exhausted all inputs." 249 if self.rank == 0 250 else "Detected at least one rank that exhausted inputs. " 251 "Throwing across all ranks." 252 ) 253 with self.assertRaisesRegex( 254 RuntimeError, expected_msg 255 ) if throw_on_early_termination else contextlib.nullcontext(): 256 with Join( 257 allreducers, 258 enable=enable, 259 throw_on_early_termination=throw_on_early_termination, 260 num_allreduces=num_allreduces, 261 run_post_hooks=run_post_hooks, 262 ): 263 for _ in inputs: 264 for allreducer in allreducers: 265 allreduce_total += allreducer(num_allreduces) 266 267 if throw_on_early_termination: 268 return 269 270 # Check `expected_total` if not `None` 271 if expected_total: 272 self.assertEqual(allreduce_total, expected_total) 273 274 # All `AllReduce` instances should receive the updated 275 # `post_hook_tensor` from the last-joined process 276 if run_post_hooks: 277 for allreducer in allreducers: 278 self.assertEqual(allreducer.post_hook_tensor.item(), AFTER_CONSTANT) 279 280 @require_n_gpus_for_nccl_backend(WORLD_SIZE, BACKEND) 281 def test_single_joinable_main_hooks(self): 282 r"""Tests the main hooks of a single :class:`Joinable`.""" 283 num_joinables = 1 284 num_allreduces = 1 285 run_post_hooks = False 286 # Non-joined processes all-reduce a 1, so this rank's all-reduce total 287 # should be precisely equal to the total number of inputs processed 288 # before it joined 289 expected_total = self.world_size * self.base_num_inputs 290 # Rank i runs for i additional iterations 291 for num_joined in range(1, self.rank + 1): 292 expected_total += (self.world_size - num_joined) * self.offset 293 294 self._test_join_base( 295 uneven_inputs=True, 296 num_joinables=num_joinables, 297 enable=True, 298 throw_on_early_termination=False, 299 num_allreduces=num_allreduces, 300 run_post_hooks=run_post_hooks, 301 expected_total=expected_total, 302 ) 303 304 @require_n_gpus_for_nccl_backend(WORLD_SIZE, BACKEND) 305 def test_single_joinable_post_hooks(self): 306 r"""Tests the post-hooks of a single :class:`Joinable`.""" 307 num_joinables = 1 308 num_allreduces = 0 # set to 0 to skip the main hooks 309 run_post_hooks = False 310 311 self._test_join_base( 312 uneven_inputs=True, 313 num_joinables=num_joinables, 314 enable=True, 315 throw_on_early_termination=False, 316 num_allreduces=num_allreduces, 317 run_post_hooks=run_post_hooks, 318 expected_total=None, 319 ) 320 321 @require_n_gpus_for_nccl_backend(WORLD_SIZE, BACKEND) 322 def test_single_joinable(self): 323 r""" 324 Tests the main hooks and post-hooks of a single :class:`Joinable` 325 together. 326 327 This combines ``test_single_joinable_main_hooks()`` and 328 ``test_single_joinable_post_hooks()`` into a single test to ensure that 329 main hooks and post-hooks operate correctly together. 330 """ 331 num_joinables = 1 332 num_allreduces = 1 333 run_post_hooks = True 334 335 expected_total = self.world_size * self.base_num_inputs 336 for num_joined in range(1, self.rank + 1): 337 expected_total += (self.world_size - num_joined) * self.offset 338 339 self._test_join_base( 340 uneven_inputs=True, 341 num_joinables=num_joinables, 342 enable=True, 343 throw_on_early_termination=False, 344 num_allreduces=num_allreduces, 345 run_post_hooks=run_post_hooks, 346 expected_total=expected_total, 347 ) 348 349 @require_n_gpus_for_nccl_backend(WORLD_SIZE, BACKEND) 350 def test_multiple_joinables(self): 351 r""" 352 Tests the main hooks and post-hooks of multiple :class:`Joinable` s 353 together. 354 355 This generalizes ``test_single_joinable()`` to multiple 356 :class:`Joinable` s. 357 """ 358 num_joinables = 3 359 num_allreduces = 1 360 run_post_hooks = True 361 362 expected_total = self.world_size * self.base_num_inputs 363 for num_joined in range(1, self.rank + 1): 364 expected_total += (self.world_size - num_joined) * self.offset 365 # The expected total is now multiplied by a factor of `NUM_JOINABLES` 366 expected_total *= num_joinables 367 368 self._test_join_base( 369 uneven_inputs=True, 370 num_joinables=num_joinables, 371 enable=True, 372 throw_on_early_termination=False, 373 num_allreduces=num_allreduces, 374 run_post_hooks=run_post_hooks, 375 expected_total=expected_total, 376 ) 377 378 @require_n_gpus_for_nccl_backend(WORLD_SIZE, BACKEND) 379 def test_single_joinable_disable(self): 380 r"""Tests ``enable=False`` for a single :class:`Joinable`.""" 381 num_joinables = 1 382 num_allreduces = 1 383 uneven_inputs = False 384 enable = False 385 run_post_hooks = False 386 387 expected_total = self.world_size * self.base_num_inputs 388 389 self._test_join_base( 390 uneven_inputs=uneven_inputs, 391 num_joinables=num_joinables, 392 enable=enable, 393 throw_on_early_termination=False, 394 num_allreduces=num_allreduces, 395 run_post_hooks=run_post_hooks, 396 expected_total=expected_total, 397 ) 398 399 @require_n_gpus_for_nccl_backend(WORLD_SIZE, BACKEND) 400 def test_multiple_joinable_disable(self): 401 r""" 402 Tests ``enable=False`` for multiple :class:`Joinable` s. 403 404 This generalizes ``test_single_joinable_disable`` to multiple 405 :class:`Joinable` s. 406 """ 407 num_joinables = 3 408 num_allreduces = 1 409 uneven_inputs = False 410 enable = False 411 run_post_hooks = False 412 413 expected_total = self.world_size * self.base_num_inputs * num_joinables 414 415 self._test_join_base( 416 uneven_inputs=uneven_inputs, 417 num_joinables=num_joinables, 418 enable=enable, 419 throw_on_early_termination=False, 420 num_allreduces=num_allreduces, 421 run_post_hooks=run_post_hooks, 422 expected_total=expected_total, 423 ) 424 425 @require_n_gpus_for_nccl_backend(WORLD_SIZE, BACKEND) 426 def test_single_joinable_throw(self): 427 r""" 428 Tests ``throw_on_early_termination=True`` for a single 429 :class:`Joinable`. 430 """ 431 num_joinables = 1 432 num_allreduces = 1 433 throw_on_early_termination = True 434 run_post_hooks = False 435 436 self._test_join_base( 437 uneven_inputs=True, 438 num_joinables=num_joinables, 439 enable=True, 440 throw_on_early_termination=throw_on_early_termination, 441 num_allreduces=num_allreduces, 442 run_post_hooks=run_post_hooks, 443 expected_total=None, 444 ) 445 446 @require_n_gpus_for_nccl_backend(WORLD_SIZE, BACKEND) 447 def test_multiple_joinables_throw(self): 448 r""" 449 Tests ``throw_on_early_termination=True`` for multiple 450 :class:`Joinable` s together. 451 452 This generalizes ``test_single_joinable_throw`` to multiple 453 :class:`Joinable` s. 454 """ 455 num_joinables = 3 456 num_allreduces = 1 457 throw_on_early_termination = True 458 run_post_hooks = False 459 460 self._test_join_base( 461 uneven_inputs=True, 462 num_joinables=num_joinables, 463 enable=True, 464 throw_on_early_termination=throw_on_early_termination, 465 num_allreduces=num_allreduces, 466 run_post_hooks=run_post_hooks, 467 expected_total=None, 468 ) 469 470 @require_n_gpus_for_nccl_backend(WORLD_SIZE, BACKEND) 471 def test_join_kwargs(self): 472 r""" 473 Tests passing keyword arguments to the context manager. 474 """ 475 num_joinables = 1 476 num_allreduces = 2 477 run_post_hooks = False 478 479 expected_total = self.world_size * self.base_num_inputs 480 for num_joined in range(1, self.rank + 1): 481 expected_total += (self.world_size - num_joined) * self.offset 482 # The expected total is now multiplied by a factor of `NUM_ALLREDUCES` 483 expected_total *= num_allreduces 484 485 self._test_join_base( 486 uneven_inputs=True, 487 num_joinables=num_joinables, 488 enable=True, 489 throw_on_early_termination=False, 490 num_allreduces=num_allreduces, 491 run_post_hooks=run_post_hooks, 492 expected_total=expected_total, 493 ) 494 495 496if __name__ == "__main__": 497 run_tests() 498