xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3import contextlib
4import enum
5import logging
6import os
7import threading
8from typing import NamedTuple
9
10import torch
11import torch.distributed as dist
12import torch.distributed.autograd as dist_autograd
13import torch.nn as nn
14from torch.distributed import rpc
15from torch.distributed.nn import RemoteModule
16from torch.nn.parallel import DistributedDataParallel
17from torch.testing._internal.common_distributed import (
18    requires_gloo,
19    requires_nccl,
20    skip_if_lt_x_gpu,
21    skip_if_rocm_multiprocess,
22)
23from torch.testing._internal.dist_utils import INIT_METHOD_TEMPLATE, dist_init
24from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
25    RpcAgentTestFixture,
26)
27
28
29NUM_EM_ROW = 2
30D_SPARSE = 3
31D_DENSE = 2
32D_HID = 3
33D_OUT = 1
34NUM_TRAINERS = 4
35# Trainers + the master + the remote worker
36WORLD_SIZE = NUM_TRAINERS + 2
37TRAINER_RANKS = list(range(NUM_TRAINERS))
38REMOTE_WORKER_RANK = TRAINER_RANKS[-1] + 1
39MASTER_RANK = REMOTE_WORKER_RANK + 1
40
41
42class DdpMode(enum.Enum):
43    # Don't apply DDP
44    NONE = enum.auto()
45    # Apply DDP to the top level nn.Module
46    OUTSIDE = enum.auto()
47    # Embed DDP inside the top level nn.Module
48    INSIDE = enum.auto()
49
50
51def init_logger():
52    logger = logging.getLogger(__name__)
53    level = logging.DEBUG if "debug" in os.environ else logging.INFO
54    logger.setLevel(level)
55    console = logging.StreamHandler()
56    formatter = logging.Formatter(
57        "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s"
58    )
59    console.setFormatter(formatter)
60    console.setLevel(level)
61    # add the handlers to the logger
62    logger.addHandler(console)
63    logger.propagate = False
64    return logger
65
66
67gLogger = init_logger()
68
69
70class FeatureSet(NamedTuple):
71    """ A feature set has 2 types of features"""
72
73    dense_features: torch.Tensor
74    sparse_features: torch.LongTensor
75    values: torch.Tensor
76
77
78def _call_method(method, rref, *args, **kwargs):
79    return method(rref.local_value(), *args, **kwargs)
80
81
82def _remote_method(method, rref, *args, **kwargs):
83    args_tup = tuple([method, rref] + list(args))
84    return rpc.rpc_sync(rref.owner(), _call_method, args=args_tup, kwargs=kwargs)
85
86
87def _remote_method_async(method, rref, *args, **kwargs):
88    args_tup = tuple([method, rref] + list(args))
89    return rpc.rpc_async(rref.owner(), _call_method, args=args_tup, kwargs=kwargs)
90
91
92class RemoteEM(nn.Module):
93    def __init__(self, num_embeddings: int, embedding_dim: int):
94        gLogger.info("Initing RemoteEM with %s %s", num_embeddings, embedding_dim)
95        super().__init__()
96        init_em = [0.5] * embedding_dim
97        self.em = nn.EmbeddingBag(
98            num_embeddings,
99            embedding_dim,
100            _weight=torch.tensor([init_em] * num_embeddings),
101        )
102
103    def forward(self, input: torch.Tensor):
104        gLogger.debug("Running RemoteEM.forward() on: %s", input)
105        return self.em(input, offsets=torch.LongTensor(range(input.shape[0])))
106
107
108# Return a linear module with predefined parameters.
109def getLinear(d_in, d_out):
110    l = nn.Linear(d_in, d_out, bias=False)
111    w = torch.ones((d_out, d_in))
112    w[0][0] = -1
113    w.requires_grad_()
114    l.weight.data = w
115    return l
116
117
118class RemoteNet(nn.Module):
119    def __init__(self, d_in: int, d_out: int):
120        gLogger.info("Initing RemoteNet with %s %s", d_in, d_out)
121        super().__init__()
122        self.fc = getLinear(d_in, d_out)
123        self.relu = nn.ReLU()
124
125    def forward(self, input: torch.Tensor):
126        gLogger.debug("Running RemoteNet.forward() on: %s", input)
127        return self.relu(self.fc(input))
128
129
130class HybridModel(nn.Module):
131    def __init__(
132        self,
133        remote_em_rref: rpc.RRef,
134        remote_net_rref: rpc.RRef,
135        process_group_for_ddp: dist.ProcessGroup = None,
136    ):
137        super().__init__()
138        self.remote_em_rref = remote_em_rref
139        self.remote_net_rref = remote_net_rref
140        self.fc1 = getLinear(D_DENSE, D_DENSE)
141        self.fc2 = getLinear(D_HID, D_OUT)
142
143        self.non_ddp_params = tuple(self.fc1.parameters()) + tuple(
144            self.fc2.parameters()
145        )
146        self.ddp_params = ()
147
148        if process_group_for_ddp is not None:
149            self.non_ddp_params, self.ddp_params = (
150                tuple(self.fc1.parameters()),
151                tuple(self.fc2.parameters()),
152            )
153            gLogger.info("Use DDP for the second local net.")
154            self.fc2 = DistributedDataParallel(
155                self.fc2, check_reduction=True, process_group=process_group_for_ddp
156            )
157
158        gLogger.info(
159            "HybridModel has %s groups of parameters.", len(list(self.parameters()))
160        )
161
162    def forward(self, input: FeatureSet):
163        gLogger.debug("Running HybridModel.forward on %s", input)
164        sparse = _remote_method(
165            RemoteEM.forward, self.remote_em_rref, input.sparse_features
166        )
167        # The same size of mini batch.
168        assert sparse.shape[0] == input.dense_features.shape[0]
169        dense = self.fc1(input.dense_features)
170        x = torch.cat((dense, sparse), 1)
171        gLogger.debug("Concatenated feature: %s", x)
172        x = _remote_method(RemoteNet.forward, self.remote_net_rref, x)
173        return self.fc2(x)
174
175
176class Trainer:
177    def __init__(
178        self,
179        remote_em_rref: rpc.RRef,
180        remote_net_rref: rpc.RRef,
181        ddp_mode: DdpMode,
182        rank: int,
183    ):
184        self.rank = rank
185        self.trainer_group = (
186            dist.new_group(TRAINER_RANKS)
187            if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE)
188            else None
189        )
190        self.remote_em_rref = remote_em_rref
191        self.remote_net_rref = remote_net_rref
192        self.hybrid_module = HybridModel(
193            self.remote_em_rref,
194            self.remote_net_rref,
195            self.trainer_group if ddp_mode in (DdpMode.INSIDE,) else None,
196        )
197        self.ddp_params, self.non_ddp_params = (
198            self.hybrid_module.ddp_params,
199            self.hybrid_module.non_ddp_params,
200        )
201        if ddp_mode == DdpMode.OUTSIDE:
202            gLogger.info("Wrapping the whole hybrid module into DDP.")
203            self.ddp_params += self.non_ddp_params
204            self.non_ddp_params = ()
205            self.hybrid_module = DistributedDataParallel(
206                self.hybrid_module,
207                check_reduction=True,
208                process_group=self.trainer_group,
209            )
210        gLogger.info(
211            "Succeeded in creating a HybridModel instance with "
212            "%s ddp params and %s other local params.",
213            len(self.ddp_params), len(self.non_ddp_params)
214        )
215
216    def destroy_pg(self):
217        if self.trainer_group:
218            dist.destroy_process_group(self.trainer_group)
219
220    def train_batch(
221        self,
222        mini_batch: FeatureSet,
223        trainer_has_less_inputs: bool,
224        simulate_uneven_inputs: bool,
225    ):
226        grads_dict = None
227
228        if not simulate_uneven_inputs:
229            input_batches = [mini_batch]
230        else:
231            # Split into microbatches, and trim to simulate uneven inputs.
232            dense_features = mini_batch.dense_features
233            sparse_features = mini_batch.sparse_features
234            values = mini_batch.values
235
236            dense_microbatch = torch.split(dense_features, 2)
237            sparse_microbatch = torch.split(sparse_features, 2)
238            values_microbatch = torch.split(values, 2)
239            batches = []
240            for d, s, v in zip(dense_microbatch, sparse_microbatch, values_microbatch):
241                feature_set = FeatureSet(dense_features=d, sparse_features=s, values=v)
242                batches.append(feature_set)
243
244            if trainer_has_less_inputs:
245                input_batches = batches[: len(batches) // 2]
246                gLogger.info(
247                    "Trainer reduced input patches from %s "
248                    "to %s to simulate uneven inputs.",
249                    len(batches), len(input_batches)
250                )
251            else:
252                input_batches = batches
253
254        with self.hybrid_module.join() if simulate_uneven_inputs else contextlib.nullcontext():
255            for b in input_batches:
256                with dist_autograd.context() as context_id:
257                    output = self.hybrid_module.forward(b)
258                    loss = (output * mini_batch.values).sum()
259                    dist_autograd.backward(context_id, [loss])
260                    grads_dict = dist_autograd.get_gradients(context_id)
261                    gLogger.info(
262                        "Loss is %s for mini batch: %s. "
263                        "Grads dict has %s entries: %s", loss, mini_batch, len(grads_dict), grads_dict
264                    )
265        return (
266            tuple(grads_dict[param] for param in self.ddp_params),
267            tuple(grads_dict[param] for param in self.non_ddp_params),
268        )
269
270
271def get_training_examples():
272    n = 16
273    training_examples = FeatureSet(
274        dense_features=torch.zeros((n, D_DENSE)),
275        sparse_features=torch.zeros(n, dtype=torch.long),
276        values=torch.zeros(n),
277    )
278    idx = 0
279    # Every example has another one that has exactly the same features but an
280    # opposite value. Therefore, their grads cancel each other in all-reduce.
281    for value in (-1, 1):
282        for x in (-1.0 * value, 1.0 * value):
283            for y in (1.0 * value, -1.0 * value):
284                for z in (0, 1):
285                    training_examples.dense_features[idx, :] = torch.tensor((x, y))
286                    training_examples.sparse_features[idx] = z
287                    training_examples.values[idx] = value
288                    idx += 1
289
290    # Split the examples among NUM_TRAINERS trainers
291    assert 0 == (n % NUM_TRAINERS)
292    examples_per_trainer = int(n / NUM_TRAINERS)
293    return [
294        FeatureSet(
295            dense_features=training_examples.dense_features[
296                start : start + examples_per_trainer, :
297            ],
298            sparse_features=training_examples.sparse_features[
299                start : start + examples_per_trainer
300            ],
301            values=training_examples.values[start : start + examples_per_trainer],
302        )
303        for start in range(0, n, examples_per_trainer)
304    ]
305
306
307shutdown_signal = threading.Condition()
308
309
310def set_shutdown_signal():
311    global shutdown_signal
312    with shutdown_signal:
313        shutdown_signal.notify()
314
315
316class DdpUnderDistAutogradTest(RpcAgentTestFixture):
317    @property
318    def world_size(self) -> int:
319        return WORLD_SIZE
320
321    def remote_worker_name(self) -> str:
322        # The name has to be consistent with that in 'dist_init' decorator.
323        return f"worker{REMOTE_WORKER_RANK}"
324
325    def trainer_name(self, rank):
326        # The name has to be consistent with that in 'dist_init' decorator.
327        return f"worker{rank}"
328
329    def _remote_worker_process(self, ddp_mode):
330        gLogger.info("The remote worker is running.")
331        dist.init_process_group(
332            backend="gloo",
333            init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
334            world_size=self.world_size,
335            rank=self.rank,
336        )
337
338        if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE):
339            # new_group needs to be called on ranks.
340            dist.new_group(TRAINER_RANKS)
341
342        global shutdown_signal
343        with shutdown_signal:
344            shutdown_signal.wait()
345        gLogger.info("Exiting remote worker.")
346        dist.destroy_process_group()
347
348    def _trainer_process(self, rank: int):
349        gLogger.info("Running the trainer #%s...", rank)
350        gLogger.info(
351            "Initing trainer process group by trainer #%s with ranks %s", rank, TRAINER_RANKS
352        )
353        dist.init_process_group(
354            backend="gloo",
355            init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
356            world_size=self.world_size,
357            rank=self.rank,
358        )
359
360        gLogger.info("Waiting for shutdown signal on trainer #%s...", rank)
361
362        global shutdown_signal
363        with shutdown_signal:
364            shutdown_signal.wait()
365        gLogger.info("Exiting the trainer #%s...", rank)
366        dist.destroy_process_group()
367
368    def _master_process(self, ddp_mode: DdpMode, simulate_uneven_inputs: bool):
369        gLogger.info("Running the master process...")
370        dist.init_process_group(
371            backend="gloo",
372            init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
373            world_size=self.world_size,
374            rank=self.rank,
375        )
376
377        remote_em_rref = rpc.remote(
378            self.remote_worker_name(), RemoteEM, args=(NUM_EM_ROW, D_SPARSE)
379        )
380        remote_net_rref = rpc.remote(
381            self.remote_worker_name(), RemoteNet, args=(D_DENSE + D_SPARSE, D_HID)
382        )
383        gLogger.info("Created remote rrefs on master")
384        self.do_test_on_master(
385            ddp_mode, simulate_uneven_inputs, remote_em_rref, remote_net_rref
386        )
387
388    def do_test_on_master(
389        self,
390        ddp_mode: DdpMode,
391        simulate_uneven_inputs: bool,
392        remote_em_rref: rpc.RRef,
393        remote_net_rref: rpc.RRef,
394    ):
395        if simulate_uneven_inputs:
396            gLogger.info(
397                "Running DDP + RPC test with simulating uneven inputs across trainers."
398            )
399
400        trainer_rrefs = []
401        for rank in TRAINER_RANKS:
402            trainer = self.trainer_name(rank)
403            trainer_rrefs.append(
404                rpc.remote(
405                    trainer,
406                    Trainer,
407                    args=(remote_em_rref, remote_net_rref, ddp_mode, rank),
408                )
409            )
410
411        if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE):
412            # new_group needs to be called on ranks.
413            dist.new_group(TRAINER_RANKS)
414
415        training_examples = get_training_examples()
416        for _ in range(3):
417            futures = []
418            num_trainers = len(trainer_rrefs)
419            for idx, trainer_rref in enumerate(trainer_rrefs):
420                # Half the trainers will deplete inputs earlier than the rest.
421                trainer_has_less_inputs = (
422                    simulate_uneven_inputs and idx < num_trainers // 2
423                )
424                futures.append(
425                    _remote_method_async(
426                        Trainer.train_batch,
427                        trainer_rref,
428                        training_examples[idx],
429                        trainer_has_less_inputs,
430                        simulate_uneven_inputs,
431                    )
432                )
433
434            for future in futures:
435                ddp_grads, non_ddp_grads = future.wait()
436                # When there are uneven inputs, it is not necessary that grads
437                # cancel each other out, since some trainers contribute 0 grad.
438                if not simulate_uneven_inputs:
439                    for grad in ddp_grads:
440                        self.assertEqual(
441                            grad,
442                            torch.zeros_like(grad),
443                            msg=f"The grad for any ddp parameter should be zeros, because "
444                            "the training examples' grads cancel each other. Received "
445                            f"gradient {grad}",
446                        )
447                for grad in non_ddp_grads:
448                    self.assertNotEqual(
449                        grad,
450                        torch.zeros_like(grad),
451                        msg="The grad for any non-ddp parameter shouldn't be zeros",
452                    )
453
454        # Destroy process groups
455        for idx, trainer_rref in enumerate(trainer_rrefs):
456            _remote_method_async(Trainer.destroy_pg, trainer_rref).wait()
457
458        # Send shutdown signals.
459        for rank in TRAINER_RANKS:
460            trainer = self.trainer_name(rank)
461            rpc.rpc_sync(trainer, set_shutdown_signal, args=())
462
463        rpc.rpc_sync(self.remote_worker_name(), set_shutdown_signal, args=())
464
465    def _do_test(self, ddp_mode, simulate_uneven_inputs=False):
466        if self.rank == MASTER_RANK:
467            self._master_process(ddp_mode, simulate_uneven_inputs)
468        elif self.rank == REMOTE_WORKER_RANK:
469            self._remote_worker_process(ddp_mode)
470        elif self.rank in TRAINER_RANKS:
471            self._trainer_process(self.rank)
472        else:
473            raise RuntimeError(f"Unknown process rank: {self.rank}")
474
475    @requires_gloo()
476    @dist_init
477    def test_backward_no_ddp(self):
478        self._do_test(DdpMode.NONE)
479
480    @requires_gloo()
481    @dist_init
482    def test_backward_ddp_outside(self):
483        self._do_test(DdpMode.OUTSIDE)
484
485    @requires_gloo()
486    @dist_init
487    def test_backward_ddp_outside_uneven_inputs(self):
488        self._do_test(DdpMode.OUTSIDE, simulate_uneven_inputs=True)
489
490    @requires_gloo()
491    @dist_init
492    def test_backward_ddp_inside(self):
493        self._do_test(DdpMode.INSIDE)
494
495
496# Common utils for both CPU and CUDA test suites
497class CommonDdpComparisonTest(RpcAgentTestFixture):
498    @property
499    def world_size(self) -> int:
500        return NUM_TRAINERS
501
502    def trainer_name(self, rank):
503        # The name has to be consistent with that in 'dist_init' decorator.
504        return f"worker{rank}"
505
506    @staticmethod
507    def get_remote_grads(rref, context_id):
508        return dist_autograd.get_gradients(context_id)[rref.local_value().weight]
509
510
511class DdpComparisonTest(CommonDdpComparisonTest):
512    def _run_test_ddp_comparision(self, simulate_uneven_inputs=False):
513        gLogger.info("Running trainer rank: %s", self.rank)
514        # Each trainer uses a different random seed. Otherwise, they are going
515        # to have exactly the same initial model parameters, input, and
516        # therefore grads. That means the grads will be the same before and
517        # after DDP's all-reduce.
518        torch.manual_seed(self.rank)
519        dist.init_process_group(
520            backend="gloo",
521            # Postfix file_name with "pg" since file_name is also used by RPC agent
522            init_method=INIT_METHOD_TEMPLATE.format(file_name=f"{self.file_name}_pg"),
523            world_size=self.world_size,
524            rank=self.rank,
525        )
526        net = nn.Linear(2, 3)
527        ddp_net = DistributedDataParallel(net)
528
529        # Odd ranks join early if simulate_uneven_inputs.
530        num_inputs = 1
531        if simulate_uneven_inputs:
532            if self.rank % 2 == 0:
533                num_inputs += 2
534        inputs_list = [torch.rand((3, 2)) for _ in range(num_inputs)]
535
536        if simulate_uneven_inputs:
537            gLogger.info("Rank %s training with %s inputs.", self.rank, len(inputs_list))
538
539        # Use distributed autograd. The gradients will be in RPC context map.
540        grads_dict = {}
541        with ddp_net.join(simulate_uneven_inputs):
542            for i, inputs in enumerate(inputs_list):
543                with dist_autograd.context() as context_id:
544                    loss = ddp_net(inputs).norm()
545                    dist_autograd.backward(context_id, [loss])
546                    grads_dict = dist_autograd.get_gradients(context_id)
547                gLogger.info("Trainer #%s got grad dict: %s", self.rank, grads_dict)
548
549                # Use local autograd. The gradients will be in each variable's '.grad'.
550                ddp_net.zero_grad()
551                loss = ddp_net(inputs).norm()
552                loss.backward()
553
554                # The gradients should be the same
555                for param in net.parameters():
556                    self.assertTrue(
557                        param in grads_dict,
558                        msg=f"Param {param} is not in dist_auto grad dict {grads_dict} for iteration {i}",
559                    )
560                    self.assertEqual(
561                        grads_dict[param],
562                        param.grad,
563                        msg=f"The grads for param {param} are different under local "
564                        f"and dist autograd: {param.grad} \n---\n {grads_dict[param]} for iteration {i}",
565                    )
566        dist.destroy_process_group()
567
568    @requires_gloo()
569    @dist_init
570    def test_ddp_comparison(self):
571        self._run_test_ddp_comparision()
572
573    @requires_gloo()
574    @dist_init
575    def test_ddp_comparison_uneven_inputs(self):
576        # test with simulating uneven inputs in DDP
577        self._run_test_ddp_comparision(simulate_uneven_inputs=True)
578
579    @requires_gloo()
580    @dist_init
581    def test_ddp_dist_autograd_sparse_grads(self):
582        # Each trainer uses a different random seed. Otherwise, they are going
583        # to have exactly the same initial model parameters, input, and
584        # therefore grads. That means the grads will be the same before and
585        # after DDP's all-reduce.
586        torch.manual_seed(self.rank)
587        dist.init_process_group(
588            backend="gloo",
589            init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
590            world_size=self.world_size,
591            rank=self.rank,
592        )
593
594        model = nn.EmbeddingBag(10, 3, sparse=True)
595        ddp_model = DistributedDataParallel(model)
596
597        # Different inputs for each
598        input = torch.LongTensor(10).random_(0, 10)
599        offsets = torch.LongTensor([0, 4])
600
601        # Run local.
602        loss = ddp_model(input, offsets).sum()
603        loss.backward()
604
605        with dist_autograd.context() as context_id:
606            loss = ddp_model(input, offsets).sum()
607            dist_autograd.backward(context_id, [loss])
608            grads_dict = dist_autograd.get_gradients(context_id)
609            self.assertEqual(1, len(grads_dict))
610            self.assertEqual(model.weight.grad, grads_dict[model.weight])
611
612    @requires_gloo()
613    @dist_init
614    def test_ddp_dist_autograd_local_vs_remote(self):
615        # Each trainer uses a different random seed. Otherwise, they are going
616        # to have exactly the same initial model parameters, input, and
617        # therefore grads. That means the grads will be the same before and
618        # after DDP's all-reduce.
619        torch.manual_seed(self.rank)
620        dist.init_process_group(
621            backend="gloo",
622            init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
623            world_size=self.world_size,
624            rank=self.rank,
625        )
626
627        # Use two different remote device input string, w/ and w/o the default
628        # device string "cpu", respectively.
629        for remote_device in ["worker0/cpu", "worker0"]:
630            remote_layer1 = RemoteModule(
631                remote_device=remote_device, module_cls=nn.Linear, args=(10, 5, False)
632            )
633            layer1 = nn.Linear(10, 5, False)
634            # Start with the same parameters for remote and local
635            layer1.weight = remote_layer1.module_rref.to_here().weight
636
637            # Run local case.
638            layer2 = nn.Linear(5, 1)
639            inputs = torch.rand((10, 10))
640            ddp_model = DistributedDataParallel(layer2)
641            loss = ddp_model(layer1(inputs)).sum()
642            loss.backward()
643
644            # Run remote case.
645            with dist_autograd.context() as context_id:
646                loss = ddp_model(remote_layer1(inputs)).sum()
647                dist_autograd.backward(context_id, [loss])
648                grads_dict = dist_autograd.get_gradients(context_id)
649                dist.barrier()
650                self.assertEqual(layer2.weight.grad, grads_dict[layer2.weight])
651                self.assertEqual(
652                    layer1.weight.grad,
653                    rpc.rpc_sync(
654                        "worker0",
655                        CommonDdpComparisonTest.get_remote_grads,
656                        args=(remote_layer1.module_rref, context_id),
657                    ),
658                )
659
660
661class CudaDdpComparisonTest(CommonDdpComparisonTest):
662    @skip_if_lt_x_gpu(NUM_TRAINERS)
663    @requires_nccl()
664    @dist_init
665    @skip_if_rocm_multiprocess
666    def test_ddp_dist_autograd_local_vs_remote_gpu(self):
667        # Each trainer uses a different random seed. Otherwise, they are going
668        # to have exactly the same initial model parameters, input, and
669        # therefore grads. That means the grads will be the same before and
670        # after DDP's all-reduce.
671        torch.manual_seed(self.rank)
672        dist.init_process_group(
673            backend="gloo",
674            init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
675            world_size=self.world_size,
676            rank=self.rank,
677        )
678
679        remote_layer1 = RemoteModule(
680            remote_device="worker0/cpu", module_cls=nn.Linear, args=(10, 7, False)
681        )
682        layer1 = nn.Linear(10, 7, False)
683        # Start with the same parameters for remote and local
684        layer1.weight = remote_layer1.module_rref.to_here().weight
685
686        layer2 = nn.Linear(7, 5).cuda(self.rank)
687        ddp_layer2 = DistributedDataParallel(layer2, device_ids=[self.rank])
688
689        remote_layer3 = RemoteModule(
690            remote_device="worker0/cpu", module_cls=nn.Linear, args=(5, 3, False)
691        )
692        layer3 = nn.Linear(5, 3, False)
693        # Start with the same parameters for remote and local
694        layer3.weight = remote_layer3.module_rref.to_here().weight
695
696        layer4 = nn.Linear(3, 1).cuda(self.rank)
697        ddp_layer4 = DistributedDataParallel(layer4, device_ids=[self.rank])
698
699        # Run local case.
700        inputs = torch.rand((10, 10))
701        loss = ddp_layer4(
702            layer3(ddp_layer2(layer1(inputs).cuda(self.rank)).cpu()).cuda(self.rank)
703        ).sum()
704        loss.backward()
705
706        # Run remote case.
707        with dist_autograd.context() as context_id:
708            loss = ddp_layer4(
709                remote_layer3(
710                    ddp_layer2(remote_layer1(inputs).cuda(self.rank)).cpu()
711                ).cuda(self.rank)
712            ).sum()
713            dist_autograd.backward(context_id, [loss])
714            grads_dict = dist_autograd.get_gradients(context_id)
715            dist.barrier()
716            self.assertEqual(
717                layer1.weight.grad,
718                rpc.rpc_sync(
719                    "worker0",
720                    CommonDdpComparisonTest.get_remote_grads,
721                    args=(remote_layer1.module_rref, context_id),
722                ),
723            )
724            self.assertEqual(layer2.weight.grad, grads_dict[layer2.weight])
725            self.assertEqual(
726                layer3.weight.grad,
727                rpc.rpc_sync(
728                    "worker0",
729                    CommonDdpComparisonTest.get_remote_grads,
730                    args=(remote_layer3.module_rref, context_id),
731                ),
732            )
733            self.assertEqual(layer4.weight.grad, grads_dict[layer4.weight])
734