1# Owner(s): ["oncall: distributed"] 2 3import contextlib 4import functools 5import io 6from collections import OrderedDict 7from copy import deepcopy 8from itertools import product 9 10import torch 11import torch.nn.functional as F 12import torch.nn.parallel as dp 13from torch import nn 14from torch.cuda.amp import autocast 15from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU 16from torch.testing._internal.common_device_type import ( 17 dtypes, 18 instantiate_device_type_tests, 19 onlyCUDA, 20 skipMeta, 21) 22from torch.testing._internal.common_utils import ( 23 _assertGradAndGradgradChecks, 24 dtype2prec_DONTUSE, 25 gradcheck, 26 run_tests, 27 skip_but_pass_in_sandcastle_if, 28 TestCase, 29) 30 31 32NO_NCCL = not hasattr(torch.distributed, "ProcessGroupNCCL") 33 34# batched grad doesn't support data parallel 35gradcheck = functools.partial(gradcheck, check_batched_grad=False) 36_assertGradAndGradgradChecks = functools.partial( 37 _assertGradAndGradgradChecks, check_batched_grad=False 38) 39 40 41class TestDataParallel(TestCase): 42 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") 43 def test_data_parallel_buffers_requiring_grad(self): 44 class TestModule(nn.Module): 45 def __init__(self, t): 46 super().__init__() 47 self.t_rg = nn.Buffer(t) 48 self.t_not_rg = nn.Buffer(t.clone().detach()) 49 50 def forward(self, x): 51 return x * self.t_rg + self.t_not_rg 52 53 m = TestModule( 54 torch.randn(100, device="cuda", requires_grad=True, dtype=torch.double) 55 ) 56 self.assertTrue(m.t_rg.requires_grad) 57 58 dpm = nn.DataParallel(m, [0, 1]) 59 inp = torch.randn(2, 100, device="cuda", dtype=torch.double) 60 61 def fn(t): 62 return dpm(inp) 63 64 gradcheck(fn, (m.t_rg,)) 65 66 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") 67 def test_data_parallel_rnn(self): 68 class TestModule(torch.nn.Module): 69 def __init__(self) -> None: 70 super().__init__() 71 self.rnn = torch.nn.LSTM( 72 300, 1024, 1, batch_first=True, bidirectional=True 73 ) 74 75 def forward(self, x): 76 self.rnn.flatten_parameters() 77 return self.rnn(x) 78 79 def step(model): 80 opt = torch.optim.SGD(model.parameters(), lr=10) 81 input = torch.ones(4, 4, 300).to(0) 82 output = model(input) 83 loss = F.mse_loss(output[0], torch.zeros_like(output[0])) 84 loss.backward() 85 opt.step() 86 87 with torch.no_grad(): 88 model = TestModule().to(0) 89 model_dp = torch.nn.DataParallel(deepcopy(model)) 90 91 # make sure DP does not crash when grad is disabled. 92 # See #21108 93 model_dp(torch.rand(2, 4, 300).to(0)) 94 95 step(model) 96 step(model_dp) 97 98 for p1, p2 in zip(model.parameters(), model_dp.parameters()): 99 self.assertTrue(p1.allclose(p2)) 100 101 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") 102 def test_data_parallel_lazy_linear(self): 103 with self.assertRaisesRegex( 104 ValueError, "Attempted to use an uninitialized parameter" 105 ): 106 model_dp = torch.nn.DataParallel(torch.nn.LazyLinear(10).to(0)) 107 model_dp(torch.rand(10, 10).to(0)) 108 109 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") 110 def test_parallel_apply(self): 111 l1 = nn.Linear(10, 5).to("cuda:0", torch.float) 112 l2 = nn.Linear(10, 5).to("cuda:1", torch.float) 113 i1 = torch.randn(2, 10, device="cuda:0", dtype=torch.float) 114 i2 = torch.randn(2, 10, device="cuda:1", dtype=torch.float) 115 expected1 = l1(i1) 116 expected2 = l2(i2) 117 modules = (l1, l2) 118 expected_outputs = (expected1, expected2) 119 120 # each input can be either a collection of positional arguments 121 # or an object representing the single argument 122 for inputs in [((i1,), (i2,)), (i1, i2)]: 123 outputs = dp.parallel_apply(modules, inputs, None) 124 for out, expected in zip(outputs, expected_outputs): 125 self.assertEqual(out, expected) 126 127 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") 128 def test_parallel_apply_autocast(self): 129 l1 = nn.Linear(10, 5).to("cuda:0", torch.float) 130 l2 = nn.Linear(10, 5).to("cuda:1", torch.float) 131 i1 = torch.randn(2, 10, device="cuda:0", dtype=torch.float) 132 i2 = torch.randn(2, 10, device="cuda:1", dtype=torch.float) 133 with autocast(): 134 expected1 = l1(i1) 135 expected2 = l2(i2) 136 modules = (l1, l2) 137 expected_outputs = (expected1, expected2) 138 139 # each input can be either a collection of positional arguments 140 # or an object representing the single argument 141 for inputs in [((i1,), (i2,)), (i1, i2)]: 142 with autocast(): 143 outputs = dp.parallel_apply(modules, inputs, None) 144 for out, expected in zip(outputs, expected_outputs): 145 self.assertEqual(out, expected) 146 147 @skip_but_pass_in_sandcastle_if(not TEST_CUDA, "CUDA unavailable") 148 def test_parallel_apply_passes_exception(self): 149 # we define and instantiate a module that will throw a KeyError 150 class TestModule(nn.Module): 151 def forward(self, *args): 152 return {}["wonderful"] 153 154 l1 = TestModule().to("cuda", torch.float) 155 # and check that parallel_apply passes on the exception 156 # (we can use a single device twice for this test) 157 with self.assertRaisesRegex( 158 KeyError, 159 "Caught KeyError in replica \\d " 160 "on device 0.\nOriginal Traceback" 161 "[\\s\\S]+wonderful", 162 ): 163 dp.parallel_apply(modules=(l1, l1), inputs=(None, None)) 164 165 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") 166 def test_data_parallel_multiple_input(self): 167 class TestModule(nn.Module): 168 def forward(self, var1, var2, float1, var3=None): 169 if var3 is None: 170 return float1 * (var1 * var2) 171 else: 172 return float1 * (var1 * var2 + var3) 173 174 m = TestModule() 175 var1 = torch.randn(5, 5, dtype=torch.float, requires_grad=True) 176 var2 = torch.randn(5, 5, dtype=torch.float, requires_grad=True) 177 var3 = torch.randn(5, 5, dtype=torch.float, requires_grad=False) 178 179 float1 = torch.randn(1).item() 180 181 expected = m(var1, var2, float1) 182 loss = expected.sum() 183 loss.backward() 184 gvar1_exp = var1.grad.clone() 185 gvar2_exp = var2.grad.clone() 186 187 def local_test(out): 188 with torch.no_grad(): 189 var1.grad.fill_(0.0) 190 var2.grad.fill_(0.0) 191 loss = out.sum() 192 loss.backward() 193 self.assertEqual(out, expected) 194 self.assertEqual(gvar1_exp, var1.grad) 195 self.assertEqual(gvar2_exp, var2.grad) 196 197 out = dp.data_parallel(m, (var1, var2, float1), (0, 1)) 198 local_test(out) 199 200 out = dp.data_parallel(m, (var1, var2, float1), (1, 0)) 201 local_test(out) 202 203 out = dp.data_parallel(m, (var1, var2, float1), (0,)) 204 local_test(out) 205 206 with torch.no_grad(): 207 var1.grad.fill_(0.0) 208 var2.grad.fill_(0.0) 209 expected = m(var1, var2, float1, var3=var3) 210 loss = expected.sum() 211 loss.backward() 212 gvar1_exp = var1.grad.clone() 213 gvar2_exp = var2.grad.clone() 214 215 dpm = nn.DataParallel(TestModule()) 216 out = dpm(var1, var2, float1, var3=var3) 217 local_test(out) 218 219 dpm = nn.DataParallel(TestModule(), device_ids=[0]) 220 out = dpm(var1, var2, float1, var3=var3) 221 local_test(out) 222 223 kwarg_wrap = {"var3": var3} 224 out = dp.data_parallel( 225 m, (var1, var2, float1), (0, 1), module_kwargs=kwarg_wrap 226 ) 227 local_test(out) 228 229 out = dp.data_parallel(m, (var1, var2, float1), (0,), module_kwargs=kwarg_wrap) 230 local_test(out) 231 232 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") 233 def test_data_parallel_small_back(self): 234 l = nn.Linear(10, 5).float().cuda() 235 i = torch.randn(20, 10, dtype=torch.float, device="cuda") 236 out = dp.data_parallel(l, i, (0, 1)) 237 self.assertEqual(out, l(i)) 238 239 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") 240 def test_data_parallel_model_device(self): 241 r"""Test device[0] check at forward time.""" 242 l = nn.Linear(2, 2) 243 inp = torch.randn(2, 2) 244 inp_cuda0 = inp.cuda(0) 245 inp_cuda1 = inp.cuda(1) 246 247 error_msg = "module must have its parameters and buffers on device {}" 248 249 @contextlib.contextmanager 250 def dummy_ctx_manager(): 251 yield 252 253 def test(inner_m, dp_device, inp, device_ids, should_fail): 254 if device_ids is None: 255 device_ids = list(range(torch.cuda.device_count())) 256 257 if isinstance(device_ids[0], torch.device): 258 expect_device = device_ids[0] 259 else: 260 expect_device = torch.device(f"cuda:{device_ids[0]}") 261 262 if should_fail: 263 264 def assert_correct(): 265 return self.assertRaisesRegex( 266 RuntimeError, error_msg.format(expect_device) 267 ) 268 269 else: 270 assert_correct = dummy_ctx_manager 271 272 # test DataParallel module 273 dpm = nn.DataParallel(inner_m, device_ids) 274 if dp_device is not None: 275 dpm = dpm.to(dp_device) 276 277 with assert_correct(): 278 dpm(inp) 279 280 # test functional 281 with assert_correct(): 282 nn.parallel.data_parallel(inner_m.to(dp_device), inp, device_ids) 283 284 test(l.to("cpu"), None, inp, None, should_fail=True) 285 test(l.cuda(1), None, inp_cuda0, None, should_fail=True) 286 test(l.cuda(), None, inp_cuda0, [1, 0], should_fail=True) 287 288 test(l.cuda(), None, inp_cuda0, None, should_fail=False) 289 test(l.cpu(), "cuda", inp_cuda0, None, should_fail=False) 290 test(l.cuda(1), None, inp_cuda1, [1, 0], should_fail=False) 291 test(l.cpu(), "cuda:1", inp_cuda1, [1, 0], should_fail=False) 292 293 s = nn.Sequential(l.cpu()) 294 test(s, None, inp, None, should_fail=True) 295 test(s, None, inp, [0, 1], should_fail=True) 296 test(s, None, inp, [1, 0], should_fail=True) 297 298 s = nn.Sequential(deepcopy(l).cpu(), l.cuda()) 299 test(s, None, inp, None, should_fail=True) 300 test(s, None, inp, [0, 1], should_fail=True) 301 test(s, None, inp, [1, 0], should_fail=True) 302 303 s = nn.Sequential(l.cuda(), deepcopy(l).cuda(1)) 304 test(s, None, inp, None, should_fail=True) 305 test(s, None, inp, [0, 1], should_fail=True) 306 test(s, None, inp, [1, 0], should_fail=True) 307 308 s = nn.Sequential(l.cuda(), deepcopy(l).cuda()) 309 test(s, None, inp, None, should_fail=False) 310 test(s, None, inp, [0, 1], should_fail=False) 311 test(s, None, inp, [1, 0], should_fail=True) 312 test(s.cpu(), None, inp, [1, 0], should_fail=True) 313 test(s.cuda(1), None, inp, [1, 0], should_fail=False) 314 315 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") 316 def test_data_parallel_model_no_refcycles(self): 317 # Python 2.7 will create reference cycles with the following 318 # Module on multiple GPUs, but Python 3 shouldn't unless 319 # there are refcycles on the PyTorch side (or the defined module) 320 import gc 321 322 class Model(nn.Module): 323 def __init__(self) -> None: 324 super().__init__() 325 self.linear = nn.Linear(1, 1) 326 327 def forward(self, x): 328 return self.linear(x) 329 330 gc.collect() 331 model = nn.DataParallel(Model().cuda()) 332 data = torch.randn(1, device="cuda") 333 model(data) 334 335 refcycles = gc.collect() 336 self.assertEqual(refcycles, 0) 337 338 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") 339 def test_data_parallel_no_grad(self): 340 test = self 341 342 class Layer(nn.Module): 343 def forward(self, x): 344 test.assertFalse(torch.is_grad_enabled()) 345 return x 346 347 l = Layer() 348 i = torch.randn(20, 10, dtype=torch.float, device="cuda") 349 with torch.no_grad(): 350 dp.data_parallel(l, i, (0, 1)) 351 self.assertRaises(AssertionError, lambda: dp.data_parallel(l, i, (0, 1))) 352 353 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") 354 def test_data_parallel(self): 355 l = nn.Linear(10, 5).float().cuda() 356 i = torch.randn(20, 10, dtype=torch.float, device="cuda:1") 357 l.cuda(1) 358 expected_out = l(i) 359 loss = expected_out.sum() 360 loss.backward() 361 expected_grads = [] 362 for param in l.parameters(): 363 expected_grads.append(param.grad.clone()) 364 dev_ids_list = [(0, 1), (1, 0)] 365 for dev_id in dev_ids_list: 366 with torch.cuda.device(dev_id[0]): 367 l.cuda() 368 l.zero_grad() 369 out = dp.data_parallel(l, i, dev_id) 370 loss = out.sum() 371 loss.backward() 372 self.assertEqual(out.get_device(), dev_id[0]) 373 self.assertEqual(out, expected_out) 374 for expected, param in zip(expected_grads, l.parameters()): 375 self.assertEqual(param.grad, expected) 376 377 # Check for None device_ids 378 l = l.cuda() 379 out = dp.data_parallel(l, i) 380 381 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") 382 def test_data_parallel_sparse(self): 383 l = nn.Embedding(10, 5, sparse=True).to("cuda:1") 384 i = torch.randint(10, (20, 5), device="cuda:1", dtype=torch.long) 385 expected_out = l(i) 386 loss = expected_out.sum() 387 loss.backward() 388 expected_grads = [] 389 for param in l.parameters(): 390 expected_grads.append(param.grad.clone()) 391 dev_ids_list = [(0, 1), (1, 0)] 392 for dev_id in dev_ids_list: 393 with torch.cuda.device(dev_id[0]): 394 l.cuda() 395 l.zero_grad() 396 out = dp.data_parallel(l, i, dev_id) 397 loss = out.sum() 398 loss.backward() 399 self.assertEqual(out.get_device(), dev_id[0]) 400 self.assertEqual(out, expected_out) 401 for expected, param in zip(expected_grads, l.parameters()): 402 self.assertEqual(param.grad.coalesce(), expected.coalesce()) 403 404 # Check for None device_ids 405 l = l.cuda() 406 out = dp.data_parallel(l, i) 407 408 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") 409 def test_data_parallel_nested_output(self): 410 def fn(input): 411 return [ 412 input, 413 (input.sin(), input.cos(), [input.add(1)]), 414 input, 415 OrderedDict(a=input, b=[input.sin()]), 416 ] 417 418 class Net(nn.Module): 419 def forward(self, input): 420 return fn(input) 421 422 i = torch.randn(2, 2).float().cuda(1) 423 gpus = range(torch.cuda.device_count()) 424 output = dp.data_parallel(Net(), i, gpus) 425 self.assertEqual(output, fn(i)) 426 self.assertIsInstance(output[0], torch.Tensor) 427 self.assertIsInstance(output[1], tuple) 428 self.assertIsInstance(output[1][0], torch.Tensor) 429 self.assertIsInstance(output[1][1], torch.Tensor) 430 self.assertIsInstance(output[1][2], list) 431 self.assertIsInstance(output[1][2][0], torch.Tensor) 432 self.assertIsInstance(output[2], torch.Tensor) 433 self.assertIsInstance(output[3], dict) 434 self.assertEqual(len(output[3]), 2) 435 self.assertIn("a", output[3]) 436 self.assertIn("b", output[3]) 437 self.assertIsInstance(output[3]["a"], torch.Tensor) 438 self.assertIsInstance(output[3]["b"], list) 439 self.assertIsInstance(output[3]["b"][0], torch.Tensor) 440 441 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") 442 def test_data_parallel_nested_input(self): 443 def fn(input): 444 return input[1][0] 445 446 class Net(nn.Module): 447 def forward(self, *input): 448 return fn(input) 449 450 i = torch.randn(20, 3, dtype=torch.float, device="cuda:1") 451 input = (i.cos(), (i.sin(), i), i.sin()) 452 gpus = range(torch.cuda.device_count()) 453 output = dp.data_parallel(Net(), input, gpus) 454 self.assertEqual(output, fn(input)) 455 456 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") 457 def test_data_parallel_module_zero_inputs(self): 458 class TestModule(nn.Module): 459 def forward(self): 460 t = torch.eye(2, 3, device="cuda:0") 461 return t + (1 - t) 462 463 def test_helper(output, expected): 464 self.assertEqual(output.get_device(), 0) 465 self.assertEqual(output, expected) 466 467 expected = torch.ones(2, 3, device="cuda:0") 468 model = TestModule() 469 470 test_helper(nn.DataParallel(model, [0])(), expected) 471 test_helper(nn.DataParallel(model, [0, 1])(), expected) 472 test_helper(dp.data_parallel(model, None, [0]), expected) 473 test_helper(dp.data_parallel(model, (), [0, 1]), expected) 474 475 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") 476 def test_data_parallel_device_args(self): 477 cuda0 = torch.device("cuda:0") 478 cuda1 = torch.device("cuda:1") 479 480 # test output_device 481 l = nn.Linear(10, 5).to(cuda0, torch.float) 482 i = torch.randn(20, 10, dtype=torch.float, device=cuda0, requires_grad=True) 483 out = dp.data_parallel(l, i, device_ids=(0, 1), output_device=cuda0) 484 self.assertEqual(out, l(i)) 485 486 # test device_ids 487 l = nn.Linear(10, 5).to(cuda0, torch.float) 488 i = torch.randn(20, 10, dtype=torch.float, device=cuda0, requires_grad=True) 489 out = dp.data_parallel(l, i, device_ids=(cuda0, cuda1), output_device=cuda0) 490 self.assertEqual(out, l(i)) 491 492 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") 493 def test_data_parallel_function_deletion(self): 494 # this test case is originated from #16532 495 def gradient_penalty(net, x): 496 output = net(x) 497 loss = torch.autograd.grad( 498 outputs=output, 499 inputs=x, 500 grad_outputs=x.new_ones(output.size()), 501 create_graph=True, 502 retain_graph=True, 503 )[0].mean() 504 return loss 505 506 net = nn.Linear(4, 1).cuda() 507 dpn = nn.DataParallel(net, [0, 1]) 508 x = torch.ones(2, 4, requires_grad=True).cuda() 509 510 dpn.zero_grad() 511 loss = gradient_penalty(dpn, x) 512 loss.backward() 513 grads = [p.grad for p in net.parameters()] 514 self.assertEqual(2, len(grads)) 515 self.assertEqual( 516 torch.tensor([[0.25, 0.25, 0.25, 0.25]], device="cuda:0"), grads[0] 517 ) 518 self.assertEqual(torch.tensor([0.0], device="cuda:0"), grads[1]) 519 520 def _test_scatter(self, tensor): 521 x = tensor.detach().requires_grad_() 522 result = dp.scatter(x, (0, 1)) 523 self.assertEqual(len(result), 2) 524 self.assertEqual(result[0], x[:2]) 525 self.assertEqual(result[0].get_device(), 0) 526 self.assertEqual(result[1], x[2:]) 527 self.assertEqual(result[1].get_device(), 1) 528 grad = result[0].detach().clone().fill_(2) 529 result[0].backward(grad) 530 self.assertEqual(x.grad[:2], grad) 531 self.assertEqual(x.grad[2:], grad.clone().zero_()) 532 _assertGradAndGradgradChecks(self, lambda y: dp.scatter(y, (0, 1)), (x,)) 533 534 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") 535 def test_scatter_cpu(self): 536 self._test_scatter(torch.randn((4, 4), dtype=torch.double)) 537 538 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") 539 def test_scatter_gpu(self): 540 self._test_scatter(torch.randn((4, 4), dtype=torch.double).cuda()) 541 542 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed") 543 @skip_but_pass_in_sandcastle_if(NO_NCCL, "NCCL needed") 544 def test_data_parallel_complex(self): 545 # We expect complex parameters to be broadcast by view_as_real, e.g. move from C to R^2 546 class Cplx(torch.nn.Module): 547 def __init__(self) -> None: 548 super().__init__() 549 self.cplx = torch.nn.Parameter( 550 torch.zeros(1, 10, dtype=torch.cfloat).cuda() 551 ) 552 553 def forward(self, x): 554 return x + self.cplx 555 556 cplx = torch.nn.DataParallel(Cplx().cuda()) 557 input = torch.rand(1, 10, dtype=torch.cfloat).cuda() 558 result = cplx(input) 559 # 2 is the extra real view dimension here 560 self.assertEqual(result.size(), torch.Size([1, 10, 2])) 561 self.assertEqual(result, torch.view_as_real(input)) 562 563 def _test_gather(self, output_device): 564 inputs = ( 565 torch.randn(2, 4, device="cuda:0", requires_grad=True, dtype=torch.double), 566 torch.randn(2, 4, device="cuda:1", requires_grad=True, dtype=torch.double), 567 ) 568 result = dp.gather(inputs, output_device) 569 self.assertEqual(result.size(), torch.Size([4, 4])) 570 self.assertEqual(result[:2], inputs[0]) 571 self.assertEqual(result[2:], inputs[1]) 572 if output_device != -1: 573 self.assertEqual(result.get_device(), output_device) 574 else: 575 self.assertFalse(result.is_cuda) 576 grad = torch.randn((4, 4), dtype=torch.double) 577 if output_device != -1: 578 grad = grad.cuda(output_device) 579 result.backward(grad) 580 self.assertEqual(inputs[0].grad, grad[:2]) 581 self.assertEqual(inputs[1].grad, grad[2:]) 582 _assertGradAndGradgradChecks( 583 self, lambda x, y: dp.gather((x, y), output_device), inputs 584 ) 585 586 # test scalar inputs, should stack into a vector in this case 587 inputs = ( 588 torch.randn((), device="cuda:0", requires_grad=True, dtype=torch.double), 589 torch.randn((), device="cuda:1", requires_grad=True, dtype=torch.double), 590 ) 591 result = dp.gather(inputs, output_device) 592 self.assertEqual(result.size(), torch.Size([2])) 593 self.assertEqual(result[0], inputs[0]) 594 self.assertEqual(result[1], inputs[1]) 595 if output_device != -1: 596 self.assertEqual(result.get_device(), output_device) 597 else: 598 self.assertFalse(result.is_cuda) 599 grad = torch.randn(2, dtype=torch.double) 600 if output_device != -1: 601 grad = grad.cuda(output_device) 602 result.backward(grad) 603 self.assertEqual(inputs[0].grad, grad[0]) 604 self.assertEqual(inputs[1].grad, grad[1]) 605 _assertGradAndGradgradChecks( 606 self, lambda x, y: dp.gather((x, y), output_device), inputs 607 ) 608 609 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") 610 def test_gather_cpu(self): 611 self._test_gather(-1) 612 613 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") 614 def test_gather_gpu(self): 615 self._test_gather(0) 616 617 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") 618 def test_gather_different_len_dicts(self): 619 inputs = ( 620 {"a": torch.randn(1, 2, requires_grad=True, device="cuda:0")}, 621 { 622 "b": torch.randn(1, 2, requires_grad=True, device="cuda:1"), 623 "a": torch.randn(1, 2, requires_grad=True, device="cuda:1"), 624 }, 625 ) 626 with self.assertRaises(ValueError): 627 _ = dp.gather(inputs, target_device=0) 628 629 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") 630 def test_replicate(self): 631 module = nn.Linear(10, 5).float().cuda() 632 input = torch.randn(2, 10, dtype=torch.float, device="cuda") 633 expected_output = module(input) 634 for devices in [(0, 1), [0, 1]]: 635 replicas = dp.replicate(module, devices) 636 for i, replica in enumerate(replicas): 637 for p in replica.parameters(): 638 self.assertEqual(p.get_device(), i) 639 replica_input = input.cuda(i) 640 self.assertEqual(replica(replica_input), expected_output) 641 642 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") 643 def test_replicate_buffers(self): 644 net = nn.Module() 645 net.bn = nn.BatchNorm2d(10) 646 net.cuda() 647 for devices in [(0, 1), [0, 1]]: 648 replicas = dp.replicate(net, devices) 649 for i, replica in enumerate(replicas): 650 self.assertEqual( 651 replica.bn.running_mean.get_device(), 652 i, 653 msg="buffer on wrong device", 654 ) 655 self.assertEqual( 656 replica.bn.running_var.get_device(), i, msg="buffer on wrong device" 657 ) 658 self.assertEqual( 659 replica.bn.num_batches_tracked.get_device(), 660 i, 661 msg="buffer on wrong device", 662 ) 663 664 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") 665 def test_zero_grad(self): 666 # zero_grad should warn about using gradients inside forward 667 668 class Net(torch.nn.Module): 669 def __init__(self, testcase): 670 super().__init__() 671 self._testcase = testcase 672 673 def forward(self, x): 674 with self._testcase.assertWarnsRegex( 675 UserWarning, 676 r"Calling \.zero_grad\(\) from a module created with nn\.DataParallel\(\) has no effect.", 677 ): 678 self.zero_grad() 679 return x 680 681 module = Net(self).cuda() 682 dpm = dp.DataParallel(module) 683 dpm(torch.rand(4, 3, 6, 5)) 684 685 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") 686 def test_autocast(self): 687 class Model(torch.nn.Linear): 688 def __init__(self) -> None: 689 super().__init__(8, 8) 690 691 @torch.cuda.amp.autocast() 692 def forward(self, input): 693 return super().forward(input) 694 695 model = dp.DataParallel(Model().cuda().to(dtype=torch.float32)) 696 input = torch.randn((8, 8), dtype=torch.float32, device="cuda") 697 self.assertTrue(model(input).dtype is torch.float16) 698 699 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") 700 def test_save_replica_module(self): 701 # DataParallel replicas can be saved (gh-37182) 702 module = torch.nn.Linear(8, 8).cuda() 703 dpm = torch.nn.parallel.replicate(module, devices=[0, 1], detach=False) 704 data = io.BytesIO() 705 torch.save(dpm, data) 706 dpm = torch.nn.parallel.replicate(module, devices=[0, 1], detach=True) 707 torch.save(dpm, data) 708 709 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") 710 def test_strided_grad_layout(self): 711 class ConvNet(nn.Module): 712 def __init__(self, layouts, dtype_list): 713 super().__init__() 714 self.dtypes = dtype_list 715 self.conv0 = torch.nn.Conv2d(8, 16, (2, 2)).to( 716 memory_format=layouts[0], dtype=dtype_list[0] 717 ) 718 self.conv1 = torch.nn.Conv2d(16, 32, (2, 2)).to( 719 memory_format=layouts[1], dtype=dtype_list[1] 720 ) 721 self.conv2 = torch.nn.Conv2d(32, 16, (2, 2)).to( 722 memory_format=layouts[2], dtype=dtype_list[2] 723 ) 724 self.conv3 = torch.nn.Conv2d(16, 8, (2, 2)).to( 725 memory_format=layouts[3], dtype=dtype_list[3] 726 ) 727 728 def forward(self, x): 729 x = x.to(self.dtypes[0]) 730 x = self.conv0(x).to(self.dtypes[1]) 731 x = self.conv1(x).to(self.dtypes[2]) 732 x = self.conv2(x).to(self.dtypes[3]) 733 x = self.conv3(x) 734 return x 735 736 layer_formats = ( 737 [torch.contiguous_format] * 4, 738 [torch.channels_last] * 2 + [torch.contiguous_format] * 2, 739 [torch.channels_last] * 4, 740 ) 741 layer_dtypes = ( 742 [torch.float] * 4, 743 [torch.float] * 2 + [torch.half] * 2, 744 [torch.half] * 4, 745 ) 746 747 ndevs = torch.cuda.device_count() 748 input = torch.randn(ndevs * 8, 8, 8, 8, device="cuda:0", dtype=torch.float) 749 target = torch.randn(ndevs * 8, 8, 4, 4, device="cuda:0", dtype=torch.float) 750 device_ids = list(range(ndevs)) 751 752 with torch.backends.cudnn.flags( 753 enabled=True, deterministic=True, benchmark=False 754 ): 755 for formats, dtype_list in product(layer_formats, layer_dtypes): 756 model_msg = f"formats = {formats} dtypes = {dtypes}" 757 try: 758 m = ConvNet(formats, dtype_list).cuda(device="cuda:0") 759 m_dp = dp.DataParallel(deepcopy(m), device_ids=device_ids) 760 opt = torch.optim.SGD(m.parameters(), lr=0.1) 761 opt_dp = torch.optim.SGD(m_dp.parameters(), lr=0.1) 762 has_half = any(p.dtype is torch.half for p in m.parameters()) 763 tol = 1.0e-3 if has_half else 1.0e-5 764 except BaseException: 765 # Prints case-specific debugging info to narrow down failing case. 766 print( 767 "Caught exception during model creation for " + model_msg, 768 flush=True, 769 ) 770 raise 771 # 2 iters: First iter creates grads, second iter tries zeroed grads. 772 for it in range(2): 773 iter_msg = f"iter = {it} " + model_msg 774 named_msg = iter_msg 775 try: 776 F.mse_loss(m(input).float(), target).backward() 777 F.mse_loss(m_dp(input).float(), target).backward() 778 for i, ((layer_name, m_child), m_dp_child) in enumerate( 779 zip(m.named_children(), m_dp.module.children()) 780 ): 781 named_msg = layer_name + ".weight " + iter_msg 782 self.assertTrue( 783 m_child.weight.grad.is_contiguous( 784 memory_format=formats[i] 785 ), 786 named_msg, 787 ) 788 self.assertTrue( 789 m_dp_child.weight.grad.is_contiguous( 790 memory_format=formats[i] 791 ), 792 named_msg, 793 ) 794 for j, ((param_name, p), p_dp) in enumerate( 795 zip(m_child.named_parameters(), m_dp_child.parameters()) 796 ): 797 named_msg = ( 798 layer_name + "." + param_name + " " + iter_msg 799 ) 800 self.assertEqual(p.grad, p_dp.grad, rtol=tol, atol=tol) 801 opt.step() 802 opt_dp.step() 803 opt.zero_grad() 804 opt_dp.zero_grad() 805 except BaseException: 806 # Makes sure we still get info if an error occurred somewhere other than the asserts. 807 print( 808 "Caught exception during iterations at " + named_msg, 809 flush=True, 810 ) 811 raise 812 813 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") 814 def test_parameter_list_dict_replica(self): 815 class MyMod(torch.nn.Module): 816 def __init__(self, data, check_fn): 817 super().__init__() 818 self.data = data 819 self.check_fn = check_fn 820 821 def forward(self, inp): 822 self.check_fn(self) 823 return inp 824 825 p1 = torch.nn.Parameter(torch.rand(10)) 826 p2 = torch.nn.Parameter(torch.rand(10)) 827 key0 = 0 828 key1 = 1 829 830 def check_fn(self_): 831 self.assertEqual(p1, self_.data[key0]) 832 self.assertEqual(p2, self_.data[key1]) 833 self.assertTrue(self_.data[key0].requires_grad) 834 self.assertTrue(self_.data[key1].requires_grad) 835 self.assertIsNotNone(self_.data[key0].grad_fn) 836 self.assertIsNotNone(self_.data[key1].grad_fn) 837 838 module = MyMod(torch.nn.ParameterList([p1, p2]), check_fn).cuda() 839 model = dp.DataParallel(module) 840 input = torch.randn((8, 8), device="cuda") 841 842 # Runs the check_fn 843 model(input) 844 845 key0 = "0" 846 key1 = "1" 847 module = MyMod(torch.nn.ParameterDict({"0": p1, "1": p2}), check_fn).cuda() 848 model = dp.DataParallel(module) 849 input = torch.randn((8, 8), device="cuda") 850 851 # Runs the check_fn 852 model(input) 853 854 855class TestDataParallelDeviceType(TestCase): 856 @onlyCUDA 857 @skipMeta 858 @dtypes(torch.float, torch.double, torch.half) 859 def test_data_parallel_module(self, device, dtype): 860 l = nn.Linear(10, 5).to(device, dtype) 861 i = torch.randn(20, 10, device=device, dtype=dtype) 862 expected_out = l(i) 863 net = nn.DataParallel(l) 864 out = net(i) 865 self.assertEqual(out.get_device(), 0) 866 self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0) 867 868 @onlyCUDA 869 @skipMeta 870 @dtypes(torch.float, torch.double, torch.half) 871 def test_data_parallel_module_kwargs_only(self, device, dtype): 872 class Net(nn.Module): 873 def __init__(self) -> None: 874 super().__init__() 875 self.l = l 876 877 def forward(self, input): 878 return self.l(input) 879 880 l = nn.Linear(10, 5).to(device, dtype) 881 i = torch.randn(20, 10, device=device, dtype=dtype) 882 expected_out = l(i) 883 n = nn.DataParallel(Net()) 884 out = n(input=i) 885 self.assertEqual(out.get_device(), 0) 886 self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0) 887 888 @onlyCUDA 889 @skipMeta 890 @dtypes(torch.float, torch.double, torch.half) 891 def test_data_parallel_module_kwargs_only_empty_list(self, device, dtype): 892 class Net(nn.Module): 893 def __init__(self) -> None: 894 super().__init__() 895 self.l = l 896 897 def forward(self, input): 898 return self.l(input["data"]) 899 900 l = nn.Linear(10, 5).to(device, dtype) 901 i = torch.randn(20, 10, device=device, dtype=dtype) 902 expected_out = l(i) 903 n = nn.DataParallel(Net()) 904 out = n(input={"data": i, "unused": []}) 905 self.assertEqual(out.get_device(), 0) 906 self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0) 907 908 @onlyCUDA 909 @skipMeta 910 @dtypes(torch.float, torch.double, torch.half) 911 def test_data_parallel_module_kwargs_only_empty_dict(self, device, dtype): 912 class Net(nn.Module): 913 def __init__(self) -> None: 914 super().__init__() 915 self.l = l 916 917 def forward(self, input): 918 return self.l(input["data"]) 919 920 l = nn.Linear(10, 5).to(device, dtype) 921 i = torch.randn(20, 10, device=device, dtype=dtype) 922 expected_out = l(i) 923 n = nn.DataParallel(Net()) 924 out = n(input={"data": i, "unused": {}}) 925 self.assertEqual(out.get_device(), 0) 926 self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0) 927 928 @onlyCUDA 929 @skipMeta 930 @dtypes(torch.float, torch.double, torch.half) 931 def test_data_parallel_module_kwargs_only_empty_tuple(self, device, dtype): 932 class Net(nn.Module): 933 def __init__(self) -> None: 934 super().__init__() 935 self.l = l 936 937 def forward(self, input): 938 return self.l(input["data"]) 939 940 l = nn.Linear(10, 5).to(device, dtype) 941 i = torch.randn(20, 10, device=device, dtype=dtype) 942 expected_out = l(i) 943 n = nn.DataParallel(Net()) 944 out = n(input={"data": i, "unused": ()}) 945 self.assertEqual(out.get_device(), 0) 946 self.assertEqual(out, expected_out, atol=dtype2prec_DONTUSE[dtype], rtol=0) 947 948 949instantiate_device_type_tests(TestDataParallelDeviceType, globals()) 950 951if __name__ == "__main__": 952 TestCase._default_dtype_check_enabled = True 953 run_tests() 954