1# Owner(s): ["oncall: distributed"] 2 3import functools 4import os 5import sys 6import warnings 7from collections import namedtuple 8from contextlib import nullcontext 9from copy import deepcopy 10from itertools import chain 11from typing import Any, Tuple 12 13import torch 14import torch.distributed as dist 15import torch.distributed.fsdp._traversal_utils as traversal_utils 16import torch.nn as nn 17from torch.distributed.fsdp import ( 18 CPUOffload, 19 FlatParameter, 20 FullyShardedDataParallel as FSDP, 21 ShardingStrategy, 22) 23from torch.distributed.fsdp._flat_param import _FSDP_USE_UNSAFE_SETATTR 24from torch.distributed.fsdp._runtime_utils import HOMOGENEOUS_ATTR_NAMES 25from torch.distributed.fsdp.wrap import ( 26 always_wrap_policy, 27 ModuleWrapPolicy, 28 transformer_auto_wrap_policy, 29) 30from torch.distributed.optim import _apply_optimizer_in_backward 31from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer 32from torch.nn.parallel import DistributedDataParallel as DDP 33from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 34from torch.testing._internal.common_fsdp import ( 35 _assert_module_states, 36 CUDAInitMode, 37 FSDPInitMode, 38 FSDPTest, 39 FSDPTestMultiThread, 40 MLP, 41 NestedWrappedModule, 42 TransformerWithSharedParams, 43) 44from torch.testing._internal.common_utils import ( 45 instantiate_parametrized_tests, 46 parametrize, 47 run_tests, 48 TEST_WITH_DEV_DBG_ASAN, 49) 50 51 52if not dist.is_available(): 53 print("Distributed not available, skipping tests", file=sys.stderr) 54 sys.exit(0) 55 56if TEST_WITH_DEV_DBG_ASAN: 57 print( 58 "Skip dev-asan as torch + multiprocessing spawn have known issues", 59 file=sys.stderr, 60 ) 61 sys.exit(0) 62 63 64class MyModel(nn.Module): 65 def __init__(self) -> None: 66 super().__init__() 67 self.a = nn.Linear(2, 2) 68 self.b = nn.Linear(2, 2) 69 70 def forward(self, x, y): 71 return self.b(self.a(x + y)) 72 73 74class TestFSDPMiscMultiProcess(FSDPTest): 75 @property 76 def world_size(self): 77 return 2 78 79 @property 80 def process_group(self): 81 return dist.distributed_c10d._get_default_group() 82 83 @skip_if_lt_x_gpu(2) 84 @parametrize("use_index", [True, False]) 85 def test_fsdp_device_id(self, use_index): 86 """ 87 Tests the FSDP ``device_id`` argument: 88 - Wrapping a CPU module should move the module to the GPU matching 89 ``device_id`` 90 - Wrapping a GPU module already on the GPU matching ``device_id`` 91 should not raise an error 92 - Wrapping a GPU module already on GPU and passing a GPU device 93 without specifying a device ID (i.e. ``torch.device("cuda")``) warns 94 """ 95 dev_id = ( 96 torch.cuda.current_device() 97 if use_index 98 else torch.device("cuda", torch.cuda.current_device()) 99 ) 100 101 def _check_device_matches(module, device_id): 102 """Checks that the ``FlatParameter``s in ``module`` have device 103 matching ``device_id``.""" 104 devices = { 105 p.device for p in module.parameters() if isinstance(p, FlatParameter) 106 } 107 assert len(devices) > 0 108 self.assertEqual(1, len(devices)) 109 found_device = devices.pop() 110 if use_index and not isinstance(device_id, torch.device): 111 device = torch.device("cuda", device_id) 112 else: 113 device = device_id 114 self.assertEqual(found_device, device) 115 116 # Check that FSDP parameters are moved to `device_id` for a CPU module 117 nested_wrapped_module = NestedWrappedModule.init( 118 self.process_group, 119 FSDPInitMode.RECURSIVE, 120 CUDAInitMode.CUDA_NEVER, 121 fsdp_kwargs={"device_id": dev_id}, 122 ) 123 _check_device_matches(nested_wrapped_module, dev_id) 124 # Check that specifying `device_id` for a GPU module already on that 125 # device does not raise an error 126 nested_wrapped_module = NestedWrappedModule.init( 127 self.process_group, 128 FSDPInitMode.RECURSIVE, 129 CUDAInitMode.CUDA_BEFORE, 130 fsdp_kwargs={"device_id": dev_id}, 131 ) 132 _check_device_matches(nested_wrapped_module, dev_id) 133 # Check that passing in `torch.device("cuda")` for a GPU module warns 134 regex = "does not have an explicit index" 135 context = self.assertWarnsRegex( 136 expected_warning=UserWarning, expected_regex=regex 137 ) 138 with context: 139 nested_wrapped_module = NestedWrappedModule.init( 140 self.process_group, 141 FSDPInitMode.RECURSIVE, 142 CUDAInitMode.CUDA_BEFORE, 143 fsdp_kwargs={"device_id": torch.device("cuda")}, 144 ) 145 _check_device_matches( 146 nested_wrapped_module, torch.device("cuda", torch.cuda.current_device()) 147 ) 148 149 @skip_if_lt_x_gpu(2) 150 def test_fsdp_zero2_eval_with_prefetch(self): 151 # Test FSDP validation with SHARD_GRAD_OP and forward_prefetch 152 153 class Mnist(nn.Module): 154 def __init__(self) -> None: 155 super().__init__() 156 self.conv1 = nn.Conv2d(1, 32, 3, 1) 157 self.conv2 = nn.Conv2d(32, 64, 3, 1) 158 self.dropout1 = nn.Dropout(0.25) 159 self.dropout2 = nn.Dropout(0.5) 160 self.fc1 = nn.Linear(9216, 128) 161 self.fc2 = nn.Linear(128, 10) 162 self.ln = nn.LayerNorm(9216) 163 164 def forward(self, x, y): 165 x = self.conv1(x) 166 x = torch.nn.functional.relu(x) 167 x = self.conv2(x) 168 x = torch.nn.functional.relu(x) 169 x = torch.nn.functional.max_pool2d(x, 2) 170 x = self.dropout1(x) 171 x = torch.flatten(x, 1) 172 x = self.ln(x) 173 x = self.fc1(x) 174 x = torch.nn.functional.relu(x) 175 x = self.dropout2(x) 176 x = self.fc2(x) 177 output = torch.nn.functional.log_softmax(x, dim=1) 178 loss = torch.nn.functional.cross_entropy(output, y) 179 return loss 180 181 model = Mnist().cuda() 182 model1 = Mnist().cuda() 183 model1.load_state_dict(model.state_dict()) 184 fsdp_model = FSDP( 185 model, 186 sharding_strategy=ShardingStrategy.SHARD_GRAD_OP, 187 forward_prefetch=True, 188 use_orig_params=True, 189 auto_wrap_policy=ModuleWrapPolicy([nn.Linear, nn.Conv2d]), 190 ) 191 ddp_model = torch.nn.parallel.DistributedDataParallel( 192 model1, 193 ) 194 195 fsdp_opt = torch.optim.SGD(fsdp_model.parameters(), lr=1e-4) 196 ddp_opt = torch.optim.SGD(ddp_model.parameters(), lr=1e-4) 197 198 seed = self.rank + 20231010 199 torch.manual_seed(seed) 200 torch.cuda.manual_seed(seed) 201 202 losses = [] 203 grads = [] 204 for i in range(5): 205 x = torch.randn(8, 1, 28, 28, device="cuda").requires_grad_() 206 y = torch.randint(low=0, high=9, size=(8,), device="cuda") 207 for model, opt in ((fsdp_model, fsdp_opt), (ddp_model, ddp_opt)): 208 seed = self.rank + i 209 torch.manual_seed(seed) 210 torch.cuda.manual_seed(seed) 211 loss = model(x, y).sum() 212 losses.append(loss) 213 loss.backward() 214 opt.step() 215 grads.append(x.grad) 216 opt.zero_grad() 217 assert torch.allclose(losses[0], losses[1]) 218 assert torch.allclose(grads[0], grads[1]) 219 losses.clear() 220 grads.clear() 221 222 with torch.no_grad(): 223 fsdp_model.eval() 224 ddp_model.eval() 225 for _ in range(5): 226 x = torch.randn(8, 1, 28, 28, device="cuda").requires_grad_() 227 y = torch.randint(low=0, high=9, size=(8,), device="cuda") 228 fsdp_loss = fsdp_model(x, y) 229 ddp_loss = ddp_model(x, y) 230 assert torch.allclose(fsdp_loss, ddp_loss) 231 232 fsdp_model.train() 233 ddp_model.train() 234 for i in range(5): 235 x = torch.randn(8, 1, 28, 28, device="cuda").requires_grad_() 236 y = torch.randint(low=0, high=9, size=(8,), device="cuda") 237 for model, opt in ((fsdp_model, fsdp_opt), (ddp_model, ddp_opt)): 238 seed = self.rank + i 239 torch.manual_seed(seed) 240 torch.cuda.manual_seed(seed) 241 loss = model(x, y).sum() 242 losses.append(loss) 243 loss.backward() 244 opt.step() 245 grads.append(x.grad) 246 opt.zero_grad() 247 assert torch.allclose(losses[0], losses[1]) 248 assert torch.allclose(grads[0], grads[1]) 249 losses.clear() 250 grads.clear() 251 252 @skip_if_lt_x_gpu(2) 253 @parametrize("use_second_layer", [True, False]) 254 @parametrize("sharding_strategy", [ShardingStrategy.NO_SHARD, None]) 255 def test_fsdp_module_no_compute_grad(self, use_second_layer, sharding_strategy): 256 # When use_second_layer=True, b is involved in forward computation but does 257 # not receive grad in backward. Otherwise, b is not involved in forward 258 # computation. 259 260 class MyModel(nn.Module): 261 def __init__(self) -> None: 262 super().__init__() 263 self.a = nn.Linear(10, 10) 264 self.b = nn.Linear(10, 10) 265 266 def forward(self, x, y): 267 out1 = self.a(x) 268 if use_second_layer: 269 out2 = self.b(y) 270 return out1, out2 271 else: 272 return out1 273 274 fsdp = FSDP( 275 MyModel().cuda(), 276 sharding_strategy=sharding_strategy, 277 auto_wrap_policy=always_wrap_policy, 278 ) 279 x = torch.randn(10, 10, device="cuda") 280 y = torch.randn(10, 10, device="cuda") 281 for i in range(4): 282 if use_second_layer: 283 a, b = fsdp(x, y) 284 else: 285 a = fsdp(x, y) 286 loss = a.sum() 287 loss.backward() 288 289 # self.a receives grad, self.b does not 290 a_grad = fsdp.module.a._handle.flat_param.grad 291 b_grad = fsdp.module.b._handle.flat_param.grad 292 self.assertIsNotNone(a_grad) 293 self.assertIsNone(b_grad) 294 295 @skip_if_lt_x_gpu(2) 296 def test_fsdp_not_all_outputs_used_in_loss(self): 297 self.run_subtests( 298 { 299 "sharding_strategy": [ 300 ShardingStrategy.FULL_SHARD, 301 ShardingStrategy.SHARD_GRAD_OP, 302 ShardingStrategy.NO_SHARD, 303 ] 304 }, 305 self._test_fsdp_not_all_outputs_used_in_loss, 306 ) 307 308 def _test_fsdp_not_all_outputs_used_in_loss( 309 self, sharding_strategy: ShardingStrategy 310 ): 311 class MyModule(nn.Module): 312 def __init__(self) -> None: 313 super().__init__() 314 self.lin1 = nn.Linear(4, 4) 315 self.lin2 = nn.Linear(4, 4) 316 317 def forward(self, x): 318 a = self.lin1(x) 319 b = self.lin2(x) 320 return (a, b) 321 322 def _check_resharded(fsdp_module): 323 handle = fsdp_module._handle 324 if not handle: 325 return 326 param = handle.flat_param 327 if handle.uses_sharded_strategy: 328 full_param = param._full_param_padded 329 self.assertEqual(full_param.storage().size(), 0) 330 331 self.assertEqual(param.data_ptr(), param._local_shard.data_ptr()) 332 333 def _check_equal(local, fsdp): 334 with FSDP.summon_full_params(fsdp): 335 for p1, p2 in zip(fsdp.parameters(), local.parameters()): 336 torch.testing.assert_close(p1, p2) 337 338 fsdp_ctor = functools.partial(FSDP, sharding_strategy=sharding_strategy) 339 m = MyModule().cuda() 340 m_local = deepcopy(m) 341 local_m = m_local 342 prev_params = [p.clone() for p in m_local.parameters()] 343 344 m.lin1 = fsdp_ctor(m.lin1) 345 m = fsdp_ctor(m) 346 _check_equal(m_local, m) 347 348 opt = torch.optim.SGD(m.parameters(), lr=1e-3) 349 opt_local = torch.optim.SGD(local_m.parameters(), lr=1e-3) 350 351 for i in range(6): 352 t = torch.ones(4, device="cuda") 353 a, b = m(t) 354 local_a, local_b = local_m(t) 355 if i < 2: 356 # use both params in loss computation. Later, 357 # b will go unused and we check grads are the 358 # same as local training. 359 loss = (a @ b).sum() 360 loss_local = (local_a @ local_b).sum() 361 else: 362 loss = a.sum() 363 loss_local = local_a.sum() 364 365 loss.backward() 366 loss_local.backward() 367 _check_resharded(m) 368 opt.step() 369 opt_local.step() 370 _check_equal(m_local, m) 371 # Ensure at least some change from previous params, otherwise 372 # above check would be vacuously true. 373 self.assertTrue( 374 any( 375 not torch.equal(p1, p2) 376 for p1, p2 in zip(prev_params, m_local.parameters()) 377 ) 378 ) 379 prev_params = [p.clone() for p in local_m.parameters()] 380 opt.zero_grad() 381 opt_local.zero_grad() 382 383 dist.barrier() 384 385 @skip_if_lt_x_gpu(2) 386 def test_fsdp_optim_overlap_no_use_orig_params_error(self): 387 fsdp_overlap = FSDP( 388 MyModel().cuda(), 389 auto_wrap_policy=always_wrap_policy, 390 use_orig_params=False, 391 ) 392 optim_cls = torch.optim.SGD 393 optim_kwargs = {"lr": 0.03} 394 _apply_optimizer_in_backward( 395 optimizer_class=optim_cls, 396 params=fsdp_overlap.parameters(), 397 optimizer_kwargs=optim_kwargs, 398 register_hook=False, 399 ) 400 401 inp = torch.randn(10, 10, device="cuda") 402 with self.assertRaisesRegex( 403 RuntimeError, "only supported with use_orig_params=True" 404 ): 405 fsdp_overlap(inp, inp) 406 407 @skip_if_lt_x_gpu(2) 408 def test_fsdp_optimizer_overlap(self): 409 torch.manual_seed(0) 410 for cpu_offload in [True, False]: 411 offload = CPUOffload(offload_params=cpu_offload) 412 model = MyModel().cuda() 413 model_overlap = deepcopy(model) 414 fsdp = FSDP( 415 model.cuda(), 416 auto_wrap_policy=always_wrap_policy, 417 use_orig_params=True, 418 cpu_offload=offload, 419 ) 420 fsdp_overlap = FSDP( 421 model_overlap.cuda(), 422 auto_wrap_policy=always_wrap_policy, 423 use_orig_params=True, 424 cpu_offload=offload, 425 ) 426 optim_cls = torch.optim.SGD 427 optim_kwargs = {"lr": 0.03} 428 _apply_optimizer_in_backward( 429 optimizer_class=optim_cls, 430 params=fsdp_overlap.parameters(), 431 optimizer_kwargs=optim_kwargs, 432 register_hook=False, 433 ) 434 for p in fsdp_overlap.parameters(): 435 assert hasattr(p, "_in_backward_optimizers") 436 optim = optim_cls(fsdp.parameters(), **optim_kwargs) 437 438 # Verify params initially equal 439 for p1, p2 in zip(fsdp.parameters(), fsdp_overlap.parameters()): 440 self.assertEqual(p1, p2) 441 442 with FSDP.summon_full_params(fsdp_overlap): 443 fsdp_overlap_prev_params = [ 444 (n, p.clone()) for n, p in fsdp_overlap.named_parameters() 445 ] 446 447 for i in range(6): 448 inp = torch.randn(2, 2, device="cuda") 449 with torch.no_grad(): 450 inp_clone = inp.clone() 451 fsdp(inp, inp).sum().backward() 452 fsdp_overlap(inp_clone, inp_clone).sum().backward() 453 454 optim.step() 455 optim.zero_grad() 456 457 # Overlapped optimizer FSDP module should have sharded_grad as None. 458 for fsdp_unit in FSDP.fsdp_modules(fsdp_overlap): 459 handle = fsdp_unit._handle 460 if handle: 461 handle_grad = handle.sharded_grad 462 self.assertEqual( 463 None, 464 handle_grad, 465 "Overlapped FSDP sharded_grad is not None!", 466 ) 467 468 # Note: FSDP without optimizer overlap won't set sharded_grad to None until the next 469 # pre-forward since it needs to run FSDP specific logic that picks up that set_to_none=True 470 # has been called (or that the gradients have been otherwise set to None) 471 472 # Verify parameters are different than prev iteration 473 with FSDP.summon_full_params(fsdp_overlap, with_grads=True): 474 for (n, p), (n_prev, p_prev) in zip( 475 fsdp_overlap.named_parameters(), fsdp_overlap_prev_params 476 ): 477 self.assertNotEqual( 478 p, 479 p_prev, 480 f"{n_prev} Params at iter {i} same as previous iter!", 481 ) 482 483 # Verify overlap and non overlapped are the same 484 with FSDP.summon_full_params(fsdp_overlap): 485 with FSDP.summon_full_params(fsdp): 486 for (n_overlap, p_overlap), (n, p) in zip( 487 fsdp_overlap.named_parameters(), fsdp.named_parameters() 488 ): 489 self.assertEqual(n_overlap, n) 490 self.assertEqual( 491 p, 492 p_overlap, 493 f"Rank {self.rank}: Params not equal at iteration {i}: {n_overlap} - {p} vs {p_overlap}", 494 ) 495 self.assertEqual( 496 None, p.grad, f"Expected param {n} grad to be None" 497 ) 498 self.assertEqual( 499 None, 500 p_overlap.grad, 501 f"Expected param {n_overlap} grad to be None", 502 ) 503 504 fsdp_overlap_prev_params = [ 505 (n, p.clone()) for n, p in fsdp_overlap.named_parameters() 506 ] 507 508 @skip_if_lt_x_gpu(2) 509 def test_fsdp_cpu_training(self): 510 """Tests FSDP training on CPU.""" 511 gloo_pg = dist.new_group(backend="gloo") 512 for ss in [ 513 ShardingStrategy.NO_SHARD, 514 ShardingStrategy.FULL_SHARD, 515 ShardingStrategy.SHARD_GRAD_OP, 516 ShardingStrategy.HYBRID_SHARD, 517 ShardingStrategy._HYBRID_SHARD_ZERO2, 518 ]: 519 torch.manual_seed(42) 520 model = MyModel() 521 ref_model = DDP(deepcopy(model), process_group=gloo_pg) 522 model = FSDP( 523 model, 524 auto_wrap_policy=always_wrap_policy, 525 process_group=gloo_pg, 526 device_id=torch.device("cpu"), 527 ) 528 ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) 529 optim = torch.optim.Adam(model.parameters(), lr=1e-2) 530 torch.manual_seed(42 + self.rank) 531 inp = torch.randn(2, 2) 532 for _ in range(10): 533 losses = [] 534 for _model, _optim in ((ref_model, ref_optim), (model, optim)): 535 loss = _model(inp, inp).sum() 536 losses.append(loss) 537 loss.backward() 538 _optim.step() 539 _optim.zero_grad() 540 self.assertEqual(losses[0], losses[1]) 541 542 @skip_if_lt_x_gpu(2) 543 def test_fsdp_cpu_init_stays_on_cpu(self): 544 # Move me to MT test once warning logging and backward collective issue 545 # is resolved. 546 """Tests that passing a CPU module to FSDP preserves that the wrapped 547 module is on CPU after FSDP initialization, albeit after logging a 548 warning, and that FSDP moves CPU input to GPU before the forward.""" 549 torch.cuda.set_device(self.rank) 550 regex = "passed-in `module` is on CPU" 551 context = self.assertWarnsRegex( 552 expected_warning=UserWarning, expected_regex=regex 553 ) 554 with context: 555 nested_wrapped_module = NestedWrappedModule.init( 556 self.process_group, 557 FSDPInitMode.RECURSIVE, 558 CUDAInitMode.CUDA_NEVER, 559 ) 560 fsdp_model = FSDP(nested_wrapped_module, self.process_group) 561 devices = {p.device for p in fsdp_model.parameters()} 562 self.assertEqual(1, len(devices)) 563 self.assertEqual(torch.device("cpu"), devices.pop()) 564 fsdp_model = fsdp_model.cuda() 565 # Ensure fwd + backward can be performed after moving to CUDA. 566 # CPU input also tests that input is correctly moved to appropriate 567 # CUDA device. 568 inp = fsdp_model.module.get_input(device=torch.device("cpu")) 569 fsdp_model(*inp).sum().backward() 570 571 @skip_if_lt_x_gpu(2) 572 def test_cpu_init_with_sync_module_states(self): 573 """ 574 Tests that passing ``sync_module_states=True`` raises an error for 575 a CPU module since the synchronization requires GPU communication, 576 while additionally passing ``device_id`` does not raise an error, even 577 when the model has CPU buffers. 578 """ 579 580 def init_nested_wrapped_module(): 581 return NestedWrappedModule.init( 582 self.process_group, 583 FSDPInitMode.NO_FSDP, 584 CUDAInitMode.CUDA_NEVER, 585 ) 586 587 with self.assertRaisesRegex( 588 ValueError, 589 "The module has CPU parameters or buffers when `sync_module_states=True`", 590 ): 591 FSDP( 592 init_nested_wrapped_module(), 593 self.process_group, 594 sync_module_states=True, 595 ) 596 597 # Check that `device_id` with `sync_module_states=True` works 598 nested_wrapped_module = init_nested_wrapped_module() 599 nested_wrapped_module.buf = nn.Buffer( 600 torch.ones((2, 2), device="cpu") * self.rank 601 ) 602 nested_wrapped_module.module[0].buf = nn.Buffer( 603 torch.ones((3, 2), device="cpu") * self.rank 604 ) 605 nested_wrapped_module = FSDP( 606 nested_wrapped_module, 607 self.process_group, 608 auto_wrap_policy=ModuleWrapPolicy({nn.Linear}), 609 device_id=torch.cuda.current_device(), 610 sync_module_states=True, 611 ) 612 # Each rank's buffers should be 0s since rank 0 is the source, and they 613 # should be on GPU since we specified `device_id` 614 self.assertEqual( 615 nested_wrapped_module.buf.device, 616 torch.device("cuda", torch.cuda.current_device()), 617 ) 618 self.assertEqual(nested_wrapped_module.buf, torch.zeros((2, 2))) 619 self.assertEqual( 620 nested_wrapped_module.module.module[0].buf.device, 621 torch.device("cuda", torch.cuda.current_device()), 622 ) 623 self.assertEqual( 624 nested_wrapped_module.module.module[0].buf, torch.zeros((3, 2)) 625 ) 626 627 628class TestFSDPMiscMultiThread(FSDPTestMultiThread): 629 @property 630 def world_size(self): 631 return 2 632 633 @property 634 def process_group(self): 635 return dist.distributed_c10d._get_default_group() 636 637 @skip_if_lt_x_gpu(2) 638 def test_fsdp_namedtuple(self): 639 class MyModule(nn.Module): 640 def __init__(self) -> None: 641 super().__init__() 642 self.lin = nn.Linear(100, 100) 643 644 def forward(self, x): 645 return x 646 647 m = MyModule().cuda() 648 m = FSDP(m) 649 t = torch.ones(1, device="cuda", requires_grad=True) 650 651 MyOutputType = namedtuple( 652 "MyOutputType", ["a", "b", "c", "d"], defaults=(t, t, t, t) 653 ) 654 655 inp = MyOutputType() 656 out = m(inp) 657 # Ensure hooks are registered 658 for x in out: 659 self.assertNotEqual([], list(x._backward_hooks.values())) 660 661 # TODO: we should check backward() and param is resharded 662 # as well, but this is blocked by 663 # https://github.com/pytorch/pytorch/issues/83107 and 664 # https://github.com/pytorch/pytorch/issues/83129 665 666 @skip_if_lt_x_gpu(2) 667 def test_device_id_auto_wrap(self): 668 """Tests that ``auto_wrap_policy`` propagates ``device_id`` to all 669 nested FSDP instances.""" 670 self.run_subtests( 671 {"use_callable": [False, True]}, 672 self._test_device_id_auto_wrap, 673 ) 674 675 def _test_device_id_auto_wrap(self, use_callable: bool): 676 module_classes = {TransformerEncoderLayer, TransformerDecoderLayer} 677 if use_callable: 678 auto_wrap_policy = functools.partial( 679 transformer_auto_wrap_policy, 680 transformer_layer_cls=module_classes, 681 ) 682 else: 683 auto_wrap_policy = ModuleWrapPolicy(module_classes) 684 fsdp_kwargs = { 685 "auto_wrap_policy": auto_wrap_policy, 686 "device_id": torch.cuda.current_device(), 687 } 688 fsdp_model = TransformerWithSharedParams.init( 689 self.process_group, 690 FSDPInitMode.RECURSIVE, 691 CUDAInitMode.CUDA_BEFORE, 692 fsdp_kwargs, 693 ) 694 for fsdp_module in FSDP.fsdp_modules(fsdp_model): 695 self.assertEqual( 696 fsdp_module.compute_device, 697 torch.device("cuda", torch.cuda.current_device()), 698 ) 699 700 @skip_if_lt_x_gpu(2) 701 def test_fsdp_device_id_cpu_offload(self): 702 """ 703 Tests FSDP when specifying both ``device_id`` and parameter CPU 704 offloading. 705 """ 706 self.run_subtests( 707 {"use_orig_params": [False, True]}, 708 self._test_fsdp_device_id_cpu_offload, 709 ) 710 711 def _test_fsdp_device_id_cpu_offload(self, use_orig_params: bool): 712 class MyModel(nn.Module): 713 def __init__(self) -> None: 714 super().__init__() 715 self.seq = nn.Sequential( 716 nn.Linear(10, 10), 717 nn.Linear(10, 10), 718 ) 719 self.lin = nn.Linear(10, 10) 720 721 def forward(self, x): 722 return self.lin(self.seq(x)) 723 724 model = MyModel() 725 # Choose a wrapping policy such that there are (1) nested FSDP 726 # instances and (2) the parent FSDP instance has managed parameters 727 auto_wrap_policy = ModuleWrapPolicy({nn.Sequential}) 728 fsdp_model = FSDP( 729 model, 730 auto_wrap_policy=auto_wrap_policy, 731 cpu_offload=CPUOffload(offload_params=True), 732 device_id=torch.cuda.current_device(), 733 use_orig_params=use_orig_params, 734 ) 735 cpu_device = torch.device("cpu") 736 for handle in traversal_utils._get_fsdp_handles(fsdp_model): 737 self.assertEqual(handle.flat_param.device, cpu_device) 738 739 @skip_if_lt_x_gpu(2) 740 def test_module_device_mismatches_device_id(self): 741 """Tests that specifying a ``device_id`` argument to FSDP for a GPU 742 module that does not match the GPU device ID raises an error.""" 743 # TODO: override FSDP MT Thread _run to set this instead of here for 744 # every test. 745 torch.cuda.set_device(self.rank) 746 context = ( 747 self.assertRaisesRegex(ValueError, f"cuda:{self.rank} vs cuda:0") 748 if self.rank != 0 749 else nullcontext() 750 ) 751 with context: 752 NestedWrappedModule.init( 753 self.process_group, 754 FSDPInitMode.RECURSIVE, 755 # Move wrapped modules to CUDA before wrapping with FSDP 756 cuda_init_mode=CUDAInitMode.CUDA_BEFORE, 757 # Should raise error since rank 1 is given `device_id=0` when 758 # the model is on cuda:1 759 fsdp_kwargs={"device_id": 0}, 760 ) 761 762 @skip_if_lt_x_gpu(2) 763 def test_cpu_gpu_module(self): 764 """Tests a CPU + GPU module supported if device_id is passed 765 in, errors if device_id is not. 766 """ 767 torch.cuda.set_device(self.rank) 768 769 class CPUGPUModule(nn.Module): 770 def __init__(self) -> None: 771 super().__init__() 772 self.a = nn.Linear(1, 1).cuda() 773 self.b = nn.Linear(1, 1) 774 775 cpu_gpu = CPUGPUModule() 776 fsdp = FSDP(cpu_gpu, device_id=torch.cuda.current_device()) 777 for param in fsdp.parameters(): 778 self.assertEqual(param.device, torch.device(torch.cuda.current_device())) 779 780 # without device_id, we hit an error 781 with self.assertRaisesRegex(RuntimeError, "please pass in device_id"): 782 FSDP(CPUGPUModule()) 783 784 @skip_if_lt_x_gpu(2) 785 def test_fsdp_ignored_module_meta(self): 786 torch.cuda.set_device(self.rank) 787 788 class CPUGPUModule(nn.Module): 789 def __init__(self) -> None: 790 super().__init__() 791 self.a = nn.Linear(1, 1) 792 self.b = nn.Linear(1, 1) 793 794 with torch.device("meta"): 795 m = CPUGPUModule() 796 m = FSDP(m, device_id=self.rank, ignored_modules=[m.a], use_orig_params=True) 797 meta_device = torch.device("meta") 798 self.assertEqual(meta_device, next(m.a.parameters()).device) 799 800 # Test with param_init_fn 801 with torch.device("meta"): 802 m = CPUGPUModule() 803 m = FSDP( 804 m, 805 device_id=torch.cuda.current_device(), 806 ignored_modules=[m.a], 807 use_orig_params=True, 808 param_init_fn=lambda m: m.to_empty( 809 device=torch.cuda.current_device(), recurse=False 810 ), 811 ) 812 self.assertEqual(meta_device, next(m.a.parameters()).device) 813 814 @skip_if_lt_x_gpu(2) 815 def test_fsdp_device_id_no_move_ignored_params_and_bufs(self): 816 class CPUGPUModule(nn.Module): 817 def __init__(self) -> None: 818 super().__init__() 819 self.a = nn.Linear(1, 1) 820 self.b = nn.Linear(1, 1) 821 self.a.buf = torch.nn.Buffer(torch.ones(1)) 822 823 m = CPUGPUModule() 824 m = FSDP(m, device_id=self.rank, ignored_modules=[m.a], use_orig_params=True) 825 ignored_params = m.a.parameters() 826 ignored_bufs = m.a.buffers() 827 for t in chain(ignored_params, ignored_bufs): 828 self.assertEqual(torch.device("cpu"), t.device) 829 830 @skip_if_lt_x_gpu(2) 831 def test_multigpu_module(self): 832 """ 833 Module on multiple GPUs wrapped in FSDP should raise an error. 834 """ 835 836 class MultiGPUModule(nn.Module): 837 def __init__(self, rank): 838 super().__init__() 839 self.rank = rank 840 self.a = nn.Linear(1, 1).cuda(self.rank) 841 self.b = nn.Linear(1, 1).cuda((self.rank + 1) % dist.get_world_size()) 842 843 with self.assertRaisesRegex( 844 RuntimeError, "FSDP only supports single device modules" 845 ): 846 FSDP(MultiGPUModule(self.rank)) 847 848 @skip_if_lt_x_gpu(2) 849 def test_no_params(self): 850 """ 851 Test that device_id and cpu init work if module has no params 852 (they are effective noops, but ensure FSDP does not assume module 853 has parameters during init) 854 """ 855 # TODO: override FSDP MT Thread _run to set this instead of here for 856 # every test. 857 torch.cuda.set_device(self.rank) 858 # Test CPU 859 no_params = nn.ReLU() 860 module = FSDP(no_params) 861 # Test CUDA 862 no_params = nn.ReLU().cuda() 863 module = FSDP(no_params) 864 # Test CPU + device_id 865 no_params = nn.ReLU() 866 module = FSDP(no_params, device_id=torch.cuda.current_device()) 867 # For modules with no params, wrong device_id will raise error about 868 # inconsistency between compute_device and device_id, since compute_device 869 # is computed as torch.cuda.current_device when there are no params. 870 no_params = nn.ReLU().cuda() 871 context = ( 872 ( 873 self.assertRaisesRegex( 874 ValueError, f"Inconsistent.*cuda:{self.rank} vs cuda:0" 875 ) 876 ) 877 if self.rank != 0 878 else nullcontext() 879 ) 880 with context: 881 FSDP(no_params, device_id=0) 882 883 @skip_if_lt_x_gpu(2) 884 def test_fsdp_same_model_across_ranks(self): 885 """ 886 FSDP broadcasts model from rank 0 to ensure it starts off with the same 887 values. 888 """ 889 890 class MyModel(nn.Module): 891 def __init__(self, rank): 892 super().__init__() 893 # Seed via rank to make model different across ranks 894 torch.manual_seed(rank) 895 torch.cuda.manual_seed(rank) 896 self.lin = nn.Linear(10, 10, bias=False) 897 self.buffer = nn.Buffer(torch.ones(1) * rank) 898 899 m = MyModel(self.rank).cuda() 900 _assert_module_states( 901 m, process_group=self.process_group, assert_fn=self.assertNotEqual 902 ) 903 # Passing sync_module_states into FSDP makes model the same during init. 904 fsdp = FSDP(m, sync_module_states=True) 905 with fsdp.summon_full_params(fsdp): 906 _assert_module_states( 907 fsdp, process_group=self.process_group, assert_fn=self.assertEqual 908 ) 909 910 # sync_module_states also works with CPU module with device_id passed in 911 m = MyModel(self.rank) 912 _assert_module_states( 913 m, process_group=self.process_group, assert_fn=self.assertNotEqual 914 ) 915 # Passing sync_module_states into FSDP makes model the same during init. 916 fsdp = FSDP(m, device_id=torch.cuda.current_device(), sync_module_states=True) 917 with fsdp.summon_full_params(fsdp): 918 _assert_module_states( 919 fsdp, process_group=self.process_group, assert_fn=self.assertEqual 920 ) 921 922 @skip_if_lt_x_gpu(2) 923 def test_homogeneous_attributes(self): 924 """ 925 Tests that passing heterogeneous values for attributes designated as 926 homogeneous raises an error. 927 """ 928 # Manually construct this list but verify against the global list of 929 # homogeneous attribute names 930 all_attr_name_and_values = [ 931 ("_use_orig_params", False, True), 932 ("limit_all_gathers", False, True), 933 ("_use_full_prec_in_eval", False, True), 934 ] 935 self.assertEqual( 936 [ 937 attr_name_and_values[0] 938 for attr_name_and_values in all_attr_name_and_values 939 ], 940 HOMOGENEOUS_ATTR_NAMES, 941 ) 942 943 self.run_subtests( 944 {"attr_name_and_values": all_attr_name_and_values}, 945 self._test_homogeneous_attributes, 946 ) 947 948 def _test_homogeneous_attributes(self, attr_name_and_values: Tuple[str, Any, Any]): 949 model = NestedWrappedModule.init( 950 self.process_group, 951 FSDPInitMode.NO_FSDP, 952 CUDAInitMode.CUDA_BEFORE, 953 {}, 954 ) 955 attr_name = attr_name_and_values[0] 956 957 if "_use_full_prec_in_eval" == attr_name: 958 model.module[1] = FSDP(model.module[1]) 959 os.environ["FSDP_USE_FULL_PREC_IN_EVAL"] = "1" 960 fsdp_model = FSDP(model) 961 else: 962 fsdp_kwargs_inner = {attr_name.lstrip("_"): attr_name_and_values[1]} 963 fsdp_kwargs_outer = {attr_name.lstrip("_"): attr_name_and_values[2]} 964 model.module[1] = FSDP(model.module[1], **fsdp_kwargs_inner) 965 fsdp_model = FSDP(model, **fsdp_kwargs_outer) 966 967 # Run a forward to trigger lazy initialization and the error 968 with self.assertRaisesRegex( 969 ValueError, f"Expects one homogeneous value for {attr_name}" 970 ): 971 inp = fsdp_model.module.get_input(torch.device("cuda")) 972 fsdp_model(*inp) 973 974 @skip_if_lt_x_gpu(2) 975 def test_fsdp_unsupported_module_cls(self): 976 regex = r"FSDP will not all-gather parameters for containers that do not implement forward" 977 model = nn.ModuleList([MLP(8, torch.device("cpu")) for _ in range(3)]) 978 with self.assertWarnsRegex(UserWarning, regex): 979 FSDP(model, device_id="cuda") 980 model = nn.ModuleDict( 981 {"1": MLP(8, torch.device("cpu")), "2": MLP(8, torch.device("cpu"))} 982 ) 983 with self.assertWarnsRegex(UserWarning, regex): 984 FSDP(model) 985 986 987class TestFSDPMiscWorldSize1(FSDPTestMultiThread): 988 @property 989 def world_size(self) -> int: 990 return 1 991 992 @skip_if_lt_x_gpu(1) 993 def test_world_size_1_sharding_strategy_warning(self): 994 """ 995 Tests that FSDP issues a warning when it switches to using ``NO_SHARD`` 996 when the world size is 1. 997 """ 998 warning_prefix = "FSDP is switching to use `NO_SHARD` instead of" 999 # If the user already passes `NO_SHARD`, then there should not be a 1000 # warning 1001 with warnings.catch_warnings(record=True) as w: 1002 warnings.simplefilter("always") # trigger all warnings 1003 FSDP(nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.NO_SHARD) 1004 for warning in w: 1005 self.assertTrue( 1006 warning.category != UserWarning 1007 or not str(warning.message).startswith(warning_prefix) 1008 ) 1009 1010 # Check that a warning is issued 1011 warning_suffix = " since the world size is 1." 1012 # - Pass `FULL_SHARD` or `None` 1013 expected_regex_full_shard = ( 1014 warning_prefix + " " + str(ShardingStrategy.FULL_SHARD) + warning_suffix 1015 ) 1016 with self.assertWarnsRegex(UserWarning, expected_regex_full_shard): 1017 FSDP(nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.FULL_SHARD) 1018 with self.assertWarnsRegex(UserWarning, expected_regex_full_shard): 1019 FSDP(nn.Linear(3, 3).cuda()) 1020 # - Pass `SHARD_GRAD_OP` 1021 expected_regex_shard_grad_op = ( 1022 warning_prefix + " " + str(ShardingStrategy.SHARD_GRAD_OP) + warning_suffix 1023 ) 1024 with self.assertWarnsRegex(UserWarning, expected_regex_shard_grad_op): 1025 FSDP( 1026 nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.SHARD_GRAD_OP 1027 ) 1028 1029 @skip_if_lt_x_gpu(1) 1030 def test_training_device_mismatch_errors(self): 1031 """ 1032 Tests that, when training starts, if FSDP parameters are not on the 1033 expected device, then an informative error is raised. This applies for 1034 both no parameter CPU offloading and parameter CPU offloading. 1035 """ 1036 # Incorrectly not moving from CPU -> GPU 1037 model = torch.nn.Linear(10, 10) 1038 fsdp_model = FSDP(model) 1039 inp = torch.randn((2, 10)) 1040 with self.assertRaisesRegex( 1041 RuntimeError, 1042 "An FSDP-managed module unexpectedly has parameters on cpu. Make " 1043 "sure to move the module to cuda:0 before training.", 1044 ): 1045 fsdp_model(inp) 1046 1047 # Incorrectly moving from CPU -> GPU 1048 model = torch.nn.Linear(10, 10) 1049 fsdp_model = FSDP(model, cpu_offload=CPUOffload(offload_params=True)) 1050 fsdp_model.to(torch.device("cuda")) 1051 inp = torch.randn((2, 10)) 1052 with self.assertRaisesRegex( 1053 RuntimeError, 1054 "An FSDP-managed module with parameter CPU offloading enabled has " 1055 "parameters on cuda:0. Make sure to not move the module from CPU " 1056 "when offloading parameters.", 1057 ): 1058 fsdp_model(inp) 1059 1060 @skip_if_lt_x_gpu(2) 1061 def test_unsafe_setattr(self): 1062 """ 1063 Tests that the environment variable for using unsafe setattr gates as 1064 expected. 1065 """ 1066 self.run_subtests( 1067 {"use_orig_params": [False, True]}, 1068 self._test_unsafe_setattr, 1069 ) 1070 1071 def _test_unsafe_setattr(self, use_orig_params: bool): 1072 called_setattr_override = False 1073 1074 class SetattrLinear(nn.Module): 1075 def __init__(self, in_dim: int, out_dim: int, device: torch.device) -> None: 1076 super().__init__() 1077 self.weight = nn.Parameter( 1078 torch.randn((in_dim, out_dim), device=device) 1079 ) 1080 1081 def forward(self, x: torch.Tensor) -> torch.Tensor: 1082 return x @ self.weight 1083 1084 def __setattr__(self, name: str, value: Any) -> None: 1085 nonlocal called_setattr_override 1086 called_setattr_override = True 1087 return super().__setattr__(name, value) 1088 1089 # Construct FSDP module without changing any environment variables and 1090 # run forward, which triggers both unsharded and sharded view setting 1091 module = SetattrLinear(5, 5, torch.device("cuda")) 1092 fsdp_module = FSDP(module, use_orig_params=use_orig_params) 1093 inp = torch.randn((8, 5), device=torch.device("cuda")) 1094 called_setattr_override = False 1095 fsdp_module(inp) 1096 self.assertTrue(called_setattr_override) 1097 1098 # Repeat with unsafe setattr explicitly enabled 1099 os.environ[_FSDP_USE_UNSAFE_SETATTR] = "1" 1100 module = SetattrLinear(5, 5, torch.device("cuda")) 1101 fsdp_module = FSDP(module, use_orig_params=use_orig_params) 1102 called_setattr_override = False 1103 fsdp_module(inp) 1104 self.assertFalse(called_setattr_override) 1105 1106 # Repeat with unsafe setattr explicitly disabled 1107 os.environ[_FSDP_USE_UNSAFE_SETATTR] = "0" 1108 module = SetattrLinear(5, 5, torch.device("cuda")) 1109 fsdp_module = FSDP(module, use_orig_params=use_orig_params) 1110 called_setattr_override = False 1111 fsdp_module(inp) 1112 self.assertTrue(called_setattr_override) 1113 1114 1115instantiate_parametrized_tests(TestFSDPMiscMultiThread) 1116instantiate_parametrized_tests(TestFSDPMiscMultiProcess) 1117 1118if __name__ == "__main__": 1119 run_tests() 1120