1# Owner(s): ["oncall: distributed"] 2 3import copy 4import logging 5import math 6import operator 7import os 8import random 9import sys 10import tempfile 11from functools import reduce 12 13import torch 14import torch.distributed as c10d 15 16 17if not c10d.is_available() or not c10d.is_ucc_available(): 18 print("c10d UCC not available, skipping tests", file=sys.stderr) 19 sys.exit(0) 20 21import test_c10d_common 22from test_c10d_common import ( 23 gpus_for_rank, 24 ModuleForDdpCommHook, 25 SparseGradientModule, 26 Task, 27) 28 29import torch.distributed as dist 30import torch.nn.functional as F 31import torch.testing._internal.common_utils as common 32from torch import nn 33from torch.nn.parallel import DistributedDataParallel 34from torch.testing._internal.common_distributed import ( 35 MultiProcessTestCase, 36 requires_ucc, 37 skip_if_lt_x_gpu, 38 verify_ddp_error_logged, 39) 40from torch.testing._internal.common_utils import ( 41 retry_on_connect_failures, 42 run_tests, 43 skip_but_pass_in_sandcastle, 44 TestCase, 45) 46 47 48def simple_reduce_tests(rank, world_size): 49 tests = [ 50 ( 51 c10d.ReduceOp.SUM, 52 torch.tensor([rank + 1.0]), 53 torch.tensor([float(world_size * (world_size + 1) / 2)]), 54 ), 55 ( 56 c10d.ReduceOp.PRODUCT, 57 torch.tensor([rank + 1.0]), 58 torch.tensor([float(math.factorial(world_size))]), 59 ), 60 ( 61 c10d.ReduceOp.MIN, 62 torch.tensor([rank + 1.0]), 63 torch.tensor([1.0]), 64 ), 65 ( 66 c10d.ReduceOp.MAX, 67 torch.tensor([rank + 1.0]), 68 torch.tensor([world_size]), 69 ), 70 ] 71 72 # Generate tests for BAND. 73 # The bit that is set changes in every iteration to check 74 # that the output changes accordingly. 75 for i in range(4): 76 vin = rank | (1 << i) 77 vout = 1 << i 78 tests.append( 79 ( 80 c10d.ReduceOp.BAND, 81 torch.tensor([vin], dtype=torch.int32), 82 torch.tensor([vout], dtype=torch.int32), 83 ), 84 ) 85 86 # Generate tests for BOR. 87 # These emulate a larger world size per iteration by having every 88 # rank contribute multiple values that are pre-OR'ed. 89 for i in range(1, 5): 90 vin = reduce(operator.or_, [rank * i + j for j in range(i)]) 91 vout = reduce(operator.or_, range(world_size * i)) 92 tests.append( 93 ( 94 c10d.ReduceOp.BOR, 95 torch.tensor([vin], dtype=torch.int32), 96 torch.tensor([vout], dtype=torch.int32), 97 ), 98 ) 99 100 # Generate tests for XOR. 101 # These emulate a larger world size per iteration by having every 102 # rank contribute multiple values that are pre-XOR'ed. 103 for i in range(1, 5): 104 vin = reduce(operator.xor, [rank * i + j for j in range(i)]) 105 vout = reduce(operator.xor, range(world_size * i)) 106 tests.append( 107 ( 108 c10d.ReduceOp.BXOR, 109 torch.tensor([vin], dtype=torch.int32), 110 torch.tensor([vout], dtype=torch.int32), 111 ), 112 ) 113 114 return tests 115 116 117class RendezvousEnvTest(TestCase): 118 @requires_ucc() 119 @retry_on_connect_failures 120 def test_logging_init(self): 121 os.environ["WORLD_SIZE"] = "1" 122 os.environ["MASTER_ADDR"] = "127.0.0.1" 123 os.environ["MASTER_PORT"] = str(common.find_free_port()) 124 os.environ["RANK"] = "0" 125 126 previous_handlers = logging.root.handlers 127 128 c10d.init_process_group(backend="ucc", init_method="env://") 129 130 current_handlers = logging.root.handlers 131 self.assertEqual(len(previous_handlers), len(current_handlers)) 132 for current, previous in zip(current_handlers, previous_handlers): 133 self.assertEqual(current, previous) 134 135 c10d.destroy_process_group() 136 137 138class TimeoutTest(test_c10d_common.AbstractTimeoutTest, TestCase): 139 @requires_ucc() 140 @retry_on_connect_failures 141 def test_default_store_timeout_ucc(self): 142 self._test_default_store_timeout("ucc") 143 144 145class ProcessGroupUCCTest(MultiProcessTestCase): 146 def _create_process_group_ucc(self): 147 store = c10d.FileStore(self.file_name, self.world_size) 148 return c10d.ProcessGroupUCC(store, self.rank, self.world_size) 149 150 def setUp(self): 151 super().setUp() 152 self._spawn_processes() 153 154 def tearDown(self): 155 super().tearDown() 156 try: 157 os.remove(self.file_name) 158 except OSError: 159 pass 160 161 @requires_ucc() 162 def test_empty_tensors(self): 163 pg = self._create_process_group_ucc() 164 165 xs = [torch.FloatTensor([])] 166 fut = pg.broadcast(xs).get_future() 167 fut.wait() 168 output = fut.value() 169 self.assertEqual(0, output[0].numel()) 170 self.assertEqual(xs[0], output[0], exact_dtype=False) 171 172 # TODO: add error check testing 173 174 def _test_broadcast_basics(self, fn): 175 pg = self._create_process_group_ucc() 176 177 def broadcast(xs, rootRank, rootTensor): 178 opts = c10d.BroadcastOptions() 179 opts.rootRank = rootRank 180 opts.rootTensor = rootTensor 181 fut = pg.broadcast(xs, opts).get_future() 182 fut.wait() 183 return fut.value() 184 185 # Every rank is root once 186 for i in range(self.world_size): 187 # Run with 1 input tensor 188 x = fn(torch.tensor([self.rank])) 189 output = broadcast([x], i, 0) 190 self.assertEqual(torch.tensor([i]), output[0], exact_dtype=False) 191 192 # TODO: UCC currently does not support multi tensor input 193 194 # Test overloaded convenience function 195 x = torch.tensor([self.rank + 1.0]) 196 fut = pg.broadcast(x, root=0).get_future() 197 fut.wait() 198 result = fut.value() 199 self.assertEqual(torch.tensor([1.0]), result[0]) 200 201 @requires_ucc() 202 def test_broadcast_basics(self): 203 self._test_broadcast_basics(lambda t: t.clone()) 204 205 # TODO: test_broadcast_basics_cuda times out locally 206 207 def _test_allreduce_basics(self, fn): 208 pg = self._create_process_group_ucc() 209 210 # Single input tests 211 tests = simple_reduce_tests(self.rank, self.world_size) 212 for op, input, expected in tests: 213 opts = c10d.AllreduceOptions() 214 opts.reduceOp = op 215 tensor = fn(input) 216 fut = pg.allreduce([tensor], opts).get_future() 217 fut.wait() 218 result = fut.value() 219 self.assertEqual(expected, result[0], exact_dtype=False) 220 221 # TODO: UCC currently does not support multi tensor input 222 223 # Test overloaded convenience function (defaults to using sum) 224 x = fn(torch.tensor([self.rank + 1.0])) 225 fut = pg.allreduce(x).get_future() 226 fut.wait() 227 result = fut.value() 228 self.assertEqual( 229 torch.tensor([float(self.world_size * (self.world_size + 1) / 2)]), 230 result[0], 231 ) 232 233 @requires_ucc() 234 def test_allreduce_basics(self): 235 self._test_allreduce_basics(lambda t: t.clone()) 236 237 # TODO: test_allreduce_basics_cuda times out locally 238 239 def _test_allgather_basics(self, fn): 240 pg = self._create_process_group_ucc() 241 242 # TODO: Run with N input tensor per rank; for now, UCC only supports single tensor input so N=1 243 for n in [1]: 244 input = [fn(torch.tensor([n * self.rank + i])) for i in range(n)] 245 output = [ 246 [fn(torch.tensor([-1])) for _ in range(n * self.world_size)] 247 for _ in range(n) 248 ] 249 expected_output = [ 250 [fn(torch.tensor([i])) for i in range(n * self.world_size)] 251 for _ in range(n) 252 ] 253 fut = pg.allgather(output, input).get_future() 254 fut.wait() 255 result = fut.value() 256 if n == 1: 257 result = [result] 258 self.assertEqual(expected_output, result) 259 260 def test_allgather_basics(self): 261 self._test_allgather_basics(lambda t: t.clone()) 262 263 def _test_reduce_basics(self, fn): 264 pg = self._create_process_group_ucc() 265 for op, input, output in simple_reduce_tests(self.rank, self.world_size): 266 for root in range(self.world_size): 267 opts = c10d.ReduceOptions() 268 opts.reduceOp = op 269 opts.rootRank = root 270 tmp = fn(input) 271 fut = pg.reduce([tmp], opts).get_future() 272 fut.wait() 273 result = fut.value() 274 if root == self.rank: 275 self.assertEqual(output, result[0], exact_dtype=False) 276 277 @requires_ucc() 278 def test_reduce_basics(self): 279 self._test_reduce_basics(lambda t: t.clone()) 280 281 # TODO: test_reduce_basics_cuda times out locally 282 283 @requires_ucc() 284 def test_send_recv_all_to_all(self): 285 pg = self._create_process_group_ucc() 286 287 # Preallocate tensors for input/output 288 inputs = [torch.tensor([self.rank]) for _ in range(self.world_size)] 289 outputs = [torch.tensor([-1]) for _ in range(self.world_size)] 290 291 # Issue sends 292 send_work = [] 293 for i in range(self.world_size): 294 if i == self.rank: 295 continue 296 send_work.append(pg.send([inputs[i]], i, 0)) 297 298 # Issue recvs 299 recv_work = [] 300 for i in range(self.world_size): 301 if i == self.rank: 302 continue 303 recv_work.append(pg.recv([outputs[i]], i, 0)) 304 305 # Wait for sends to complete 306 for work in send_work: 307 work.wait() 308 self.assertTrue(work.is_completed()) 309 310 # Wait for recvs to complete 311 for work in recv_work: 312 work.wait() 313 self.assertTrue(work.is_completed()) 314 315 # Test that every output other than our own contains the respective rank 316 for i in range(self.world_size): 317 if i == self.rank: 318 continue 319 self.assertEqual(torch.tensor([i]), outputs[i]) 320 321 # TODO: test_barrier_implies_wait fails with numerical mismatch, will investigate later 322 @skip_but_pass_in_sandcastle("fails with numerical mismatch, skip for now") 323 @requires_ucc() 324 def test_barrier_implies_wait(self): 325 pg = self._create_process_group_ucc() 326 327 # Kick off allreduce operations 328 size = (100, 100) 329 num = 16 330 tensors = [torch.full(size, float(i)) for i in range(num)] 331 for tensor in tensors: 332 # Note: leak the returned work handle 333 pg.allreduce(tensor) 334 335 # Barrier should ensure all previous work has completed 336 pg.barrier().get_future().wait() 337 338 for i, tensor in enumerate(tensors): 339 self.assertEqual(torch.full(size, float(i * self.world_size)), tensor) 340 341 342class DistributedDataParallelTest( 343 test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase 344): 345 def setUp(self): 346 super().setUp() 347 self._spawn_processes() 348 349 def _get_process_group(self): 350 store = self._get_store() 351 c10d.init_process_group( 352 "ucc", store=store, rank=self.rank, world_size=self.world_size 353 ) 354 return c10d.distributed_c10d._get_default_group() 355 356 def _test_ucc_backend( 357 self, devices, device_ids, multi_device=False, gradient_as_bucket_view=False 358 ): 359 process_group = self._get_process_group() 360 self._test_ddp_with_process_group( 361 process_group, devices, device_ids, multi_device, gradient_as_bucket_view 362 ) 363 364 @requires_ucc() 365 def test_ucc_backend_cpu_module(self): 366 self._test_ucc_backend([torch.device("cpu")], None) 367 368 @requires_ucc() 369 def test_ucc_backend_cpu_module_grad_is_view(self): 370 self._test_ucc_backend( 371 [torch.device("cpu")], None, gradient_as_bucket_view=True 372 ) 373 374 @requires_ucc() 375 @skip_if_lt_x_gpu(2) 376 def test_ucc_backend_1gpu_module_device_ids_integer_list(self): 377 int_devices = gpus_for_rank(self.world_size)[self.rank][:1] 378 devices = [torch.device("cuda:" + str(i)) for i in int_devices] 379 self._test_ucc_backend(devices, int_devices) 380 381 @requires_ucc() 382 @skip_if_lt_x_gpu(2) 383 def test_ucc_backend_1gpu_module_device_ids_torch_device_list(self): 384 int_devices = gpus_for_rank(self.world_size)[self.rank][:1] 385 devices = [torch.device("cuda:" + str(i)) for i in int_devices] 386 self._test_ucc_backend(devices, devices) 387 388 # TODO: test_ucc_backend_2gpu_module and test_ucc_backend_4gpu_module 389 # require broadcast_coalesced which is not supported by ucc currently 390 @skip_but_pass_in_sandcastle( 391 "requires broadcast coalesced, which is not supported by ucc currently" 392 ) 393 @requires_ucc() 394 @skip_if_lt_x_gpu(4) 395 def test_ucc_backend_2gpu_module(self): 396 int_devices = gpus_for_rank(self.world_size)[self.rank][:2] 397 devices = [torch.device("cuda:" + str(i)) for i in int_devices] 398 self._test_ucc_backend(devices, None, multi_device=True) 399 400 @skip_but_pass_in_sandcastle( 401 "requires broadcast coalesced, which is not supported by ucc currently" 402 ) 403 @requires_ucc() 404 @skip_if_lt_x_gpu(8) 405 def test_ucc_backend_4gpu_module(self): 406 int_devices = gpus_for_rank(self.world_size)[self.rank][:4] 407 devices = [torch.device("cuda:" + str(i)) for i in int_devices] 408 self._test_ucc_backend(devices, None, multi_device=True) 409 410 def _test_global_local_unused_params_grad( 411 self, gradient_as_bucket_view=False, static_graph=False 412 ): 413 """ 414 By simulating a multi-task training, this test is to make sure: 415 1) DDP does not touch the grad of globally unused parameters. 416 2) DDP does update the grad of locally unused parameters. 417 """ 418 419 class GlobalLocalUnusedParamModule(nn.Module): 420 def __init__(self) -> None: 421 super().__init__() 422 self.t0 = Task() 423 self.t1 = Task() 424 self.task_unused = Task() 425 426 def task_parameters(self): 427 return (self.t0.p, self.t1.p, self.task_unused.p) 428 429 def forward(self, x, rank): 430 return self.t0(x) if rank == 0 else self.t1(x) 431 432 def run_and_verify_grad(model): 433 # Run forward 434 output = model(8, self.rank) 435 436 # The grads of all parameters should be None at this point. 437 t0_p, t1_p, task_unused_p = model.module.task_parameters() 438 self.assertIsNone(t0_p.grad) 439 self.assertIsNone(t1_p.grad) 440 self.assertIsNone(task_unused_p.grad) 441 442 # Run backward 443 output.mean().backward() 444 445 # Now locally unused parameter should have grad updated on all ranks. 446 # However the globally unused parameter should still have None grad. 447 self.assertIsNotNone(t0_p.grad) 448 self.assertIsNotNone(t1_p.grad) 449 self.assertIsNone(task_unused_p.grad) 450 451 process_group = self._get_process_group() 452 453 # Test on CPU 454 cpu_model = DistributedDataParallel( 455 GlobalLocalUnusedParamModule().cpu(), 456 process_group=process_group, 457 find_unused_parameters=True, 458 gradient_as_bucket_view=gradient_as_bucket_view, 459 static_graph=static_graph, 460 ) 461 run_and_verify_grad(cpu_model) 462 463 # Test on GPU 464 device_id = gpus_for_rank(self.world_size)[self.rank][0] 465 gpu_model = DistributedDataParallel( 466 GlobalLocalUnusedParamModule().to(device_id), 467 device_ids=[device_id], 468 process_group=process_group, 469 find_unused_parameters=True, 470 gradient_as_bucket_view=gradient_as_bucket_view, 471 static_graph=static_graph, 472 ) 473 run_and_verify_grad(gpu_model) 474 475 # TODO: times out 476 @skip_but_pass_in_sandcastle("times out") 477 @requires_ucc() 478 @skip_if_lt_x_gpu(2) 479 def test_global_local_unused_params_grad(self): 480 self._test_global_local_unused_params_grad() 481 482 # TODO: times out 483 @skip_but_pass_in_sandcastle("times out") 484 @requires_ucc() 485 @skip_if_lt_x_gpu(2) 486 def test_global_local_unused_params_grad_with_grad_is_view(self): 487 self._test_global_local_unused_params_grad(gradient_as_bucket_view=True) 488 489 # TODO: times out 490 @skip_but_pass_in_sandcastle("times out") 491 @requires_ucc() 492 @skip_if_lt_x_gpu(2) 493 def test_global_local_unused_params_grad_with_static_graph(self): 494 self._test_global_local_unused_params_grad(static_graph=True) 495 496 # TODO: times out 497 @skip_but_pass_in_sandcastle("times out") 498 @requires_ucc() 499 @skip_if_lt_x_gpu(2) 500 def test_find_unused_parameters_when_unused_parameters_empty(self): 501 """ 502 An empty unused_parameters array does not imply find_unused_parameters = 503 false. This test makes sure that DDP allreduces unused parameters 504 accordingly where the forward pass in some process uses all parameters. 505 This unit test creates a module that uses all parameters in rank = 0, and 506 has unused parameters in other ranks. 507 """ 508 509 class FindUnusedParamModule(nn.Module): 510 def __init__(self) -> None: 511 super().__init__() 512 self.t0 = Task() 513 self.t1 = Task() 514 515 def task_parameters(self): 516 return (self.t0.p, self.t1.p) 517 518 def forward(self, x, rank): 519 return self.t1(self.t0(x)) if rank == 0 else self.t1(x) 520 521 def run_and_verify_grad(model): 522 # Run forward 523 output = model(8, self.rank) 524 525 # The grads of all parameters should be None at this point. 526 [self.assertIsNone(t_p.grad) for t_p in model.module.task_parameters()] 527 528 # Run backward 529 output.mean().backward() 530 531 # Now locally unused parameter should have grad updated on all ranks. 532 [self.assertIsNotNone(t_p.grad) for t_p in model.module.task_parameters()] 533 534 process_group = self._get_process_group() 535 536 # Test on CPU 537 cpu_model = DistributedDataParallel( 538 FindUnusedParamModule().cpu(), 539 process_group=process_group, 540 find_unused_parameters=True, 541 ) 542 run_and_verify_grad(cpu_model) 543 544 # Test on GPU 545 device_id = gpus_for_rank(self.world_size)[self.rank][0] 546 gpu_model = DistributedDataParallel( 547 FindUnusedParamModule().to(device_id), 548 device_ids=[device_id], 549 process_group=process_group, 550 find_unused_parameters=True, 551 ) 552 run_and_verify_grad(gpu_model) 553 554 @requires_ucc() 555 def test_ignored_output(self): 556 """ 557 Test that the output of a model can be ignored and that there is no 558 implicit requirement that `backward` gets called. 559 """ 560 process_group = self._get_process_group() 561 562 class IgnoredOutput(nn.Module): 563 def __init__(self) -> None: 564 super().__init__() 565 self.fc1 = nn.Linear(2, 10, bias=False) 566 self.fc2 = nn.Linear(10, 4, bias=False) 567 self.relu = nn.ReLU() 568 569 def forward(self, x): 570 x = self.relu(self.fc1(x)) 571 x = self.relu(self.fc2(x)) 572 return F.softmax(x, dim=1) 573 574 model = DistributedDataParallel( 575 IgnoredOutput().float(), 576 process_group=process_group, 577 ) 578 579 batch_size = 4 580 criterion = nn.CrossEntropyLoss() 581 input = torch.rand([batch_size, 2], dtype=torch.float) 582 target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]) 583 584 # Run a few iterations where we ignore the output. 585 for _ in range(4): 586 output = model(input) 587 del output 588 589 # Run a few iterations where we use the output. 590 for _ in range(4): 591 output = model(input) 592 loss = criterion(output, target) 593 loss.backward() 594 595 @requires_ucc() 596 def test_ignored_output_with_unused_parameters(self): 597 """ 598 Test that the output of a model can be ignored and that there is no 599 implicit requirement that `backward` gets called, if not all model 600 parameters participated in computing the model output. 601 """ 602 process_group = self._get_process_group() 603 604 class IgnoredOutputWithUnusedParameters(nn.Module): 605 def __init__(self) -> None: 606 super().__init__() 607 self.fc1 = nn.Linear(2, 10, bias=False) 608 self.fc2 = nn.Linear(10, 4, bias=False) 609 self.fc3 = nn.Linear(4, 4, bias=False) 610 self.relu = nn.ReLU() 611 612 def forward(self, x): 613 x = self.relu(self.fc1(x)) 614 x = self.relu(self.fc2(x)) 615 return F.softmax(x, dim=1) 616 617 model = DistributedDataParallel( 618 IgnoredOutputWithUnusedParameters().float(), 619 process_group=process_group, 620 find_unused_parameters=True, 621 ) 622 623 batch_size = 4 624 criterion = nn.CrossEntropyLoss() 625 input = torch.rand([batch_size, 2], dtype=torch.float) 626 target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]) 627 628 # Run a few iterations where we ignore the output. 629 for _ in range(4): 630 output = model(input) 631 del output 632 633 # Run a few iterations where we use the output. 634 for _ in range(4): 635 output = model(input) 636 loss = criterion(output, target) 637 loss.backward() 638 639 def _run_and_verify_sparse_gradients(self, vanilla_model, ddp_model): 640 mult = 2 641 batch_size = mult * self.world_size 642 criterion = nn.CrossEntropyLoss() 643 input = torch.randint(0, 10, [batch_size, 2]) 644 target = torch.randint(0, 10, [batch_size]) 645 646 # Run with entire batch against single process version 647 criterion(vanilla_model(input), target).backward() 648 649 # Run with partial batch against multi process version 650 partial_input = input.split(mult)[self.rank] 651 partial_target = target.split(mult)[self.rank] 652 criterion(ddp_model(partial_input), partial_target).backward() 653 654 # Check that the gradients are sparse and identical 655 vanilla_parameter = next(vanilla_model.parameters()) 656 ddp_parameter = next(ddp_model.parameters()) 657 self.assertEqual( 658 vanilla_parameter.grad.coalesce(), ddp_parameter.grad.coalesce() 659 ) 660 661 @requires_ucc() 662 @skip_if_lt_x_gpu(2) 663 def test_save_load_checkpoint(self): 664 dist.init_process_group( 665 "ucc", 666 init_method=f"file://{self.file_name}", 667 world_size=self.world_size, 668 rank=self.rank, 669 ) 670 671 class TestModel(nn.Module): 672 def __init__(self) -> None: 673 super().__init__() 674 self.fc1 = nn.Linear(2, 10, bias=False) 675 self.fc2 = nn.Linear(10, 4, bias=False) 676 self.relu = nn.ReLU() 677 678 def forward(self, x): 679 x = self.relu(self.fc1(x)) 680 x = self.relu(self.fc2(x)) 681 return F.softmax(x, dim=1) 682 683 def train_loop(model, optimizer, iterations): 684 for _ in range(iterations): 685 optimizer.zero_grad() 686 output = model(input) 687 loss = criterion(output, target) 688 loss.backward() 689 optimizer.step() 690 691 device_id = gpus_for_rank(self.world_size)[self.rank][0] 692 693 model_withload = TestModel().float().to(device_id) 694 model_withoutload = TestModel().float().to(device_id) 695 696 ddp_withload = DistributedDataParallel( 697 model_withload, 698 device_ids=[device_id], 699 ) 700 ddp_withoutload = DistributedDataParallel( 701 model_withoutload, 702 device_ids=[device_id], 703 ) 704 705 # ensure that all the three models start with the same set of parameters. By default they are randomized on construction 706 for p in ddp_withload.parameters(): 707 with torch.no_grad(): 708 p.zero_() 709 for p in model_withload.parameters(): 710 with torch.no_grad(): 711 p.zero_() 712 for p in ddp_withoutload.parameters(): 713 with torch.no_grad(): 714 p.zero_() 715 716 batch_size = 4 717 criterion = nn.CrossEntropyLoss() 718 719 optimizer_withload = torch.optim.SGD(ddp_withload.parameters(), lr=0.001) 720 optimizer_non_ddp_withload = torch.optim.SGD( 721 model_withload.parameters(), lr=0.001 722 ) 723 optimizer_withoutload = torch.optim.SGD(ddp_withoutload.parameters(), lr=0.001) 724 725 input = torch.rand([batch_size, 2], dtype=torch.float).to(device_id) 726 target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to( 727 device_id 728 ) 729 730 # run the model for 6 iterations, with a checkpoint in the middle 731 train_loop(ddp_withload, optimizer_withload, 3) 732 733 # zero out parameters of both DDP and non-DDP models and reload them from the DDP state dict 734 checkpoint_path = tempfile.gettempdir() + "/model.checkpoint" 735 if self.rank == 0: 736 torch.save(ddp_withload.state_dict(), checkpoint_path) 737 738 dist.barrier() 739 map_location = {"cuda:%d" % 0: "cuda:%d" % self.rank} 740 ddp_state_dict = torch.load(checkpoint_path, map_location=map_location) 741 742 for model in [ddp_withload, model_withload]: 743 for p in ddp_withload.parameters(): 744 with torch.no_grad(): 745 p.zero_() 746 ddp_withload.load_state_dict(ddp_state_dict) 747 # the non-DDP model needs to first remove the prefix of "module." from the DDP state dict 748 torch.nn.modules.utils.consume_prefix_in_state_dict_if_present( 749 ddp_state_dict, "module." 750 ) 751 model_withload.load_state_dict(ddp_state_dict) 752 753 train_loop(ddp_withload, optimizer_withload, 3) 754 train_loop(model_withload, optimizer_non_ddp_withload, 3) 755 756 # re-run the model with the same inputs for 6 iterations with no checkpoint 757 train_loop(ddp_withoutload, optimizer_withoutload, 6) 758 759 for p_withload, p_withoutload, p_non_ddp_withload in zip( 760 ddp_withload.parameters(), 761 ddp_withoutload.parameters(), 762 model_withload.parameters(), 763 ): 764 self.assertEqual(p_withload, p_withoutload) 765 self.assertEqual(p_non_ddp_withload, p_withoutload) 766 767 def _test_sparse_gradients(self, gradient_as_bucket_view=False): 768 process_group = self._get_process_group() 769 770 # Ensure initialized weights and inputs are identical across processes 771 torch.manual_seed(1337) 772 773 vanilla_model = SparseGradientModule() 774 ddp_model = DistributedDataParallel( 775 copy.deepcopy(vanilla_model), 776 process_group=process_group, 777 gradient_as_bucket_view=gradient_as_bucket_view, 778 ) 779 780 self._run_and_verify_sparse_gradients(vanilla_model, ddp_model) 781 782 # TODO: backward pass: input tensor has to be dense 783 @skip_but_pass_in_sandcastle("backward pass: input tensor has to be dense") 784 @requires_ucc() 785 def test_sparse_gradients(self): 786 self._test_sparse_gradients() 787 788 # TODO: backward pass: input tensor has to be dense 789 @skip_but_pass_in_sandcastle("backward pass: input tensor has to be dense") 790 @requires_ucc() 791 def test_sparse_gradients_grad_is_view(self): 792 self._test_sparse_gradients(gradient_as_bucket_view=True) 793 794 @requires_ucc() 795 def test_ddp_comm_hook_future_passing_cpu(self): 796 """ 797 This unit test verifies whether the Future object is passed properly. 798 The callback function creates a Future object and sets a value to it. 799 """ 800 process_group = self._get_process_group() 801 802 # Test on CPU 803 cpu_model = DistributedDataParallel( 804 ModuleForDdpCommHook().cpu(), process_group=process_group 805 ) 806 807 # Register DDP Communication Hook 808 cpu_model.register_comm_hook(None, self._simple_hook) 809 810 # check whether the grads are equal to what then callback returns. 811 # without the comm_hook, result would be 0.25 * torch.ones(2, 2). 812 self._run_and_verify_hook(cpu_model, 8, 2 * torch.ones(2, 2)) 813 814 def _gpu_model_with_ddp_comm_hook( 815 self, process_group, hook=None, gradient_as_bucket_view=False, state=None 816 ): 817 device_id = gpus_for_rank(self.world_size)[self.rank][0] 818 gpu_model = DistributedDataParallel( 819 ModuleForDdpCommHook().to(device_id), 820 device_ids=[device_id], 821 process_group=process_group, 822 gradient_as_bucket_view=gradient_as_bucket_view, 823 ) 824 825 # Register a DDP communication hook if any. 826 if hook is not None: 827 gpu_model.register_comm_hook(state, hook) 828 829 return gpu_model 830 831 @requires_ucc() 832 @skip_if_lt_x_gpu(2) 833 def test_ddp_comm_hook_future_passing_gpu_ucc(self): 834 """ 835 This unit test verifies whether the Future object is passed properly using ucc backend. 836 The hook callback function creates a Future object and sets a value to it. 837 """ 838 process_group = self._get_process_group() 839 840 # Get GPU model with simple_hook registered. 841 gpu_model = self._gpu_model_with_ddp_comm_hook(process_group, self._simple_hook) 842 843 # check whether the grads are equal to what simple_hook's then callback returns. 844 # without the comm_hook, result would be 0.25 * torch.ones(2, 2). 845 self._run_and_verify_hook(gpu_model, 8, 2 * torch.ones(2, 2)) 846 847 @requires_ucc() 848 def test_ddp_invalid_comm_hook_init(self): 849 """ 850 This unit test makes sure that register_comm_hook properly checks the format 851 of hook defined by user. The Python hook must be callable. This test also 852 checks whether bucket annotation checked properly if defined. 853 """ 854 process_group = self._get_process_group() 855 856 model = DistributedDataParallel( 857 ModuleForDdpCommHook(), process_group=process_group 858 ) 859 860 with self.assertRaisesRegex(TypeError, "Communication hook must be callable."): 861 model.register_comm_hook(state=None, hook=1) 862 863 with self.assertRaisesRegex( 864 ValueError, "bucket annotation should be dist.GradBucket." 865 ): 866 867 def comm_hook( 868 state: object, bucket: int 869 ) -> torch.futures.Future[torch.Tensor]: 870 return torch.futures.Future() 871 872 model.register_comm_hook(state=None, hook=comm_hook) 873 874 @requires_ucc() 875 def test_ddp_invalid_comm_hook_return_type(self): 876 """ 877 This test checks whether return annotation checked properly if defined. It also 878 checks whether an internal error is thrown if return type is incorrect and user 879 hasn't specified any return type annotation. 880 """ 881 process_group = self._get_process_group() 882 883 model = DistributedDataParallel( 884 ModuleForDdpCommHook(), process_group=process_group 885 ) 886 887 expected_err = ( 888 "Communication hook: return annotation should be torch.futures.Future" 889 ) 890 with self.assertRaisesRegex( 891 ValueError, 892 expected_err, 893 ): 894 895 def comm_hook(state: object, bucket: dist.GradBucket) -> int: 896 return torch.futures.Future() 897 898 model.register_comm_hook(state=None, hook=comm_hook) 899 900 verify_ddp_error_logged(model, expected_err) 901 902 with self.assertRaisesRegex( 903 RuntimeError, 904 "callback must return a torch.futures.Future object, but got", 905 ): 906 907 def comm_hook(state: object, bucket: dist.GradBucket): 908 return 1 909 910 model.register_comm_hook(state=None, hook=comm_hook) 911 912 # Run forward 913 output = model(8, self.rank) 914 915 # Run backward 916 output.mean().backward() 917 918 @requires_ucc() 919 def test_ddp_comm_hook_register_just_once(self): 920 """ 921 DDP communication hook can only be registered once. This test validates whether 922 the error is thrown properly when register_comm_hook is called more than once. 923 """ 924 process_group = self._get_process_group() 925 926 model = DistributedDataParallel( 927 ModuleForDdpCommHook(), process_group=process_group 928 ) 929 930 def dummy_hook(state, bucket): 931 fut = torch.futures.Future() 932 fut.set_result([bucket.buffer()]) 933 return fut 934 935 model.register_comm_hook(None, dummy_hook) 936 937 with self.assertRaisesRegex( 938 RuntimeError, 939 "register_comm_hook or register_builtin_comm_hook can only be called once.", 940 ): 941 model.register_comm_hook(None, dummy_hook) 942 943 # TODO: backward pass: input tensor must be dense 944 @skip_but_pass_in_sandcastle("backward pass: input tensor has to be dense") 945 @requires_ucc() 946 def test_ddp_comm_hook_sparse_gradients(self): 947 """ 948 Runs "test_sparse_gradients" unit test with DDP communication hook. We define a 949 simple hook that does allreduce and works with ucc backend for this test. 950 """ 951 process_group = self._get_process_group() 952 953 # Ensure initialized weights and inputs are identical across processes 954 torch.manual_seed(1337) 955 956 vanilla_model = SparseGradientModule() 957 ddp_model = DistributedDataParallel( 958 copy.deepcopy(vanilla_model), 959 process_group=process_group, 960 ) 961 962 def allreduce_hook_ucc( 963 state: object, bucket: dist.GradBucket 964 ) -> torch.futures.Future[torch.Tensor]: 965 def div_by_world_size(fut): 966 # Divide the result by 2 * world_size. 967 return fut.wait()[0] / self.world_size 968 969 # Prepare allreduced grad bucket tensors by running an async work. 970 fut = process_group.allreduce([bucket.buffer()]).get_future() 971 return fut.then(div_by_world_size) 972 973 ddp_model.register_comm_hook(None, allreduce_hook_ucc) 974 975 self._run_and_verify_sparse_gradients(vanilla_model, ddp_model) 976 977 978class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase): 979 @property 980 def device(self): 981 return "cpu" 982 983 def setUp(self): 984 super().setUp() 985 self._spawn_processes() 986 987 def tearDown(self): 988 super().tearDown() 989 try: 990 os.remove(self.file_name) 991 except OSError: 992 pass 993 994 @requires_ucc() 995 @skip_if_lt_x_gpu(2) 996 def test_sequence_num_set_default_pg_ucc(self): 997 self._test_sequence_num_set_default_pg(backend="ucc") 998 999 @requires_ucc() 1000 @skip_if_lt_x_gpu(2) 1001 def test_sequence_num_set_ucc_new_group(self): 1002 self._test_sequence_num_set_new_group(backend="ucc") 1003 1004 @skip_if_lt_x_gpu(2) 1005 @requires_ucc() 1006 def test_sequence_num_incremented_ucc_default(self): 1007 self._test_sequence_num_incremented_default_group("ucc") 1008 1009 @skip_if_lt_x_gpu(4) 1010 @requires_ucc() 1011 def test_sequence_num_incremented_ucc_subgroup(self): 1012 if self.world_size < 4: 1013 return skip_but_pass_in_sandcastle("Test requires world_size of at least 4") 1014 self._test_sequence_num_incremented_subgroup("ucc") 1015 1016 @skip_but_pass_in_sandcastle("Fails on M60") 1017 @requires_ucc() 1018 def test_ucc_barrier_device_ids(self): 1019 store = c10d.FileStore(self.file_name, self.world_size) 1020 c10d.init_process_group( 1021 backend="ucc", rank=self.rank, world_size=self.world_size, store=store 1022 ) 1023 1024 with self.assertRaisesRegex(RuntimeError, "device_ids not supported"): 1025 c10d.barrier(device_ids=[self.rank]) 1026 1027 @skip_but_pass_in_sandcastle("Fails on M60") 1028 @skip_if_lt_x_gpu(2) 1029 @requires_ucc() 1030 def test_ucc_warn_not_in_group(self): 1031 self._test_warn_not_in_group(backend="ucc") 1032 1033 @skip_if_lt_x_gpu(2) 1034 @requires_ucc() 1035 def test_ucc_rank_membership(self): 1036 self._test_rank_membership(backend="ucc") 1037 1038 @skip_if_lt_x_gpu(2) 1039 @requires_ucc() 1040 def test_tensor_dtype_mismatch(self): 1041 self._test_tensor_dtype_mismatch(backend="ucc") 1042 1043 @skip_if_lt_x_gpu(2) 1044 @requires_ucc() 1045 def test_tensor_dtype_complex(self): 1046 self._test_tensor_dtype_complex(backend="ucc") 1047 1048 1049class UccProcessGroupWithDispatchedCollectivesTests( 1050 test_c10d_common.ProcessGroupWithDispatchedCollectivesTests 1051): 1052 @skip_but_pass_in_sandcastle("Fails on M60") 1053 @requires_ucc() 1054 @skip_if_lt_x_gpu(1) 1055 def test_collectives(self): 1056 # includes reduce, broadcast, all_reduce, all_gather, reduce_scatter, barrier, all_to_all, scatter 1057 self._test_collectives(backend="ucc") 1058 1059 @skip_but_pass_in_sandcastle("Fails on M60") 1060 @requires_ucc() 1061 @skip_if_lt_x_gpu(1) 1062 def test_allgather_base(self): 1063 store = dist.FileStore(self.file_name, self.world_size) 1064 dist.init_process_group( 1065 "ucc", 1066 world_size=self.world_size, 1067 rank=self.rank, 1068 store=store, 1069 ) 1070 device = "cuda" 1071 tensor = torch.ones(10, 10, device=torch.device(device)) 1072 output_tensor = torch.zeros(10, 10, device=torch.device(device)) 1073 dist.all_gather_into_tensor(output_tensor, tensor) 1074 self.assertEqual(output_tensor, tensor) 1075 1076 1077if __name__ == "__main__": 1078 assert ( 1079 not torch.cuda._initialized 1080 ), "test_distributed must not have initialized CUDA context on main process" 1081 1082 run_tests() 1083