xref: /aosp_15_r20/external/pytorch/test/distributed/test_data_parallel.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import contextlib
4import functools
5import io
6from collections import OrderedDict
7from copy import deepcopy
8from itertools import product
9
10import torch
11import torch.nn.functional as F
12import torch.nn.parallel as dp
13from torch import nn
14from torch.cuda.amp import autocast
15from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
16from torch.testing._internal.common_device_type import (
17    dtypes,
18    instantiate_device_type_tests,
19    onlyCUDA,
20    skipMeta,
21)
22from torch.testing._internal.common_utils import (
23    _assertGradAndGradgradChecks,
24    dtype2prec_DONTUSE,
25    gradcheck,
26    run_tests,
27    skip_but_pass_in_sandcastle_if,
28    TestCase,
29)
30
31
32NO_NCCL = not hasattr(torch.distributed, "ProcessGroupNCCL")
33
34# batched grad doesn't support data parallel
35gradcheck = functools.partial(gradcheck, check_batched_grad=False)
36_assertGradAndGradgradChecks = functools.partial(
37    _assertGradAndGradgradChecks, check_batched_grad=False
38)
39
40
41class TestDataParallel(TestCase):
42    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
43    def test_data_parallel_buffers_requiring_grad(self):
44        class TestModule(nn.Module):
45            def __init__(self, t):
46                super().__init__()
47                self.t_rg = nn.Buffer(t)
48                self.t_not_rg = nn.Buffer(t.clone().detach())
49
50            def forward(self, x):
51                return x * self.t_rg + self.t_not_rg
52
53        m = TestModule(
54            torch.randn(100, device="cuda", requires_grad=True, dtype=torch.double)
55        )
56        self.assertTrue(m.t_rg.requires_grad)
57
58        dpm = nn.DataParallel(m, [0, 1])
59        inp = torch.randn(2, 100, device="cuda", dtype=torch.double)
60
61        def fn(t):
62            return dpm(inp)
63
64        gradcheck(fn, (m.t_rg,))
65
66    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
67    def test_data_parallel_rnn(self):
68        class TestModule(torch.nn.Module):
69            def __init__(self) -> None:
70                super().__init__()
71                self.rnn = torch.nn.LSTM(
72                    300, 1024, 1, batch_first=True, bidirectional=True
73                )
74
75            def forward(self, x):
76                self.rnn.flatten_parameters()
77                return self.rnn(x)
78
79        def step(model):
80            opt = torch.optim.SGD(model.parameters(), lr=10)
81            input = torch.ones(4, 4, 300).to(0)
82            output = model(input)
83            loss = F.mse_loss(output[0], torch.zeros_like(output[0]))
84            loss.backward()
85            opt.step()
86
87        with torch.no_grad():
88            model = TestModule().to(0)
89            model_dp = torch.nn.DataParallel(deepcopy(model))
90
91            # make sure DP does not crash when grad is disabled.
92            # See #21108
93            model_dp(torch.rand(2, 4, 300).to(0))
94
95        step(model)
96        step(model_dp)
97
98        for p1, p2 in zip(model.parameters(), model_dp.parameters()):
99            self.assertTrue(p1.allclose(p2))
100
101    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
102    def test_data_parallel_lazy_linear(self):
103        with self.assertRaisesRegex(
104            ValueError, "Attempted to use an uninitialized parameter"
105        ):
106            model_dp = torch.nn.DataParallel(torch.nn.LazyLinear(10).to(0))
107            model_dp(torch.rand(10, 10).to(0))
108
109    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
110    def test_parallel_apply(self):
111        l1 = nn.Linear(10, 5).to("cuda:0", torch.float)
112        l2 = nn.Linear(10, 5).to("cuda:1", torch.float)
113        i1 = torch.randn(2, 10, device="cuda:0", dtype=torch.float)
114        i2 = torch.randn(2, 10, device="cuda:1", dtype=torch.float)
115        expected1 = l1(i1)
116        expected2 = l2(i2)
117        modules = (l1, l2)
118        expected_outputs = (expected1, expected2)
119
120        # each input can be either a collection of positional arguments
121        #                       or an object representing the single argument
122        for inputs in [((i1,), (i2,)), (i1, i2)]:
123            outputs = dp.parallel_apply(modules, inputs, None)
124            for out, expected in zip(outputs, expected_outputs):
125                self.assertEqual(out, expected)
126
127    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
128    def test_parallel_apply_autocast(self):
129        l1 = nn.Linear(10, 5).to("cuda:0", torch.float)
130        l2 = nn.Linear(10, 5).to("cuda:1", torch.float)
131        i1 = torch.randn(2, 10, device="cuda:0", dtype=torch.float)
132        i2 = torch.randn(2, 10, device="cuda:1", dtype=torch.float)
133        with autocast():
134            expected1 = l1(i1)
135            expected2 = l2(i2)
136        modules = (l1, l2)
137        expected_outputs = (expected1, expected2)
138
139        # each input can be either a collection of positional arguments
140        #                       or an object representing the single argument
141        for inputs in [((i1,), (i2,)), (i1, i2)]:
142            with autocast():
143                outputs = dp.parallel_apply(modules, inputs, None)
144            for out, expected in zip(outputs, expected_outputs):
145                self.assertEqual(out, expected)
146
147    @skip_but_pass_in_sandcastle_if(not TEST_CUDA, "CUDA unavailable")
148    def test_parallel_apply_passes_exception(self):
149        # we define and instantiate a module that will throw a KeyError
150        class TestModule(nn.Module):
151            def forward(self, *args):
152                return {}["wonderful"]
153
154        l1 = TestModule().to("cuda", torch.float)
155        # and check that parallel_apply passes on the exception
156        # (we can use a single device twice for this test)
157        with self.assertRaisesRegex(
158            KeyError,
159            "Caught KeyError in replica \\d "
160            "on device 0.\nOriginal Traceback"
161            "[\\s\\S]+wonderful",
162        ):
163            dp.parallel_apply(modules=(l1, l1), inputs=(None, None))
164
165    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
166    def test_data_parallel_multiple_input(self):
167        class TestModule(nn.Module):
168            def forward(self, var1, var2, float1, var3=None):
169                if var3 is None:
170                    return float1 * (var1 * var2)
171                else:
172                    return float1 * (var1 * var2 + var3)
173
174        m = TestModule()
175        var1 = torch.randn(5, 5, dtype=torch.float, requires_grad=True)
176        var2 = torch.randn(5, 5, dtype=torch.float, requires_grad=True)
177        var3 = torch.randn(5, 5, dtype=torch.float, requires_grad=False)
178
179        float1 = torch.randn(1).item()
180
181        expected = m(var1, var2, float1)
182        loss = expected.sum()
183        loss.backward()
184        gvar1_exp = var1.grad.clone()
185        gvar2_exp = var2.grad.clone()
186
187        def local_test(out):
188            with torch.no_grad():
189                var1.grad.fill_(0.0)
190                var2.grad.fill_(0.0)
191            loss = out.sum()
192            loss.backward()
193            self.assertEqual(out, expected)
194            self.assertEqual(gvar1_exp, var1.grad)
195            self.assertEqual(gvar2_exp, var2.grad)
196
197        out = dp.data_parallel(m, (var1, var2, float1), (0, 1))
198        local_test(out)
199
200        out = dp.data_parallel(m, (var1, var2, float1), (1, 0))
201        local_test(out)
202
203        out = dp.data_parallel(m, (var1, var2, float1), (0,))
204        local_test(out)
205
206        with torch.no_grad():
207            var1.grad.fill_(0.0)
208            var2.grad.fill_(0.0)
209        expected = m(var1, var2, float1, var3=var3)
210        loss = expected.sum()
211        loss.backward()
212        gvar1_exp = var1.grad.clone()
213        gvar2_exp = var2.grad.clone()
214
215        dpm = nn.DataParallel(TestModule())
216        out = dpm(var1, var2, float1, var3=var3)
217        local_test(out)
218
219        dpm = nn.DataParallel(TestModule(), device_ids=[0])
220        out = dpm(var1, var2, float1, var3=var3)
221        local_test(out)
222
223        kwarg_wrap = {"var3": var3}
224        out = dp.data_parallel(
225            m, (var1, var2, float1), (0, 1), module_kwargs=kwarg_wrap
226        )
227        local_test(out)
228
229        out = dp.data_parallel(m, (var1, var2, float1), (0,), module_kwargs=kwarg_wrap)
230        local_test(out)
231
232    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
233    def test_data_parallel_small_back(self):
234        l = nn.Linear(10, 5).float().cuda()
235        i = torch.randn(20, 10, dtype=torch.float, device="cuda")
236        out = dp.data_parallel(l, i, (0, 1))
237        self.assertEqual(out, l(i))
238
239    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
240    def test_data_parallel_model_device(self):
241        r"""Test device[0] check at forward time."""
242        l = nn.Linear(2, 2)
243        inp = torch.randn(2, 2)
244        inp_cuda0 = inp.cuda(0)
245        inp_cuda1 = inp.cuda(1)
246
247        error_msg = "module must have its parameters and buffers on device {}"
248
249        @contextlib.contextmanager
250        def dummy_ctx_manager():
251            yield
252
253        def test(inner_m, dp_device, inp, device_ids, should_fail):
254            if device_ids is None:
255                device_ids = list(range(torch.cuda.device_count()))
256
257            if isinstance(device_ids[0], torch.device):
258                expect_device = device_ids[0]
259            else:
260                expect_device = torch.device(f"cuda:{device_ids[0]}")
261
262            if should_fail:
263
264                def assert_correct():
265                    return self.assertRaisesRegex(
266                        RuntimeError, error_msg.format(expect_device)
267                    )
268
269            else:
270                assert_correct = dummy_ctx_manager
271
272            # test DataParallel module
273            dpm = nn.DataParallel(inner_m, device_ids)
274            if dp_device is not None:
275                dpm = dpm.to(dp_device)
276
277            with assert_correct():
278                dpm(inp)
279
280            # test functional
281            with assert_correct():
282                nn.parallel.data_parallel(inner_m.to(dp_device), inp, device_ids)
283
284        test(l.to("cpu"), None, inp, None, should_fail=True)
285        test(l.cuda(1), None, inp_cuda0, None, should_fail=True)
286        test(l.cuda(), None, inp_cuda0, [1, 0], should_fail=True)
287
288        test(l.cuda(), None, inp_cuda0, None, should_fail=False)
289        test(l.cpu(), "cuda", inp_cuda0, None, should_fail=False)
290        test(l.cuda(1), None, inp_cuda1, [1, 0], should_fail=False)
291        test(l.cpu(), "cuda:1", inp_cuda1, [1, 0], should_fail=False)
292
293        s = nn.Sequential(l.cpu())
294        test(s, None, inp, None, should_fail=True)
295        test(s, None, inp, [0, 1], should_fail=True)
296        test(s, None, inp, [1, 0], should_fail=True)
297
298        s = nn.Sequential(deepcopy(l).cpu(), l.cuda())
299        test(s, None, inp, None, should_fail=True)
300        test(s, None, inp, [0, 1], should_fail=True)
301        test(s, None, inp, [1, 0], should_fail=True)
302
303        s = nn.Sequential(l.cuda(), deepcopy(l).cuda(1))
304        test(s, None, inp, None, should_fail=True)
305        test(s, None, inp, [0, 1], should_fail=True)
306        test(s, None, inp, [1, 0], should_fail=True)
307
308        s = nn.Sequential(l.cuda(), deepcopy(l).cuda())
309        test(s, None, inp, None, should_fail=False)
310        test(s, None, inp, [0, 1], should_fail=False)
311        test(s, None, inp, [1, 0], should_fail=True)
312        test(s.cpu(), None, inp, [1, 0], should_fail=True)
313        test(s.cuda(1), None, inp, [1, 0], should_fail=False)
314
315    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
316    def test_data_parallel_model_no_refcycles(self):
317        # Python 2.7 will create reference cycles with the following
318        # Module on multiple GPUs, but Python 3 shouldn't unless
319        # there are refcycles on the PyTorch side (or the defined module)
320        import gc
321
322        class Model(nn.Module):
323            def __init__(self) -> None:
324                super().__init__()
325                self.linear = nn.Linear(1, 1)
326
327            def forward(self, x):
328                return self.linear(x)
329
330        gc.collect()
331        model = nn.DataParallel(Model().cuda())
332        data = torch.randn(1, device="cuda")
333        model(data)
334
335        refcycles = gc.collect()
336        self.assertEqual(refcycles, 0)
337
338    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
339    def test_data_parallel_no_grad(self):
340        test = self
341
342        class Layer(nn.Module):
343            def forward(self, x):
344                test.assertFalse(torch.is_grad_enabled())
345                return x
346
347        l = Layer()
348        i = torch.randn(20, 10, dtype=torch.float, device="cuda")
349        with torch.no_grad():
350            dp.data_parallel(l, i, (0, 1))
351        self.assertRaises(AssertionError, lambda: dp.data_parallel(l, i, (0, 1)))
352
353    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
354    def test_data_parallel(self):
355        l = nn.Linear(10, 5).float().cuda()
356        i = torch.randn(20, 10, dtype=torch.float, device="cuda:1")
357        l.cuda(1)
358        expected_out = l(i)
359        loss = expected_out.sum()
360        loss.backward()
361        expected_grads = []
362        for param in l.parameters():
363            expected_grads.append(param.grad.clone())
364        dev_ids_list = [(0, 1), (1, 0)]
365        for dev_id in dev_ids_list:
366            with torch.cuda.device(dev_id[0]):
367                l.cuda()
368                l.zero_grad()
369                out = dp.data_parallel(l, i, dev_id)
370                loss = out.sum()
371                loss.backward()
372                self.assertEqual(out.get_device(), dev_id[0])
373                self.assertEqual(out, expected_out)
374                for expected, param in zip(expected_grads, l.parameters()):
375                    self.assertEqual(param.grad, expected)
376
377        # Check for None device_ids
378        l = l.cuda()
379        out = dp.data_parallel(l, i)
380
381    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
382    def test_data_parallel_sparse(self):
383        l = nn.Embedding(10, 5, sparse=True).to("cuda:1")
384        i = torch.randint(10, (20, 5), device="cuda:1", dtype=torch.long)
385        expected_out = l(i)
386        loss = expected_out.sum()
387        loss.backward()
388        expected_grads = []
389        for param in l.parameters():
390            expected_grads.append(param.grad.clone())
391        dev_ids_list = [(0, 1), (1, 0)]
392        for dev_id in dev_ids_list:
393            with torch.cuda.device(dev_id[0]):
394                l.cuda()
395                l.zero_grad()
396                out = dp.data_parallel(l, i, dev_id)
397                loss = out.sum()
398                loss.backward()
399                self.assertEqual(out.get_device(), dev_id[0])
400                self.assertEqual(out, expected_out)
401                for expected, param in zip(expected_grads, l.parameters()):
402                    self.assertEqual(param.grad.coalesce(), expected.coalesce())
403
404        # Check for None device_ids
405        l = l.cuda()
406        out = dp.data_parallel(l, i)
407
408    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
409    def test_data_parallel_nested_output(self):
410        def fn(input):
411            return [
412                input,
413                (input.sin(), input.cos(), [input.add(1)]),
414                input,
415                OrderedDict(a=input, b=[input.sin()]),
416            ]
417
418        class Net(nn.Module):
419            def forward(self, input):
420                return fn(input)
421
422        i = torch.randn(2, 2).float().cuda(1)
423        gpus = range(torch.cuda.device_count())
424        output = dp.data_parallel(Net(), i, gpus)
425        self.assertEqual(output, fn(i))
426        self.assertIsInstance(output[0], torch.Tensor)
427        self.assertIsInstance(output[1], tuple)
428        self.assertIsInstance(output[1][0], torch.Tensor)
429        self.assertIsInstance(output[1][1], torch.Tensor)
430        self.assertIsInstance(output[1][2], list)
431        self.assertIsInstance(output[1][2][0], torch.Tensor)
432        self.assertIsInstance(output[2], torch.Tensor)
433        self.assertIsInstance(output[3], dict)
434        self.assertEqual(len(output[3]), 2)
435        self.assertIn("a", output[3])
436        self.assertIn("b", output[3])
437        self.assertIsInstance(output[3]["a"], torch.Tensor)
438        self.assertIsInstance(output[3]["b"], list)
439        self.assertIsInstance(output[3]["b"][0], torch.Tensor)
440
441    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
442    def test_data_parallel_nested_input(self):
443        def fn(input):
444            return input[1][0]
445
446        class Net(nn.Module):
447            def forward(self, *input):
448                return fn(input)
449
450        i = torch.randn(20, 3, dtype=torch.float, device="cuda:1")
451        input = (i.cos(), (i.sin(), i), i.sin())
452        gpus = range(torch.cuda.device_count())
453        output = dp.data_parallel(Net(), input, gpus)
454        self.assertEqual(output, fn(input))
455
456    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
457    def test_data_parallel_module_zero_inputs(self):
458        class TestModule(nn.Module):
459            def forward(self):
460                t = torch.eye(2, 3, device="cuda:0")
461                return t + (1 - t)
462
463        def test_helper(output, expected):
464            self.assertEqual(output.get_device(), 0)
465            self.assertEqual(output, expected)
466
467        expected = torch.ones(2, 3, device="cuda:0")
468        model = TestModule()
469
470        test_helper(nn.DataParallel(model, [0])(), expected)
471        test_helper(nn.DataParallel(model, [0, 1])(), expected)
472        test_helper(dp.data_parallel(model, None, [0]), expected)
473        test_helper(dp.data_parallel(model, (), [0, 1]), expected)
474
475    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
476    def test_data_parallel_device_args(self):
477        cuda0 = torch.device("cuda:0")
478        cuda1 = torch.device("cuda:1")
479
480        # test output_device
481        l = nn.Linear(10, 5).to(cuda0, torch.float)
482        i = torch.randn(20, 10, dtype=torch.float, device=cuda0, requires_grad=True)
483        out = dp.data_parallel(l, i, device_ids=(0, 1), output_device=cuda0)
484        self.assertEqual(out, l(i))
485
486        # test device_ids
487        l = nn.Linear(10, 5).to(cuda0, torch.float)
488        i = torch.randn(20, 10, dtype=torch.float, device=cuda0, requires_grad=True)
489        out = dp.data_parallel(l, i, device_ids=(cuda0, cuda1), output_device=cuda0)
490        self.assertEqual(out, l(i))
491
492    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
493    def test_data_parallel_function_deletion(self):
494        # this test case is originated from #16532
495        def gradient_penalty(net, x):
496            output = net(x)
497            loss = torch.autograd.grad(
498                outputs=output,
499                inputs=x,
500                grad_outputs=x.new_ones(output.size()),
501                create_graph=True,
502                retain_graph=True,
503            )[0].mean()
504            return loss
505
506        net = nn.Linear(4, 1).cuda()
507        dpn = nn.DataParallel(net, [0, 1])
508        x = torch.ones(2, 4, requires_grad=True).cuda()
509
510        dpn.zero_grad()
511        loss = gradient_penalty(dpn, x)
512        loss.backward()
513        grads = [p.grad for p in net.parameters()]
514        self.assertEqual(2, len(grads))
515        self.assertEqual(
516            torch.tensor([[0.25, 0.25, 0.25, 0.25]], device="cuda:0"), grads[0]
517        )
518        self.assertEqual(torch.tensor([0.0], device="cuda:0"), grads[1])
519
520    def _test_scatter(self, tensor):
521        x = tensor.detach().requires_grad_()
522        result = dp.scatter(x, (0, 1))
523        self.assertEqual(len(result), 2)
524        self.assertEqual(result[0], x[:2])
525        self.assertEqual(result[0].get_device(), 0)
526        self.assertEqual(result[1], x[2:])
527        self.assertEqual(result[1].get_device(), 1)
528        grad = result[0].detach().clone().fill_(2)
529        result[0].backward(grad)
530        self.assertEqual(x.grad[:2], grad)
531        self.assertEqual(x.grad[2:], grad.clone().zero_())
532        _assertGradAndGradgradChecks(self, lambda y: dp.scatter(y, (0, 1)), (x,))
533
534    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
535    def test_scatter_cpu(self):
536        self._test_scatter(torch.randn((4, 4), dtype=torch.double))
537
538    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
539    def test_scatter_gpu(self):
540        self._test_scatter(torch.randn((4, 4), dtype=torch.double).cuda())
541
542    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
543    @skip_but_pass_in_sandcastle_if(NO_NCCL, "NCCL needed")
544    def test_data_parallel_complex(self):
545        # We expect complex parameters to be broadcast by view_as_real, e.g. move from C to R^2
546        class Cplx(torch.nn.Module):
547            def __init__(self) -> None:
548                super().__init__()
549                self.cplx = torch.nn.Parameter(
550                    torch.zeros(1, 10, dtype=torch.cfloat).cuda()
551                )
552
553            def forward(self, x):
554                return x + self.cplx
555
556        cplx = torch.nn.DataParallel(Cplx().cuda())
557        input = torch.rand(1, 10, dtype=torch.cfloat).cuda()
558        result = cplx(input)
559        # 2 is the extra real view dimension here
560        self.assertEqual(result.size(), torch.Size([1, 10, 2]))
561        self.assertEqual(result, torch.view_as_real(input))
562
563    def _test_gather(self, output_device):
564        inputs = (
565            torch.randn(2, 4, device="cuda:0", requires_grad=True, dtype=torch.double),
566            torch.randn(2, 4, device="cuda:1", requires_grad=True, dtype=torch.double),
567        )
568        result = dp.gather(inputs, output_device)
569        self.assertEqual(result.size(), torch.Size([4, 4]))
570        self.assertEqual(result[:2], inputs[0])
571        self.assertEqual(result[2:], inputs[1])
572        if output_device != -1:
573            self.assertEqual(result.get_device(), output_device)
574        else:
575            self.assertFalse(result.is_cuda)
576        grad = torch.randn((4, 4), dtype=torch.double)
577        if output_device != -1:
578            grad = grad.cuda(output_device)
579        result.backward(grad)
580        self.assertEqual(inputs[0].grad, grad[:2])
581        self.assertEqual(inputs[1].grad, grad[2:])
582        _assertGradAndGradgradChecks(
583            self, lambda x, y: dp.gather((x, y), output_device), inputs
584        )
585
586        # test scalar inputs, should stack into a vector in this case
587        inputs = (
588            torch.randn((), device="cuda:0", requires_grad=True, dtype=torch.double),
589            torch.randn((), device="cuda:1", requires_grad=True, dtype=torch.double),
590        )
591        result = dp.gather(inputs, output_device)
592        self.assertEqual(result.size(), torch.Size([2]))
593        self.assertEqual(result[0], inputs[0])
594        self.assertEqual(result[1], inputs[1])
595        if output_device != -1:
596            self.assertEqual(result.get_device(), output_device)
597        else:
598            self.assertFalse(result.is_cuda)
599        grad = torch.randn(2, dtype=torch.double)
600        if output_device != -1:
601            grad = grad.cuda(output_device)
602        result.backward(grad)
603        self.assertEqual(inputs[0].grad, grad[0])
604        self.assertEqual(inputs[1].grad, grad[1])
605        _assertGradAndGradgradChecks(
606            self, lambda x, y: dp.gather((x, y), output_device), inputs
607        )
608
609    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
610    def test_gather_cpu(self):
611        self._test_gather(-1)
612
613    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
614    def test_gather_gpu(self):
615        self._test_gather(0)
616
617    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
618    def test_gather_different_len_dicts(self):
619        inputs = (
620            {"a": torch.randn(1, 2, requires_grad=True, device="cuda:0")},
621            {
622                "b": torch.randn(1, 2, requires_grad=True, device="cuda:1"),
623                "a": torch.randn(1, 2, requires_grad=True, device="cuda:1"),
624            },
625        )
626        with self.assertRaises(ValueError):
627            _ = dp.gather(inputs, target_device=0)
628
629    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
630    def test_replicate(self):
631        module = nn.Linear(10, 5).float().cuda()
632        input = torch.randn(2, 10, dtype=torch.float, device="cuda")
633        expected_output = module(input)
634        for devices in [(0, 1), [0, 1]]:
635            replicas = dp.replicate(module, devices)
636            for i, replica in enumerate(replicas):
637                for p in replica.parameters():
638                    self.assertEqual(p.get_device(), i)
639                replica_input = input.cuda(i)
640                self.assertEqual(replica(replica_input), expected_output)
641
642    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
643    def test_replicate_buffers(self):
644        net = nn.Module()
645        net.bn = nn.BatchNorm2d(10)
646        net.cuda()
647        for devices in [(0, 1), [0, 1]]:
648            replicas = dp.replicate(net, devices)
649            for i, replica in enumerate(replicas):
650                self.assertEqual(
651                    replica.bn.running_mean.get_device(),
652                    i,
653                    msg="buffer on wrong device",
654                )
655                self.assertEqual(
656                    replica.bn.running_var.get_device(), i, msg="buffer on wrong device"
657                )
658                self.assertEqual(
659                    replica.bn.num_batches_tracked.get_device(),
660                    i,
661                    msg="buffer on wrong device",
662                )
663
664    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
665    def test_zero_grad(self):
666        # zero_grad should warn about using gradients inside forward
667
668        class Net(torch.nn.Module):
669            def __init__(self, testcase):
670                super().__init__()
671                self._testcase = testcase
672
673            def forward(self, x):
674                with self._testcase.assertWarnsRegex(
675                    UserWarning,
676                    r"Calling \.zero_grad\(\) from a module created with nn\.DataParallel\(\) has no effect.",
677                ):
678                    self.zero_grad()
679                return x
680
681        module = Net(self).cuda()
682        dpm = dp.DataParallel(module)
683        dpm(torch.rand(4, 3, 6, 5))
684
685    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
686    def test_autocast(self):
687        class Model(torch.nn.Linear):
688            def __init__(self) -> None:
689                super().__init__(8, 8)
690
691            @torch.cuda.amp.autocast()
692            def forward(self, input):
693                return super().forward(input)
694
695        model = dp.DataParallel(Model().cuda().to(dtype=torch.float32))
696        input = torch.randn((8, 8), dtype=torch.float32, device="cuda")
697        self.assertTrue(model(input).dtype is torch.float16)
698
699    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
700    def test_save_replica_module(self):
701        # DataParallel replicas can be saved (gh-37182)
702        module = torch.nn.Linear(8, 8).cuda()
703        dpm = torch.nn.parallel.replicate(module, devices=[0, 1], detach=False)
704        data = io.BytesIO()
705        torch.save(dpm, data)
706        dpm = torch.nn.parallel.replicate(module, devices=[0, 1], detach=True)
707        torch.save(dpm, data)
708
709    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
710    def test_strided_grad_layout(self):
711        class ConvNet(nn.Module):
712            def __init__(self, layouts, dtype_list):
713                super().__init__()
714                self.dtypes = dtype_list
715                self.conv0 = torch.nn.Conv2d(8, 16, (2, 2)).to(
716                    memory_format=layouts[0], dtype=dtype_list[0]
717                )
718                self.conv1 = torch.nn.Conv2d(16, 32, (2, 2)).to(
719                    memory_format=layouts[1], dtype=dtype_list[1]
720                )
721                self.conv2 = torch.nn.Conv2d(32, 16, (2, 2)).to(
722                    memory_format=layouts[2], dtype=dtype_list[2]
723                )
724                self.conv3 = torch.nn.Conv2d(16, 8, (2, 2)).to(
725                    memory_format=layouts[3], dtype=dtype_list[3]
726                )
727
728            def forward(self, x):
729                x = x.to(self.dtypes[0])
730                x = self.conv0(x).to(self.dtypes[1])
731                x = self.conv1(x).to(self.dtypes[2])
732                x = self.conv2(x).to(self.dtypes[3])
733                x = self.conv3(x)
734                return x
735
736        layer_formats = (
737            [torch.contiguous_format] * 4,
738            [torch.channels_last] * 2 + [torch.contiguous_format] * 2,
739            [torch.channels_last] * 4,
740        )
741        layer_dtypes = (
742            [torch.float] * 4,
743            [torch.float] * 2 + [torch.half] * 2,
744            [torch.half] * 4,
745        )
746
747        ndevs = torch.cuda.device_count()
748        input = torch.randn(ndevs * 8, 8, 8, 8, device="cuda:0", dtype=torch.float)
749        target = torch.randn(ndevs * 8, 8, 4, 4, device="cuda:0", dtype=torch.float)
750        device_ids = list(range(ndevs))
751
752        with torch.backends.cudnn.flags(
753            enabled=True, deterministic=True, benchmark=False
754        ):
755            for formats, dtype_list in product(layer_formats, layer_dtypes):
756                model_msg = f"formats = {formats} dtypes = {dtypes}"
757                try:
758                    m = ConvNet(formats, dtype_list).cuda(device="cuda:0")
759                    m_dp = dp.DataParallel(deepcopy(m), device_ids=device_ids)
760                    opt = torch.optim.SGD(m.parameters(), lr=0.1)
761                    opt_dp = torch.optim.SGD(m_dp.parameters(), lr=0.1)
762                    has_half = any(p.dtype is torch.half for p in m.parameters())
763                    tol = 1.0e-3 if has_half else 1.0e-5
764                except BaseException:
765                    # Prints case-specific debugging info to narrow down failing case.
766                    print(
767                        "Caught exception during model creation for " + model_msg,
768                        flush=True,
769                    )
770                    raise
771                # 2 iters:  First iter creates grads, second iter tries zeroed grads.
772                for it in range(2):
773                    iter_msg = f"iter = {it} " + model_msg
774                    named_msg = iter_msg
775                    try:
776                        F.mse_loss(m(input).float(), target).backward()
777                        F.mse_loss(m_dp(input).float(), target).backward()
778                        for i, ((layer_name, m_child), m_dp_child) in enumerate(
779                            zip(m.named_children(), m_dp.module.children())
780                        ):
781                            named_msg = layer_name + ".weight " + iter_msg
782                            self.assertTrue(
783                                m_child.weight.grad.is_contiguous(
784                                    memory_format=formats[i]
785                                ),
786                                named_msg,
787                            )
788                            self.assertTrue(
789                                m_dp_child.weight.grad.is_contiguous(
790                                    memory_format=formats[i]
791                                ),
792                                named_msg,
793                            )
794                            for j, ((param_name, p), p_dp) in enumerate(
795                                zip(m_child.named_parameters(), m_dp_child.parameters())
796                            ):
797                                named_msg = (
798                                    layer_name + "." + param_name + " " + iter_msg
799                                )
800                                self.assertEqual(p.grad, p_dp.grad, rtol=tol, atol=tol)
801                        opt.step()
802                        opt_dp.step()
803                        opt.zero_grad()
804                        opt_dp.zero_grad()
805                    except BaseException:
806                        # Makes sure we still get info if an error occurred somewhere other than the asserts.
807                        print(
808                            "Caught exception during iterations at " + named_msg,
809                            flush=True,
810                        )
811                        raise
812
813    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
814    def test_parameter_list_dict_replica(self):
815        class MyMod(torch.nn.Module):
816            def __init__(self, data, check_fn):
817                super().__init__()
818                self.data = data
819                self.check_fn = check_fn
820
821            def forward(self, inp):
822                self.check_fn(self)
823                return inp
824
825        p1 = torch.nn.Parameter(torch.rand(10))
826        p2 = torch.nn.Parameter(torch.rand(10))
827        key0 = 0
828        key1 = 1
829
830        def check_fn(self_):
831            self.assertEqual(p1, self_.data[key0])
832            self.assertEqual(p2, self_.data[key1])
833            self.assertTrue(self_.data[key0].requires_grad)
834            self.assertTrue(self_.data[key1].requires_grad)
835            self.assertIsNotNone(self_.data[key0].grad_fn)
836            self.assertIsNotNone(self_.data[key1].grad_fn)
837
838        module = MyMod(torch.nn.ParameterList([p1, p2]), check_fn).cuda()
839        model = dp.DataParallel(module)
840        input = torch.randn((8, 8), device="cuda")
841
842        # Runs the check_fn
843        model(input)
844
845        key0 = "0"
846        key1 = "1"
847        module = MyMod(torch.nn.ParameterDict({"0": p1, "1": p2}), check_fn).cuda()
848        model = dp.DataParallel(module)
849        input = torch.randn((8, 8), device="cuda")
850
851        # Runs the check_fn
852        model(input)
853
854
855class TestDataParallelDeviceType(TestCase):
856    @onlyCUDA
857    @skipMeta
858    @dtypes(torch.float, torch.double, torch.half)
859    def test_data_parallel_module(self, device, dtype):
860        l = nn.Linear(10, 5).to(device, dtype)
861        i = torch.randn(20, 10, device=device, dtype=dtype)
862        expected_out = l(i)
863        net = nn.DataParallel(l)
864        out = net(i)
865        self.assertEqual(out.get_device(), 0)
866        self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0)
867
868    @onlyCUDA
869    @skipMeta
870    @dtypes(torch.float, torch.double, torch.half)
871    def test_data_parallel_module_kwargs_only(self, device, dtype):
872        class Net(nn.Module):
873            def __init__(self) -> None:
874                super().__init__()
875                self.l = l
876
877            def forward(self, input):
878                return self.l(input)
879
880        l = nn.Linear(10, 5).to(device, dtype)
881        i = torch.randn(20, 10, device=device, dtype=dtype)
882        expected_out = l(i)
883        n = nn.DataParallel(Net())
884        out = n(input=i)
885        self.assertEqual(out.get_device(), 0)
886        self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0)
887
888    @onlyCUDA
889    @skipMeta
890    @dtypes(torch.float, torch.double, torch.half)
891    def test_data_parallel_module_kwargs_only_empty_list(self, device, dtype):
892        class Net(nn.Module):
893            def __init__(self) -> None:
894                super().__init__()
895                self.l = l
896
897            def forward(self, input):
898                return self.l(input["data"])
899
900        l = nn.Linear(10, 5).to(device, dtype)
901        i = torch.randn(20, 10, device=device, dtype=dtype)
902        expected_out = l(i)
903        n = nn.DataParallel(Net())
904        out = n(input={"data": i, "unused": []})
905        self.assertEqual(out.get_device(), 0)
906        self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0)
907
908    @onlyCUDA
909    @skipMeta
910    @dtypes(torch.float, torch.double, torch.half)
911    def test_data_parallel_module_kwargs_only_empty_dict(self, device, dtype):
912        class Net(nn.Module):
913            def __init__(self) -> None:
914                super().__init__()
915                self.l = l
916
917            def forward(self, input):
918                return self.l(input["data"])
919
920        l = nn.Linear(10, 5).to(device, dtype)
921        i = torch.randn(20, 10, device=device, dtype=dtype)
922        expected_out = l(i)
923        n = nn.DataParallel(Net())
924        out = n(input={"data": i, "unused": {}})
925        self.assertEqual(out.get_device(), 0)
926        self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0)
927
928    @onlyCUDA
929    @skipMeta
930    @dtypes(torch.float, torch.double, torch.half)
931    def test_data_parallel_module_kwargs_only_empty_tuple(self, device, dtype):
932        class Net(nn.Module):
933            def __init__(self) -> None:
934                super().__init__()
935                self.l = l
936
937            def forward(self, input):
938                return self.l(input["data"])
939
940        l = nn.Linear(10, 5).to(device, dtype)
941        i = torch.randn(20, 10, device=device, dtype=dtype)
942        expected_out = l(i)
943        n = nn.DataParallel(Net())
944        out = n(input={"data": i, "unused": ()})
945        self.assertEqual(out.get_device(), 0)
946        self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0)
947
948
949instantiate_device_type_tests(TestDataParallelDeviceType, globals())
950
951if __name__ == "__main__":
952    TestCase._default_dtype_check_enabled = True
953    run_tests()
954