1# Owner(s): ["oncall: distributed"] 2 3import os 4import sys 5from datetime import timedelta 6from unittest.mock import patch 7 8import torch 9import torch.distributed as c10d 10from torch._C._distributed_c10d import _ProcessGroupWrapper 11 12 13if not c10d.is_available(): 14 print("c10d not available, skipping tests", file=sys.stderr) 15 sys.exit(0) 16 17from test_c10d_common import LOOPBACK 18 19from torch.testing._internal.common_distributed import ( 20 create_device, 21 MultiProcessTestCase, 22 requires_gloo, 23 requires_nccl, 24 skip_if_lt_x_gpu, 25 with_dist_debug_levels, 26) 27from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN 28 29 30class AbstractProcessGroupWrapperTest(MultiProcessTestCase): 31 def setUp(self): 32 super().setUp() 33 self._spawn_processes() 34 35 def _validate_error(self, exception, op_type, rank, tensor, verify_diff=True): 36 err = str(exception) 37 self.assertTrue( 38 op_type in err, f"Got {err} but expected {op_type} to be in error." 39 ) 40 # User doesn't call barrier with tensor. 41 if op_type != "BARRIER": 42 self.assertTrue( 43 f"{list(tensor.shape)}" in err, 44 f"Did not find shapes {list(tensor.shape)} in error {err}", 45 ) 46 # For CUDA, only assert on device type, not index 47 if "cuda" in str(tensor.device): 48 self.assertTrue( 49 "cuda" in err, f"Did not find cuda device in error {err}" 50 ) 51 else: 52 self.assertTrue( 53 str(tensor.device) in err, 54 f"Did not find tensor device {str(tensor.device)} in error {err}", 55 ) 56 # C++ and python type strings are not exactly the same. 57 if "float" in str(tensor.dtype): 58 self.assertTrue("Float" in err, "Expected Float type") 59 elif "int" in str(tensor.dtype): 60 self.assertTrue("Long" in err, "Expected Long type") 61 else: 62 self.fail(f"Unexpected dtype {str(tensor.dtype)} for error {err}") 63 64 # Ensure sequence number is logged in error 65 self.assertTrue("SequenceNumber" in err) 66 # Ensure info about how collectives diff is in the error. 67 if verify_diff: 68 self.assertTrue( 69 "Collectives differ in the following" in err, f"Got error {err}" 70 ) 71 72 def _test_collective_hang(self, wrapper_pg, use_cuda=False): 73 # All ranks besides 1 call allreduce and wrapper_pg should detect a hang 74 # and report an issue with rank 1. 75 faulty_rank = 1 76 if self.rank != faulty_rank: 77 tensor = torch.randn(20, 10) 78 if use_cuda: 79 tensor = tensor.to(self.rank) 80 81 if self.rank == 0: 82 # Rank 0 reports faulty ranks 83 err = f"Ranks {faulty_rank} failed to pass monitoredBarrier" 84 else: 85 err = "Please check rank 0 logs for faulty rank" 86 87 # Gloo can sometimes throw the following error if a rank exits early 88 # before rank 0 calls into the allreduce. 89 err += "|Connection closed by peer|Connection reset by peer" 90 with self.assertRaisesRegex(RuntimeError, err): 91 wrapper_pg.allreduce([tensor]) 92 93 def _test_collectives_op_mismatch(self, wrapper_pg, use_cuda=False): 94 tensor = torch.randn(20, 10) 95 if use_cuda: 96 tensor = tensor.to(self.rank) 97 works = [] 98 # Run a few successful collectives 99 for _ in range(500): 100 work = wrapper_pg.allreduce([tensor]) 101 works.append(work) 102 103 for w in works: 104 w.wait() 105 106 # Simulate mismatch: allreduce vs reduce. 107 # Error including info about inconsistent collective, rank, tensor 108 # shape, device, and dtype should be raised. 109 with self.assertRaisesRegex(RuntimeError, ".*") as cm: 110 if self.rank == 0: 111 wrapper_pg.allreduce([tensor]) 112 else: 113 wrapper_pg.reduce([tensor]) 114 self._validate_error( 115 exception=cm.exception, 116 op_type="ALLREDUCE" if self.rank == 0 else "REDUCE", 117 rank=self.rank, 118 tensor=tensor, 119 ) 120 121 with self.assertRaisesRegex(RuntimeError, ".*") as cm: 122 if self.rank == 0: 123 wrapper_pg.reduce([tensor]) 124 else: 125 wrapper_pg.barrier() 126 self._validate_error( 127 exception=cm.exception, 128 op_type="REDUCE" if self.rank == 0 else "BARRIER", 129 rank=self.rank, 130 tensor=tensor, 131 ) 132 133 with self.assertRaisesRegex(RuntimeError, ".*") as cm: 134 if self.rank == 0: 135 wrapper_pg.broadcast(tensor, 0) 136 else: 137 output_tensors = [ 138 torch.zeros_like(tensor) for _ in range(self.world_size) 139 ] 140 wrapper_pg.allgather([output_tensors], [tensor]) 141 self._validate_error( 142 exception=cm.exception, 143 op_type="BROADCAST" if self.rank == 0 else "ALLGATHER", 144 rank=self.rank, 145 tensor=tensor, 146 ) 147 148 def _test_collective_shape_mismatch(self, wrapper_pg, use_cuda=False): 149 wrapper_pg.barrier() 150 dim = 2 if self.rank == 0 else 10 151 tensor = torch.randn(20, dim) 152 if use_cuda: 153 tensor = tensor.to(self.rank) 154 with self.assertRaisesRegex(RuntimeError, ".*") as cm: 155 wrapper_pg.allreduce([tensor]) 156 self._validate_error( 157 exception=cm.exception, 158 op_type="ALLREDUCE", 159 rank=self.rank, 160 tensor=tensor, 161 ) 162 163 # Check errors are raised when dimensionality of shapes is different 164 tensor = torch.randn(20, 10, 2) if self.rank == 0 else torch.randn(20, 10) 165 if use_cuda: 166 tensor = tensor.to(self.rank) 167 with self.assertRaisesRegex(RuntimeError, ".*") as cm: 168 wrapper_pg.allreduce([tensor]) 169 self._validate_error( 170 exception=cm.exception, 171 op_type="ALLREDUCE", 172 rank=self.rank, 173 tensor=tensor, 174 ) 175 176 # Check shape errors with scatter 177 input = [ 178 torch.tensor( 179 [self.rank] if self.rank == 0 else [self.rank, self.rank], 180 device=self.rank if use_cuda else "cpu", 181 ) 182 for _ in range(self.world_size) 183 ] 184 outputs = [ 185 torch.tensor( 186 [-1] if self.rank == 0 else [-1, -1], 187 device=self.rank if use_cuda else "cpu", 188 ) 189 for _ in range(self.world_size) 190 ] 191 root_rank = 0 192 opts = c10d.ScatterOptions() 193 opts.rootRank = root_rank 194 with self.assertRaisesRegex(RuntimeError, ".*") as cm: 195 if self.rank == root_rank: 196 wrapper_pg.scatter([outputs[self.rank]], [input], opts).wait() 197 else: 198 wrapper_pg.scatter([outputs[self.rank]], [], opts).wait() 199 self._validate_error( 200 exception=cm.exception, 201 op_type="SCATTER", 202 rank=self.rank, 203 tensor=outputs[self.rank], 204 ) 205 206 207# ASAN is not safe since we are spawning processes. 208if not TEST_WITH_DEV_DBG_ASAN: 209 210 @requires_gloo() 211 @requires_nccl() 212 class ProcessGroupNCCLWrapperTest(AbstractProcessGroupWrapperTest): 213 def setUp(self): 214 super(AbstractProcessGroupWrapperTest, self).setUp() 215 self._spawn_processes() 216 # TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests 217 # that use TORCH_NCCL_BLOCKING_WAIT will test it as expected. 218 os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1" 219 220 @property 221 def world_size(self) -> int: 222 return 2 223 224 def _create_wrapper_pg(self, with_new_group=False, timeout=10.0): 225 store = c10d.FileStore(self.file_name, self.world_size) 226 c10d.init_process_group( 227 backend="nccl", 228 rank=self.rank, 229 world_size=self.world_size, 230 store=store, 231 timeout=timedelta(seconds=timeout), 232 ) 233 if with_new_group: 234 pg = c10d.new_group(backend="nccl", timeout=timedelta(seconds=timeout)) 235 else: 236 _pg = c10d.ProcessGroupNCCL( 237 store, 238 self.rank, 239 self.world_size, 240 timeout=timedelta(seconds=timeout), 241 ) 242 pg = c10d._create_process_group_wrapper( 243 _pg, 244 "unused", 245 store, 246 self.rank, 247 self.world_size, 248 timeout=timeout, 249 ) 250 return pg 251 252 @requires_nccl() 253 @skip_if_lt_x_gpu(2) 254 def test_collective_hang(self): 255 pg = self._create_wrapper_pg(timeout=2.0) 256 self._test_collective_hang(pg) 257 258 # NOTE: these tests are separated by debug level instead of combined into 259 # one due to https://github.com/pytorch/pytorch/issues/55967, they can be 260 # combined after that is resolved. 261 @requires_nccl() 262 @skip_if_lt_x_gpu(2) 263 @with_dist_debug_levels(levels=["DETAIL"]) 264 def test_collectives_op_mismatch_debug_mode(self): 265 pg = self._create_wrapper_pg(with_new_group=True) 266 self._test_collectives_op_mismatch(pg, use_cuda=True) 267 self._test_nccl_only_op_mismatch(pg) 268 269 @requires_nccl() 270 @skip_if_lt_x_gpu(2) 271 @with_dist_debug_levels(levels=["OFF"]) 272 def test_collectives_op_mismatch(self): 273 pg = self._create_wrapper_pg(with_new_group=False) 274 self._test_collectives_op_mismatch(pg, use_cuda=True) 275 self._test_nccl_only_op_mismatch(pg) 276 277 @requires_nccl() 278 @skip_if_lt_x_gpu(2) 279 @with_dist_debug_levels(levels=["DETAIL"]) 280 def test_collective_shape_mismatch_debug_mode_detail(self): 281 pg = self._create_wrapper_pg(with_new_group=True) 282 self._test_collective_shape_mismatch(pg, use_cuda=True) 283 self._test_nccl_only_shape_mismatch(pg) 284 285 @requires_nccl() 286 @skip_if_lt_x_gpu(2) 287 @with_dist_debug_levels(levels=["OFF"]) 288 def test_collective_shape_mismatch_debug_mode_off(self): 289 pg = self._create_wrapper_pg(with_new_group=False) 290 self._test_collective_shape_mismatch(pg, use_cuda=True) 291 self._test_nccl_only_shape_mismatch(pg) 292 293 def _test_nccl_only_op_mismatch(self, wrapper_pg): 294 device = f"cuda:{self.rank}" 295 with self.assertRaisesRegex(RuntimeError, ".*") as cm: 296 output = torch.zeros(4 + self.rank, device=device) 297 input = torch.ones(4 * self.world_size, device=device) 298 if self.rank == 0: 299 wrapper_pg._allgather_base(output, input).wait() 300 else: 301 wrapper_pg._reduce_scatter_base(output, input).wait() 302 303 op_type = "ALLGATHER_BASE" if self.rank == 0 else "REDUCE_SCATTER_BASE" 304 self._validate_error( 305 exception=cm.exception, 306 op_type=op_type, 307 rank=self.rank, 308 tensor=input, 309 ) 310 311 def _test_nccl_only_shape_mismatch(self, wrapper_pg): 312 device = f"cuda:{self.rank}" 313 with self.assertRaisesRegex(RuntimeError, ".*") as cm: 314 output = torch.zeros(4 + self.rank, device=device) 315 input = torch.ones(4 * (self.world_size + 1), device=device) 316 317 wrapper_pg._reduce_scatter_base(output, input).wait() 318 self._validate_error( 319 exception=cm.exception, 320 op_type="REDUCE_SCATTER_BASE", 321 rank=self.rank, 322 tensor=input, 323 verify_diff=False, 324 ) 325 with self.assertRaisesRegex(RuntimeError, ".*") as cm: 326 output = torch.zeros(4, device=device) 327 input = torch.ones((4 + self.rank) * self.world_size, device=device) 328 329 wrapper_pg._reduce_scatter_base(output, input).wait() 330 self._validate_error( 331 exception=cm.exception, 332 op_type="REDUCE_SCATTER_BASE", 333 rank=self.rank, 334 tensor=input, 335 verify_diff=False, 336 ) 337 338 @requires_nccl() 339 @skip_if_lt_x_gpu(2) 340 @with_dist_debug_levels(levels=["DETAIL"]) 341 def test_coalescing_manager_debug_mode_detail(self): 342 """ 343 Tests that coalescing manager w/TORCH_DISTRIBUTED_DEBUG 344 does not crash: https://github.com/pytorch/pytorch/issues/109520 345 """ 346 torch.cuda.set_device(self.rank) 347 pg = self._create_wrapper_pg(with_new_group=True) 348 dev = torch.cuda.current_device() 349 pg._start_coalescing(torch.device(dev)) 350 pg.allreduce([torch.ones(1, device=dev)]) 351 pg._end_coalescing(torch.device(dev)) 352 353 @requires_nccl() 354 @skip_if_lt_x_gpu(2) 355 @with_dist_debug_levels(levels=["DETAIL"]) 356 @patch("torch.distributed.distributed_c10d._GLOO_AVAILABLE", False) 357 def test_debug_level_detail_no_gloo(self): 358 with self.assertRaisesRegex( 359 AssertionError, "ProcessGroupWrapper unsupported without GLOO backend" 360 ): 361 self._create_wrapper_pg() 362 363 @requires_nccl() 364 @skip_if_lt_x_gpu(2) 365 @patch("torch.distributed.distributed_c10d._GLOO_AVAILABLE", False) 366 def test_new_group_no_gloo(self): 367 def patched_isinstance(obj, clazz): 368 if clazz is _ProcessGroupWrapper: 369 raise NameError 370 else: 371 return isinstance(obj, clazz) 372 373 with patch( 374 "torch.distributed.distributed_c10d.isinstance", 375 side_effect=patched_isinstance, 376 ): 377 self._create_wrapper_pg(with_new_group=True) 378 # nothing to assert, isinstance(pg, _ProcessGroupWrapper) 379 # should never be invoked since it is preceeded by 380 # _GLOO_AVAILABLE check, this test will fail on 381 # an unexpected NameError if not. 382 383 384@requires_gloo() 385class ProcessGroupGlooWrapperTest(AbstractProcessGroupWrapperTest): 386 def opts(self, threads=2, timeout=10.0): 387 opts = c10d.ProcessGroupGloo._Options() 388 opts._timeout = timeout 389 opts._devices = [create_device(interface=LOOPBACK)] 390 opts._threads = threads 391 return opts 392 393 def _create_wrapper_pg(self, with_new_group=False, timeout=10.0): 394 store = c10d.FileStore(self.file_name, self.world_size) 395 c10d.init_process_group( 396 backend="gloo", rank=self.rank, world_size=self.world_size, store=store 397 ) 398 if with_new_group: 399 pg = c10d.new_group(backend="gloo") 400 else: 401 _pg = c10d.ProcessGroupGloo( 402 store, self.rank, self.world_size, self.opts(timeout=timeout) 403 ) 404 pg = c10d._create_process_group_wrapper( 405 _pg, 406 "unused", 407 store, 408 self.rank, 409 self.world_size, 410 timeout=timeout, 411 ) 412 return pg 413 414 def test_collective_hang(self): 415 pg = self._create_wrapper_pg(timeout=2.0) 416 self._test_collective_hang(pg) 417 418 # NOTE: these tests are separated by debug level instead of combined into 419 # one due to https://github.com/pytorch/pytorch/issues/55967, they can be 420 # combined after that is resolved. 421 @with_dist_debug_levels(levels=["DETAIL"]) 422 def test_collectives_op_mismatch_debug_mode(self): 423 pg = self._create_wrapper_pg(with_new_group=True) 424 self._test_collectives_op_mismatch(pg) 425 426 @with_dist_debug_levels(levels=["OFF"]) 427 def test_collectives_op_mismatch(self): 428 pg = self._create_wrapper_pg(with_new_group=False) 429 self._test_collectives_op_mismatch(pg) 430 431 @with_dist_debug_levels(levels=["DETAIL"]) 432 def test_collective_shape_mismatch_debug_mode(self): 433 pg = self._create_wrapper_pg(with_new_group=True) 434 self._test_collective_shape_mismatch(pg) 435 436 @with_dist_debug_levels(levels=["OFF"]) 437 def test_collective_shape_mismatch_debug_mode_off(self): 438 pg = self._create_wrapper_pg(with_new_group=False) 439 self._test_collective_shape_mismatch(pg) 440 441 @skip_if_lt_x_gpu(4) 442 @with_dist_debug_levels(levels=["DETAIL"]) 443 def test_collectives_op_mismatch_cuda_debug_mode(self): 444 pg = self._create_wrapper_pg(with_new_group=True) 445 self._test_collectives_op_mismatch(pg, use_cuda=True) 446 447 @skip_if_lt_x_gpu(4) 448 @with_dist_debug_levels(levels=["OFF"]) 449 def test_collectives_op_mismatch_cuda(self): 450 pg = self._create_wrapper_pg(with_new_group=False) 451 self._test_collectives_op_mismatch(pg, use_cuda=True) 452 453 @skip_if_lt_x_gpu(4) 454 @with_dist_debug_levels(levels=["DETAIL"]) 455 def test_collective_shape_mismatch_cuda_debug_mode(self): 456 pg = self._create_wrapper_pg(with_new_group=True) 457 self._test_collective_shape_mismatch(pg, use_cuda=True) 458 459 @skip_if_lt_x_gpu(4) 460 @with_dist_debug_levels(levels=["OFF"]) 461 def test_collective_shape_mismatch_cuda(self): 462 pg = self._create_wrapper_pg(with_new_group=False) 463 self._test_collective_shape_mismatch(pg, use_cuda=True) 464 465 466if __name__ == "__main__": 467 assert ( 468 not torch.cuda._initialized 469 ), "test_pg_wrapper must not have initialized CUDA context on main process" 470 471 run_tests() 472