1# mypy: allow-untyped-defs 2 3import contextlib 4import enum 5import logging 6import os 7import threading 8from typing import NamedTuple 9 10import torch 11import torch.distributed as dist 12import torch.distributed.autograd as dist_autograd 13import torch.nn as nn 14from torch.distributed import rpc 15from torch.distributed.nn import RemoteModule 16from torch.nn.parallel import DistributedDataParallel 17from torch.testing._internal.common_distributed import ( 18 requires_gloo, 19 requires_nccl, 20 skip_if_lt_x_gpu, 21 skip_if_rocm_multiprocess, 22) 23from torch.testing._internal.dist_utils import INIT_METHOD_TEMPLATE, dist_init 24from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( 25 RpcAgentTestFixture, 26) 27 28 29NUM_EM_ROW = 2 30D_SPARSE = 3 31D_DENSE = 2 32D_HID = 3 33D_OUT = 1 34NUM_TRAINERS = 4 35# Trainers + the master + the remote worker 36WORLD_SIZE = NUM_TRAINERS + 2 37TRAINER_RANKS = list(range(NUM_TRAINERS)) 38REMOTE_WORKER_RANK = TRAINER_RANKS[-1] + 1 39MASTER_RANK = REMOTE_WORKER_RANK + 1 40 41 42class DdpMode(enum.Enum): 43 # Don't apply DDP 44 NONE = enum.auto() 45 # Apply DDP to the top level nn.Module 46 OUTSIDE = enum.auto() 47 # Embed DDP inside the top level nn.Module 48 INSIDE = enum.auto() 49 50 51def init_logger(): 52 logger = logging.getLogger(__name__) 53 level = logging.DEBUG if "debug" in os.environ else logging.INFO 54 logger.setLevel(level) 55 console = logging.StreamHandler() 56 formatter = logging.Formatter( 57 "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s" 58 ) 59 console.setFormatter(formatter) 60 console.setLevel(level) 61 # add the handlers to the logger 62 logger.addHandler(console) 63 logger.propagate = False 64 return logger 65 66 67gLogger = init_logger() 68 69 70class FeatureSet(NamedTuple): 71 """ A feature set has 2 types of features""" 72 73 dense_features: torch.Tensor 74 sparse_features: torch.LongTensor 75 values: torch.Tensor 76 77 78def _call_method(method, rref, *args, **kwargs): 79 return method(rref.local_value(), *args, **kwargs) 80 81 82def _remote_method(method, rref, *args, **kwargs): 83 args_tup = tuple([method, rref] + list(args)) 84 return rpc.rpc_sync(rref.owner(), _call_method, args=args_tup, kwargs=kwargs) 85 86 87def _remote_method_async(method, rref, *args, **kwargs): 88 args_tup = tuple([method, rref] + list(args)) 89 return rpc.rpc_async(rref.owner(), _call_method, args=args_tup, kwargs=kwargs) 90 91 92class RemoteEM(nn.Module): 93 def __init__(self, num_embeddings: int, embedding_dim: int): 94 gLogger.info("Initing RemoteEM with %s %s", num_embeddings, embedding_dim) 95 super().__init__() 96 init_em = [0.5] * embedding_dim 97 self.em = nn.EmbeddingBag( 98 num_embeddings, 99 embedding_dim, 100 _weight=torch.tensor([init_em] * num_embeddings), 101 ) 102 103 def forward(self, input: torch.Tensor): 104 gLogger.debug("Running RemoteEM.forward() on: %s", input) 105 return self.em(input, offsets=torch.LongTensor(range(input.shape[0]))) 106 107 108# Return a linear module with predefined parameters. 109def getLinear(d_in, d_out): 110 l = nn.Linear(d_in, d_out, bias=False) 111 w = torch.ones((d_out, d_in)) 112 w[0][0] = -1 113 w.requires_grad_() 114 l.weight.data = w 115 return l 116 117 118class RemoteNet(nn.Module): 119 def __init__(self, d_in: int, d_out: int): 120 gLogger.info("Initing RemoteNet with %s %s", d_in, d_out) 121 super().__init__() 122 self.fc = getLinear(d_in, d_out) 123 self.relu = nn.ReLU() 124 125 def forward(self, input: torch.Tensor): 126 gLogger.debug("Running RemoteNet.forward() on: %s", input) 127 return self.relu(self.fc(input)) 128 129 130class HybridModel(nn.Module): 131 def __init__( 132 self, 133 remote_em_rref: rpc.RRef, 134 remote_net_rref: rpc.RRef, 135 process_group_for_ddp: dist.ProcessGroup = None, 136 ): 137 super().__init__() 138 self.remote_em_rref = remote_em_rref 139 self.remote_net_rref = remote_net_rref 140 self.fc1 = getLinear(D_DENSE, D_DENSE) 141 self.fc2 = getLinear(D_HID, D_OUT) 142 143 self.non_ddp_params = tuple(self.fc1.parameters()) + tuple( 144 self.fc2.parameters() 145 ) 146 self.ddp_params = () 147 148 if process_group_for_ddp is not None: 149 self.non_ddp_params, self.ddp_params = ( 150 tuple(self.fc1.parameters()), 151 tuple(self.fc2.parameters()), 152 ) 153 gLogger.info("Use DDP for the second local net.") 154 self.fc2 = DistributedDataParallel( 155 self.fc2, check_reduction=True, process_group=process_group_for_ddp 156 ) 157 158 gLogger.info( 159 "HybridModel has %s groups of parameters.", len(list(self.parameters())) 160 ) 161 162 def forward(self, input: FeatureSet): 163 gLogger.debug("Running HybridModel.forward on %s", input) 164 sparse = _remote_method( 165 RemoteEM.forward, self.remote_em_rref, input.sparse_features 166 ) 167 # The same size of mini batch. 168 assert sparse.shape[0] == input.dense_features.shape[0] 169 dense = self.fc1(input.dense_features) 170 x = torch.cat((dense, sparse), 1) 171 gLogger.debug("Concatenated feature: %s", x) 172 x = _remote_method(RemoteNet.forward, self.remote_net_rref, x) 173 return self.fc2(x) 174 175 176class Trainer: 177 def __init__( 178 self, 179 remote_em_rref: rpc.RRef, 180 remote_net_rref: rpc.RRef, 181 ddp_mode: DdpMode, 182 rank: int, 183 ): 184 self.rank = rank 185 self.trainer_group = ( 186 dist.new_group(TRAINER_RANKS) 187 if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE) 188 else None 189 ) 190 self.remote_em_rref = remote_em_rref 191 self.remote_net_rref = remote_net_rref 192 self.hybrid_module = HybridModel( 193 self.remote_em_rref, 194 self.remote_net_rref, 195 self.trainer_group if ddp_mode in (DdpMode.INSIDE,) else None, 196 ) 197 self.ddp_params, self.non_ddp_params = ( 198 self.hybrid_module.ddp_params, 199 self.hybrid_module.non_ddp_params, 200 ) 201 if ddp_mode == DdpMode.OUTSIDE: 202 gLogger.info("Wrapping the whole hybrid module into DDP.") 203 self.ddp_params += self.non_ddp_params 204 self.non_ddp_params = () 205 self.hybrid_module = DistributedDataParallel( 206 self.hybrid_module, 207 check_reduction=True, 208 process_group=self.trainer_group, 209 ) 210 gLogger.info( 211 "Succeeded in creating a HybridModel instance with " 212 "%s ddp params and %s other local params.", 213 len(self.ddp_params), len(self.non_ddp_params) 214 ) 215 216 def destroy_pg(self): 217 if self.trainer_group: 218 dist.destroy_process_group(self.trainer_group) 219 220 def train_batch( 221 self, 222 mini_batch: FeatureSet, 223 trainer_has_less_inputs: bool, 224 simulate_uneven_inputs: bool, 225 ): 226 grads_dict = None 227 228 if not simulate_uneven_inputs: 229 input_batches = [mini_batch] 230 else: 231 # Split into microbatches, and trim to simulate uneven inputs. 232 dense_features = mini_batch.dense_features 233 sparse_features = mini_batch.sparse_features 234 values = mini_batch.values 235 236 dense_microbatch = torch.split(dense_features, 2) 237 sparse_microbatch = torch.split(sparse_features, 2) 238 values_microbatch = torch.split(values, 2) 239 batches = [] 240 for d, s, v in zip(dense_microbatch, sparse_microbatch, values_microbatch): 241 feature_set = FeatureSet(dense_features=d, sparse_features=s, values=v) 242 batches.append(feature_set) 243 244 if trainer_has_less_inputs: 245 input_batches = batches[: len(batches) // 2] 246 gLogger.info( 247 "Trainer reduced input patches from %s " 248 "to %s to simulate uneven inputs.", 249 len(batches), len(input_batches) 250 ) 251 else: 252 input_batches = batches 253 254 with self.hybrid_module.join() if simulate_uneven_inputs else contextlib.nullcontext(): 255 for b in input_batches: 256 with dist_autograd.context() as context_id: 257 output = self.hybrid_module.forward(b) 258 loss = (output * mini_batch.values).sum() 259 dist_autograd.backward(context_id, [loss]) 260 grads_dict = dist_autograd.get_gradients(context_id) 261 gLogger.info( 262 "Loss is %s for mini batch: %s. " 263 "Grads dict has %s entries: %s", loss, mini_batch, len(grads_dict), grads_dict 264 ) 265 return ( 266 tuple(grads_dict[param] for param in self.ddp_params), 267 tuple(grads_dict[param] for param in self.non_ddp_params), 268 ) 269 270 271def get_training_examples(): 272 n = 16 273 training_examples = FeatureSet( 274 dense_features=torch.zeros((n, D_DENSE)), 275 sparse_features=torch.zeros(n, dtype=torch.long), 276 values=torch.zeros(n), 277 ) 278 idx = 0 279 # Every example has another one that has exactly the same features but an 280 # opposite value. Therefore, their grads cancel each other in all-reduce. 281 for value in (-1, 1): 282 for x in (-1.0 * value, 1.0 * value): 283 for y in (1.0 * value, -1.0 * value): 284 for z in (0, 1): 285 training_examples.dense_features[idx, :] = torch.tensor((x, y)) 286 training_examples.sparse_features[idx] = z 287 training_examples.values[idx] = value 288 idx += 1 289 290 # Split the examples among NUM_TRAINERS trainers 291 assert 0 == (n % NUM_TRAINERS) 292 examples_per_trainer = int(n / NUM_TRAINERS) 293 return [ 294 FeatureSet( 295 dense_features=training_examples.dense_features[ 296 start : start + examples_per_trainer, : 297 ], 298 sparse_features=training_examples.sparse_features[ 299 start : start + examples_per_trainer 300 ], 301 values=training_examples.values[start : start + examples_per_trainer], 302 ) 303 for start in range(0, n, examples_per_trainer) 304 ] 305 306 307shutdown_signal = threading.Condition() 308 309 310def set_shutdown_signal(): 311 global shutdown_signal 312 with shutdown_signal: 313 shutdown_signal.notify() 314 315 316class DdpUnderDistAutogradTest(RpcAgentTestFixture): 317 @property 318 def world_size(self) -> int: 319 return WORLD_SIZE 320 321 def remote_worker_name(self) -> str: 322 # The name has to be consistent with that in 'dist_init' decorator. 323 return f"worker{REMOTE_WORKER_RANK}" 324 325 def trainer_name(self, rank): 326 # The name has to be consistent with that in 'dist_init' decorator. 327 return f"worker{rank}" 328 329 def _remote_worker_process(self, ddp_mode): 330 gLogger.info("The remote worker is running.") 331 dist.init_process_group( 332 backend="gloo", 333 init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name), 334 world_size=self.world_size, 335 rank=self.rank, 336 ) 337 338 if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE): 339 # new_group needs to be called on ranks. 340 dist.new_group(TRAINER_RANKS) 341 342 global shutdown_signal 343 with shutdown_signal: 344 shutdown_signal.wait() 345 gLogger.info("Exiting remote worker.") 346 dist.destroy_process_group() 347 348 def _trainer_process(self, rank: int): 349 gLogger.info("Running the trainer #%s...", rank) 350 gLogger.info( 351 "Initing trainer process group by trainer #%s with ranks %s", rank, TRAINER_RANKS 352 ) 353 dist.init_process_group( 354 backend="gloo", 355 init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name), 356 world_size=self.world_size, 357 rank=self.rank, 358 ) 359 360 gLogger.info("Waiting for shutdown signal on trainer #%s...", rank) 361 362 global shutdown_signal 363 with shutdown_signal: 364 shutdown_signal.wait() 365 gLogger.info("Exiting the trainer #%s...", rank) 366 dist.destroy_process_group() 367 368 def _master_process(self, ddp_mode: DdpMode, simulate_uneven_inputs: bool): 369 gLogger.info("Running the master process...") 370 dist.init_process_group( 371 backend="gloo", 372 init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name), 373 world_size=self.world_size, 374 rank=self.rank, 375 ) 376 377 remote_em_rref = rpc.remote( 378 self.remote_worker_name(), RemoteEM, args=(NUM_EM_ROW, D_SPARSE) 379 ) 380 remote_net_rref = rpc.remote( 381 self.remote_worker_name(), RemoteNet, args=(D_DENSE + D_SPARSE, D_HID) 382 ) 383 gLogger.info("Created remote rrefs on master") 384 self.do_test_on_master( 385 ddp_mode, simulate_uneven_inputs, remote_em_rref, remote_net_rref 386 ) 387 388 def do_test_on_master( 389 self, 390 ddp_mode: DdpMode, 391 simulate_uneven_inputs: bool, 392 remote_em_rref: rpc.RRef, 393 remote_net_rref: rpc.RRef, 394 ): 395 if simulate_uneven_inputs: 396 gLogger.info( 397 "Running DDP + RPC test with simulating uneven inputs across trainers." 398 ) 399 400 trainer_rrefs = [] 401 for rank in TRAINER_RANKS: 402 trainer = self.trainer_name(rank) 403 trainer_rrefs.append( 404 rpc.remote( 405 trainer, 406 Trainer, 407 args=(remote_em_rref, remote_net_rref, ddp_mode, rank), 408 ) 409 ) 410 411 if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE): 412 # new_group needs to be called on ranks. 413 dist.new_group(TRAINER_RANKS) 414 415 training_examples = get_training_examples() 416 for _ in range(3): 417 futures = [] 418 num_trainers = len(trainer_rrefs) 419 for idx, trainer_rref in enumerate(trainer_rrefs): 420 # Half the trainers will deplete inputs earlier than the rest. 421 trainer_has_less_inputs = ( 422 simulate_uneven_inputs and idx < num_trainers // 2 423 ) 424 futures.append( 425 _remote_method_async( 426 Trainer.train_batch, 427 trainer_rref, 428 training_examples[idx], 429 trainer_has_less_inputs, 430 simulate_uneven_inputs, 431 ) 432 ) 433 434 for future in futures: 435 ddp_grads, non_ddp_grads = future.wait() 436 # When there are uneven inputs, it is not necessary that grads 437 # cancel each other out, since some trainers contribute 0 grad. 438 if not simulate_uneven_inputs: 439 for grad in ddp_grads: 440 self.assertEqual( 441 grad, 442 torch.zeros_like(grad), 443 msg=f"The grad for any ddp parameter should be zeros, because " 444 "the training examples' grads cancel each other. Received " 445 f"gradient {grad}", 446 ) 447 for grad in non_ddp_grads: 448 self.assertNotEqual( 449 grad, 450 torch.zeros_like(grad), 451 msg="The grad for any non-ddp parameter shouldn't be zeros", 452 ) 453 454 # Destroy process groups 455 for idx, trainer_rref in enumerate(trainer_rrefs): 456 _remote_method_async(Trainer.destroy_pg, trainer_rref).wait() 457 458 # Send shutdown signals. 459 for rank in TRAINER_RANKS: 460 trainer = self.trainer_name(rank) 461 rpc.rpc_sync(trainer, set_shutdown_signal, args=()) 462 463 rpc.rpc_sync(self.remote_worker_name(), set_shutdown_signal, args=()) 464 465 def _do_test(self, ddp_mode, simulate_uneven_inputs=False): 466 if self.rank == MASTER_RANK: 467 self._master_process(ddp_mode, simulate_uneven_inputs) 468 elif self.rank == REMOTE_WORKER_RANK: 469 self._remote_worker_process(ddp_mode) 470 elif self.rank in TRAINER_RANKS: 471 self._trainer_process(self.rank) 472 else: 473 raise RuntimeError(f"Unknown process rank: {self.rank}") 474 475 @requires_gloo() 476 @dist_init 477 def test_backward_no_ddp(self): 478 self._do_test(DdpMode.NONE) 479 480 @requires_gloo() 481 @dist_init 482 def test_backward_ddp_outside(self): 483 self._do_test(DdpMode.OUTSIDE) 484 485 @requires_gloo() 486 @dist_init 487 def test_backward_ddp_outside_uneven_inputs(self): 488 self._do_test(DdpMode.OUTSIDE, simulate_uneven_inputs=True) 489 490 @requires_gloo() 491 @dist_init 492 def test_backward_ddp_inside(self): 493 self._do_test(DdpMode.INSIDE) 494 495 496# Common utils for both CPU and CUDA test suites 497class CommonDdpComparisonTest(RpcAgentTestFixture): 498 @property 499 def world_size(self) -> int: 500 return NUM_TRAINERS 501 502 def trainer_name(self, rank): 503 # The name has to be consistent with that in 'dist_init' decorator. 504 return f"worker{rank}" 505 506 @staticmethod 507 def get_remote_grads(rref, context_id): 508 return dist_autograd.get_gradients(context_id)[rref.local_value().weight] 509 510 511class DdpComparisonTest(CommonDdpComparisonTest): 512 def _run_test_ddp_comparision(self, simulate_uneven_inputs=False): 513 gLogger.info("Running trainer rank: %s", self.rank) 514 # Each trainer uses a different random seed. Otherwise, they are going 515 # to have exactly the same initial model parameters, input, and 516 # therefore grads. That means the grads will be the same before and 517 # after DDP's all-reduce. 518 torch.manual_seed(self.rank) 519 dist.init_process_group( 520 backend="gloo", 521 # Postfix file_name with "pg" since file_name is also used by RPC agent 522 init_method=INIT_METHOD_TEMPLATE.format(file_name=f"{self.file_name}_pg"), 523 world_size=self.world_size, 524 rank=self.rank, 525 ) 526 net = nn.Linear(2, 3) 527 ddp_net = DistributedDataParallel(net) 528 529 # Odd ranks join early if simulate_uneven_inputs. 530 num_inputs = 1 531 if simulate_uneven_inputs: 532 if self.rank % 2 == 0: 533 num_inputs += 2 534 inputs_list = [torch.rand((3, 2)) for _ in range(num_inputs)] 535 536 if simulate_uneven_inputs: 537 gLogger.info("Rank %s training with %s inputs.", self.rank, len(inputs_list)) 538 539 # Use distributed autograd. The gradients will be in RPC context map. 540 grads_dict = {} 541 with ddp_net.join(simulate_uneven_inputs): 542 for i, inputs in enumerate(inputs_list): 543 with dist_autograd.context() as context_id: 544 loss = ddp_net(inputs).norm() 545 dist_autograd.backward(context_id, [loss]) 546 grads_dict = dist_autograd.get_gradients(context_id) 547 gLogger.info("Trainer #%s got grad dict: %s", self.rank, grads_dict) 548 549 # Use local autograd. The gradients will be in each variable's '.grad'. 550 ddp_net.zero_grad() 551 loss = ddp_net(inputs).norm() 552 loss.backward() 553 554 # The gradients should be the same 555 for param in net.parameters(): 556 self.assertTrue( 557 param in grads_dict, 558 msg=f"Param {param} is not in dist_auto grad dict {grads_dict} for iteration {i}", 559 ) 560 self.assertEqual( 561 grads_dict[param], 562 param.grad, 563 msg=f"The grads for param {param} are different under local " 564 f"and dist autograd: {param.grad} \n---\n {grads_dict[param]} for iteration {i}", 565 ) 566 dist.destroy_process_group() 567 568 @requires_gloo() 569 @dist_init 570 def test_ddp_comparison(self): 571 self._run_test_ddp_comparision() 572 573 @requires_gloo() 574 @dist_init 575 def test_ddp_comparison_uneven_inputs(self): 576 # test with simulating uneven inputs in DDP 577 self._run_test_ddp_comparision(simulate_uneven_inputs=True) 578 579 @requires_gloo() 580 @dist_init 581 def test_ddp_dist_autograd_sparse_grads(self): 582 # Each trainer uses a different random seed. Otherwise, they are going 583 # to have exactly the same initial model parameters, input, and 584 # therefore grads. That means the grads will be the same before and 585 # after DDP's all-reduce. 586 torch.manual_seed(self.rank) 587 dist.init_process_group( 588 backend="gloo", 589 init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name), 590 world_size=self.world_size, 591 rank=self.rank, 592 ) 593 594 model = nn.EmbeddingBag(10, 3, sparse=True) 595 ddp_model = DistributedDataParallel(model) 596 597 # Different inputs for each 598 input = torch.LongTensor(10).random_(0, 10) 599 offsets = torch.LongTensor([0, 4]) 600 601 # Run local. 602 loss = ddp_model(input, offsets).sum() 603 loss.backward() 604 605 with dist_autograd.context() as context_id: 606 loss = ddp_model(input, offsets).sum() 607 dist_autograd.backward(context_id, [loss]) 608 grads_dict = dist_autograd.get_gradients(context_id) 609 self.assertEqual(1, len(grads_dict)) 610 self.assertEqual(model.weight.grad, grads_dict[model.weight]) 611 612 @requires_gloo() 613 @dist_init 614 def test_ddp_dist_autograd_local_vs_remote(self): 615 # Each trainer uses a different random seed. Otherwise, they are going 616 # to have exactly the same initial model parameters, input, and 617 # therefore grads. That means the grads will be the same before and 618 # after DDP's all-reduce. 619 torch.manual_seed(self.rank) 620 dist.init_process_group( 621 backend="gloo", 622 init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name), 623 world_size=self.world_size, 624 rank=self.rank, 625 ) 626 627 # Use two different remote device input string, w/ and w/o the default 628 # device string "cpu", respectively. 629 for remote_device in ["worker0/cpu", "worker0"]: 630 remote_layer1 = RemoteModule( 631 remote_device=remote_device, module_cls=nn.Linear, args=(10, 5, False) 632 ) 633 layer1 = nn.Linear(10, 5, False) 634 # Start with the same parameters for remote and local 635 layer1.weight = remote_layer1.module_rref.to_here().weight 636 637 # Run local case. 638 layer2 = nn.Linear(5, 1) 639 inputs = torch.rand((10, 10)) 640 ddp_model = DistributedDataParallel(layer2) 641 loss = ddp_model(layer1(inputs)).sum() 642 loss.backward() 643 644 # Run remote case. 645 with dist_autograd.context() as context_id: 646 loss = ddp_model(remote_layer1(inputs)).sum() 647 dist_autograd.backward(context_id, [loss]) 648 grads_dict = dist_autograd.get_gradients(context_id) 649 dist.barrier() 650 self.assertEqual(layer2.weight.grad, grads_dict[layer2.weight]) 651 self.assertEqual( 652 layer1.weight.grad, 653 rpc.rpc_sync( 654 "worker0", 655 CommonDdpComparisonTest.get_remote_grads, 656 args=(remote_layer1.module_rref, context_id), 657 ), 658 ) 659 660 661class CudaDdpComparisonTest(CommonDdpComparisonTest): 662 @skip_if_lt_x_gpu(NUM_TRAINERS) 663 @requires_nccl() 664 @dist_init 665 @skip_if_rocm_multiprocess 666 def test_ddp_dist_autograd_local_vs_remote_gpu(self): 667 # Each trainer uses a different random seed. Otherwise, they are going 668 # to have exactly the same initial model parameters, input, and 669 # therefore grads. That means the grads will be the same before and 670 # after DDP's all-reduce. 671 torch.manual_seed(self.rank) 672 dist.init_process_group( 673 backend="gloo", 674 init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name), 675 world_size=self.world_size, 676 rank=self.rank, 677 ) 678 679 remote_layer1 = RemoteModule( 680 remote_device="worker0/cpu", module_cls=nn.Linear, args=(10, 7, False) 681 ) 682 layer1 = nn.Linear(10, 7, False) 683 # Start with the same parameters for remote and local 684 layer1.weight = remote_layer1.module_rref.to_here().weight 685 686 layer2 = nn.Linear(7, 5).cuda(self.rank) 687 ddp_layer2 = DistributedDataParallel(layer2, device_ids=[self.rank]) 688 689 remote_layer3 = RemoteModule( 690 remote_device="worker0/cpu", module_cls=nn.Linear, args=(5, 3, False) 691 ) 692 layer3 = nn.Linear(5, 3, False) 693 # Start with the same parameters for remote and local 694 layer3.weight = remote_layer3.module_rref.to_here().weight 695 696 layer4 = nn.Linear(3, 1).cuda(self.rank) 697 ddp_layer4 = DistributedDataParallel(layer4, device_ids=[self.rank]) 698 699 # Run local case. 700 inputs = torch.rand((10, 10)) 701 loss = ddp_layer4( 702 layer3(ddp_layer2(layer1(inputs).cuda(self.rank)).cpu()).cuda(self.rank) 703 ).sum() 704 loss.backward() 705 706 # Run remote case. 707 with dist_autograd.context() as context_id: 708 loss = ddp_layer4( 709 remote_layer3( 710 ddp_layer2(remote_layer1(inputs).cuda(self.rank)).cpu() 711 ).cuda(self.rank) 712 ).sum() 713 dist_autograd.backward(context_id, [loss]) 714 grads_dict = dist_autograd.get_gradients(context_id) 715 dist.barrier() 716 self.assertEqual( 717 layer1.weight.grad, 718 rpc.rpc_sync( 719 "worker0", 720 CommonDdpComparisonTest.get_remote_grads, 721 args=(remote_layer1.module_rref, context_id), 722 ), 723 ) 724 self.assertEqual(layer2.weight.grad, grads_dict[layer2.weight]) 725 self.assertEqual( 726 layer3.weight.grad, 727 rpc.rpc_sync( 728 "worker0", 729 CommonDdpComparisonTest.get_remote_grads, 730 args=(remote_layer3.module_rref, context_id), 731 ), 732 ) 733 self.assertEqual(layer4.weight.grad, grads_dict[layer4.weight]) 734