1# Owner(s): ["oncall: distributed"] 2 3import contextlib 4import itertools 5import os 6import sys 7from functools import partial 8from itertools import product 9from typing import Any, Dict, List 10 11import torch 12import torch.cuda.nccl as nccl 13import torch.nn as nn 14import torch.nn.functional as F 15from torch import distributed as dist 16from torch.distributed._composable import fully_shard 17from torch.distributed.fsdp import ( 18 BackwardPrefetch, 19 CPUOffload, 20 FullyShardedDataParallel as FSDP, 21 MixedPrecision, 22 ShardingStrategy, 23) 24from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler 25from torch.distributed.fsdp.wrap import ModuleWrapPolicy, size_based_auto_wrap_policy 26from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer 27from torch.nn.modules.batchnorm import _BatchNorm 28from torch.optim.swa_utils import AveragedModel 29from torch.testing._internal.common_distributed import ( 30 SaveForwardInputsModel, 31 skip_if_lt_x_gpu, 32) 33from torch.testing._internal.common_fsdp import ( 34 CUDAInitMode, 35 FSDPInitMode, 36 FSDPTest, 37 subtest_name, 38 TransformerWithSharedParams, 39) 40from torch.testing._internal.common_utils import ( 41 instantiate_parametrized_tests, 42 parametrize, 43 run_tests, 44 skip_but_pass_in_sandcastle_if, 45 TEST_WITH_DEV_DBG_ASAN, 46) 47 48 49try: 50 import torchvision 51 52 HAS_TORCHVISION = True 53except ImportError: 54 HAS_TORCHVISION = False 55 56skipIfNoTorchVision = skip_but_pass_in_sandcastle_if( 57 not HAS_TORCHVISION, "no torchvision" 58) 59 60 61if not dist.is_available(): 62 print("Distributed not available, skipping tests", file=sys.stderr) 63 sys.exit(0) 64 65if TEST_WITH_DEV_DBG_ASAN: 66 print( 67 "Skip dev-asan as torch + multiprocessing spawn have known issues", 68 file=sys.stderr, 69 ) 70 sys.exit(0) 71 72# Various mixed precision configs to test under. 73default_mp = MixedPrecision( 74 param_dtype=torch.float16, 75 buffer_dtype=torch.float16, 76 reduce_dtype=torch.float16, 77) 78 79# Params and buffers are not cast, comm only happens 80# in reduced precision. 81mp_only_reduce = MixedPrecision(reduce_dtype=torch.float16) 82 83# Only parameters are cast (thus comm should happen in the param_dtype precision) 84mp_only_param_and_buf = MixedPrecision( 85 param_dtype=torch.float16, buffer_dtype=torch.float16 86) 87 88# Nothing is cast (thus param, comm, grad, and buffer should be in the full precision) 89mp_no_mixed_precision = MixedPrecision() 90 91nccl_supports_bf16 = dist.is_nccl_available() and nccl.version() >= (2, 10) 92 93mp_configs = [default_mp, mp_only_reduce, mp_only_param_and_buf, mp_no_mixed_precision] 94if nccl_supports_bf16: 95 mp_diff_buffer_and_reduce = MixedPrecision( 96 param_dtype=torch.float16, 97 buffer_dtype=torch.bfloat16, 98 reduce_dtype=torch.float32, 99 ) 100 mp_configs.extend([mp_diff_buffer_and_reduce]) 101 102# Buffer original dtype, which can differ from model params. 103_BUFFER_ORIG_DTYPE = torch.float64 104 105params = "mp_config,cpu_offload,full_precision_param_dtype,enable_sharded_grad_scaler" 106cpu_offload_config = [CPUOffload(offload_params=True), CPUOffload(offload_params=False)] 107full_precision_param_dtype_config = [torch.float32, torch.float64] 108enable_sharded_grad_scaler = ["enable_sharded_grad_scaler", None] 109 110configs = list( 111 product( 112 mp_configs, 113 cpu_offload_config, 114 full_precision_param_dtype_config, 115 enable_sharded_grad_scaler, 116 ) 117) 118 119test_name_mapping = { 120 str(CPUOffload(offload_params=True)): "offload_true", 121 str(CPUOffload(offload_params=False)): "offload_false", 122 str(default_mp): "mp_fp16", 123 str(mp_only_reduce): "mp_only_reduce", 124 str(mp_only_param_and_buf): "mp_only_param_and_buf", 125 str(mp_no_mixed_precision): "mp_no_mp", 126 str(torch.float32): "fp32", 127 str(torch.float64): "fp64", 128 "enable_sharded_grad_scaler": "enable_sharded_grad_scaler", 129} 130 131if nccl_supports_bf16: 132 test_name_mapping.update( 133 { 134 str(mp_diff_buffer_and_reduce): "mp_diff_buffer_reduce", 135 } 136 ) 137 138subtest_name = partial(subtest_name, test_name_mapping) 139 140_CURRENT_FULL_PRECISION_PARAM_DTYPE = None 141 142 143@contextlib.contextmanager 144def patch_reduce_scatter(new_reduce_scatter, full_precision_param_dtype): 145 """ 146 Patches ``dist.reduce_scatter_tensor`` with ``new_reduce_scatter`` and 147 restores upon exiting. Used for validation of mixed precision. 148 """ 149 orig_reduce_scatter = dist.reduce_scatter_tensor 150 dist.reduce_scatter_tensor = new_reduce_scatter 151 global _CURRENT_FULL_PRECISION_PARAM_DTYPE 152 _CURRENT_FULL_PRECISION_PARAM_DTYPE = full_precision_param_dtype 153 try: 154 yield 155 finally: 156 dist.reduce_scatter_tensor = orig_reduce_scatter 157 _CURRENT_FULL_PRECISION_PARAM_DTYPE = None 158 159 160class LinearMixedPrecision(nn.Module): 161 """ 162 A linear module with extra checks for mixed precision training. 163 """ 164 165 def __init__(self, param_dtype, buffer_name="buffer", run_checks=True): 166 super().__init__() 167 self.lin = nn.Linear(10, 10, bias=False).to(param_dtype) 168 # Use a configurable buffer name to avoid all submodules sharing the 169 # same buffer name, which may hide prefixed vs. unprefixed name bugs 170 self.buffer_name = buffer_name 171 self.register_buffer(buffer_name, torch.randn((1, 2), dtype=_BUFFER_ORIG_DTYPE)) 172 self._orig_param_type = param_dtype 173 self._orig_buffer_dtype = _BUFFER_ORIG_DTYPE 174 self.run_checks = run_checks 175 176 def forward(self, tup): 177 inp, cls, fsdp, mp_config, full_precision_param_dtype = tup 178 if self.run_checks: 179 # Param and input should be the mixed precision type 180 expected_param_type = ( 181 mp_config.param_dtype 182 if mp_config.param_dtype is not None 183 else self._orig_param_type 184 ) 185 expected_buffer_type = ( 186 mp_config.buffer_dtype 187 if mp_config.buffer_dtype is not None 188 else self._orig_buffer_dtype 189 ) 190 cls.assertEqual(inp.dtype, expected_param_type) 191 # Buffer should be in specified precision as well. 192 cls.assertEqual(getattr(self, self.buffer_name).dtype, expected_buffer_type) 193 194 # In FSDP, self.params should point to the right type. 195 num_active_fsdp = 0 196 for fsdp_module in FSDP.fsdp_modules(fsdp): 197 fsdp_managed_params = fsdp_module.params 198 # Single param assumption 199 cls.assertEqual(1, len(fsdp_managed_params)) 200 for param in fsdp_managed_params: 201 # FSDP unit is currently active if it is not using the param 202 # local shard. This supports both FULL_SHARD and SHARD_GRAD_OP 203 # cases. In FULL_SHARD, we have the additional property that 204 # param._full_param_padded has not been freed. 205 param_is_sharded = ( 206 fsdp_module.sharding_strategy != ShardingStrategy.NO_SHARD 207 and fsdp_module.world_size > 1 208 ) 209 is_fsdp_unit_active = ( 210 param_is_sharded 211 and param.data.data_ptr() != param._local_shard.data_ptr() 212 ) 213 if is_fsdp_unit_active: 214 num_active_fsdp += 1 215 # This FSDP unit is active, verify param points to mixed 216 cls.assertEqual(param.dtype, expected_param_type) 217 # _unshard should have also freed the fp16 shard. 218 # Shard is never allocated if param_dtype mixed precision is not 219 # enabled. 220 if mp_config.param_dtype is not None: 221 cls.assertEqual(0, param._mp_shard.untyped_storage().size()) 222 else: 223 cls.assertFalse(hasattr(param, "_mp_shard")) 224 elif param_is_sharded: 225 # This FSDP unit is not active as full param has been 226 # freed or not yet allocated. Ensure param points to full 227 # precision param. 228 cls.assertEqual(param.dtype, full_precision_param_dtype) 229 # We should have gotten at least one active FSDP unit for sharded 230 # (world size > 1) cases. For cases where param is not sharded 231 # (ie world_size == 1) it is a bit hard to check if FSDP unit is active 232 # as we'd always point to the local shard, so we rely on the forward 233 # pass self.lin(inp) working well and inp being reduced precision to 234 # implicitly validate that the param is indeed in the reduced precision. 235 if cls.world_size > 1: 236 cls.assertGreater(num_active_fsdp, 0) 237 238 return (self.lin(inp), cls, fsdp, mp_config, full_precision_param_dtype) 239 240 241class TestFSDPMixedPrecision(FSDPTest): 242 @property 243 def world_size(self): 244 raise ValueError("To be implemented by child classes") 245 246 def _get_simple_nested_model( 247 self, param_dtype, run_checks, *fsdp_args, **fsdp_kwargs 248 ): 249 model = FSDP( 250 nn.Sequential( 251 FSDP( 252 LinearMixedPrecision( 253 param_dtype, buffer_name="buffer0", run_checks=run_checks 254 ).cuda(), 255 *fsdp_args, 256 **fsdp_kwargs, 257 ), 258 LinearMixedPrecision( 259 param_dtype, buffer_name="buffer1", run_checks=run_checks 260 ).cuda(), 261 ), 262 *fsdp_args, 263 **fsdp_kwargs, 264 ) 265 return model 266 267 def _get_simple_nested_model_composable( 268 self, param_dtype, run_checks, *fsdp_args, **fsdp_kwargs 269 ): 270 model = nn.Sequential( 271 LinearMixedPrecision( 272 param_dtype, buffer_name="buffer0", run_checks=run_checks 273 ).cuda(), 274 LinearMixedPrecision( 275 param_dtype, buffer_name="buffer1", run_checks=run_checks 276 ).cuda(), 277 ) 278 fully_shard(model[0], *fsdp_args, **fsdp_kwargs) 279 fully_shard(model, *fsdp_args, **fsdp_kwargs) 280 return model 281 282 def _get_simple_model(self, param_dtype, *fsdp_args, **fsdp_kwargs): 283 model = FSDP( 284 LinearMixedPrecision(param_dtype).cuda(), *fsdp_args, **fsdp_kwargs 285 ) 286 return model 287 288 def _validate_no_mp_shard(self, fsdp_model): 289 """ 290 Validates that there is no mixed precision _mp_shard allocated 291 when it is not expected to be. 292 """ 293 fsdp_units = FSDP.fsdp_modules(fsdp_model) 294 for fsdp in fsdp_units: 295 for param in fsdp.params: 296 self.assertFalse(hasattr(param, "_mp_shard")) 297 298 def _validate_mp_shard_freed(self, fsdp_model): 299 """ 300 Ensures that the mixed precision shard is greed for all FSDP units. 301 """ 302 fsdp_units = FSDP.fsdp_modules(fsdp_model) 303 for fsdp in fsdp_units: 304 for param in fsdp.params: 305 self.assertEqual(0, param._mp_shard.untyped_storage().size()) 306 307 def _reduce_scatter_validate_mp( 308 self, orig_reduce_scatter, mp_config, should_run_low_prec, *args, **kwargs 309 ): 310 """ 311 Runs reduce-scatter but verifies mixed precision settings before. This 312 is to test mixed precision is working as expected during backward pass. 313 In particular it ensures that the gradients were cast to the right type 314 and comm. is going to happen in the right type. 315 """ 316 tensors = [] 317 for x in args: 318 if isinstance(x, torch.Tensor): 319 tensors.append(x) 320 for x in kwargs.values(): 321 if isinstance(x, torch.Tensor): 322 tensors.append(x) 323 324 # reduce_dtype has higher priority than param_dtype, because mixed_precision 325 # supports overriding param_dtype with reduce_dtype to control the 326 # reduction precision. In the case where reduce_dtype == param_dtype 327 # this tests that gradients are in the expected precision as well. 328 # If reduce_dtype is not specified (is None) we comm. in the param_dtype 329 # if that is specified, otherwise full precision dtype. 330 if should_run_low_prec: 331 expected_dtype = ( 332 mp_config.reduce_dtype 333 if mp_config.reduce_dtype is not None 334 else ( 335 mp_config.param_dtype 336 if mp_config.param_dtype is not None 337 else _CURRENT_FULL_PRECISION_PARAM_DTYPE 338 ) 339 ) 340 else: 341 expected_dtype = _CURRENT_FULL_PRECISION_PARAM_DTYPE 342 343 for t in tensors: 344 self.assertEqual( 345 expected_dtype, 346 t.dtype, 347 f"Expected to reduce in {expected_dtype} but got tensors in {t.dtype}", 348 ) 349 350 return orig_reduce_scatter(*args, **kwargs) 351 352 def _test_grads_reduced_precision( 353 self, offload_params: bool, use_orig_params: bool 354 ): 355 class MyModel(nn.Module): 356 def __init__(self) -> None: 357 super().__init__() 358 self.lin1 = nn.Linear(10, 10) 359 self.lin2 = nn.Linear(10, 10) 360 361 def forward(self, x): 362 return self.lin2(self.lin1(x)) 363 364 m = MyModel().cuda() 365 mp = MixedPrecision( 366 param_dtype=torch.float16, 367 reduce_dtype=torch.float16, 368 buffer_dtype=torch.float16, 369 keep_low_precision_grads=True, 370 ) 371 fsdp_kwargs = { 372 "mixed_precision": mp, 373 "cpu_offload": CPUOffload(offload_params=offload_params), 374 "use_orig_params": use_orig_params, 375 } 376 m.lin1 = FSDP(m.lin1, **fsdp_kwargs) 377 m = FSDP(m, **fsdp_kwargs) 378 for _ in range(6): 379 inp = torch.ones(1, 10) 380 m(inp).sum().backward() 381 for param in m.parameters(): 382 if param.grad is not None: 383 self.assertEqual(torch.float16, param.grad.dtype) 384 385 dist.barrier() 386 387 def _run_test_mixed_precision_e2e( 388 self, 389 mp_config, 390 cpu_offload, 391 backward_prefetch, 392 forward_prefetch, 393 full_precision_param_dtype, 394 sharding_strategy, 395 enable_sharded_grad_scaler, 396 ): 397 torch.cuda.set_device(self.rank) 398 fsdp_models = [ 399 self._get_simple_model( 400 param_dtype=full_precision_param_dtype, 401 sharding_strategy=sharding_strategy, 402 cpu_offload=cpu_offload, 403 mixed_precision=mp_config, 404 backward_prefetch=backward_prefetch, 405 forward_prefetch=forward_prefetch, 406 ), 407 self._get_simple_nested_model( 408 param_dtype=full_precision_param_dtype, 409 run_checks=True, 410 sharding_strategy=sharding_strategy, 411 cpu_offload=cpu_offload, 412 mixed_precision=mp_config, 413 backward_prefetch=backward_prefetch, 414 forward_prefetch=forward_prefetch, 415 ), 416 ] 417 for model in fsdp_models: 418 if not cpu_offload.offload_params: 419 model.cuda() 420 421 # Patch reduce_scatter to add validation for mixed precision types. 422 orig_reduce_scatter = dist.reduce_scatter_tensor 423 test_reduce_scatter = partial( 424 self._reduce_scatter_validate_mp, 425 orig_reduce_scatter, 426 mp_config, 427 True, 428 ) 429 with patch_reduce_scatter(test_reduce_scatter, full_precision_param_dtype): 430 scaler = ShardedGradScaler(enabled=enable_sharded_grad_scaler) 431 optim = torch.optim.Adam(model.parameters()) 432 433 for _ in range(3): 434 inp = torch.randn( 435 3, 10, device="cuda", dtype=full_precision_param_dtype 436 ) 437 # Forward pass of LinearMixedPrecision check casting of 438 # inputs, params, buffers. 439 act, *_ = model( 440 (inp, self, model, mp_config, full_precision_param_dtype) 441 ) 442 # Buffers should be casted. 443 for buf in model.buffers(): 444 if mp_config.buffer_dtype is not None: 445 self.assertEqual(buf.dtype, mp_config.buffer_dtype) 446 else: 447 self.assertEqual(buf.dtype, _BUFFER_ORIG_DTYPE) 448 # p._mp_shard should be freed. 449 if mp_config.param_dtype is not None: 450 self._validate_mp_shard_freed(model) 451 else: 452 # We never should have allocated an _mp_shard. 453 self._validate_no_mp_shard(model) 454 455 loss = act.sum() 456 loss = scaler.scale(loss) 457 if mp_config.param_dtype is not None: 458 self.assertEqual(loss.dtype, mp_config.param_dtype) 459 else: 460 self.assertEqual(loss.dtype, full_precision_param_dtype) 461 # Will run patched reduce scatter that validates mixed_precision 462 # types in backward. 463 loss.backward() 464 # Buffers stay casted even after backwards. 465 for buf in model.buffers(): 466 if mp_config.buffer_dtype is not None: 467 self.assertEqual(buf.dtype, mp_config.buffer_dtype) 468 else: 469 self.assertEqual(buf.dtype, _BUFFER_ORIG_DTYPE) 470 # p._mp_shard should be freed. 471 if mp_config.param_dtype is not None: 472 self._validate_mp_shard_freed(model) 473 else: 474 self._validate_no_mp_shard(model) 475 476 # Ensure params and grads are in full precision, 477 # as after fwd/backward we maintain full precision shards. 478 for param in model.parameters(): 479 self.assertEqual(param.dtype, full_precision_param_dtype) 480 if param.grad is not None: 481 self.assertEqual( 482 param.grad.dtype, full_precision_param_dtype 483 ) 484 485 # Unscale the gradients and step 486 scaler.step(optim) 487 # Update the scale factor 488 scaler.update() 489 490 # Summon full params should be in full precision 491 with model.summon_full_params(model): 492 # It is not expected for summon_full_params to allocate 493 # a mixed precision shard. 494 if mp_config.param_dtype is not None: 495 self._validate_mp_shard_freed(model) 496 else: 497 self._validate_no_mp_shard(model) 498 params = list(model.parameters()) 499 for p in params: 500 self.assertEqual(p.dtype, full_precision_param_dtype) 501 502 # Note that buffers are cast only once and only restored 503 # to the original buffer dtype in state_dict, so 504 # summon_full_params is not expected to restore buffer 505 # types to their original. 506 named_buffers = dict(model.named_buffers()) 507 for v in named_buffers.values(): 508 if mp_config.buffer_dtype is not None: 509 self.assertEqual(v.dtype, mp_config.buffer_dtype) 510 else: 511 self.assertEqual(v.dtype, _BUFFER_ORIG_DTYPE) 512 513 # state_dict should be in full precision 514 state_dict = {k: v.clone() for k, v in model.state_dict().items()} 515 for name, tensor in state_dict.items(): 516 # Parameters and buffers are checkpointed in their 517 # original dtypes, which may be different. 518 if name in named_buffers.keys(): 519 self.assertEqual(tensor.dtype, _BUFFER_ORIG_DTYPE) 520 else: 521 self.assertEqual( 522 tensor.dtype, 523 full_precision_param_dtype, 524 f"{name}: {tensor.dtype} vs {full_precision_param_dtype}", 525 ) 526 527 # After state_dict, buffer's dtype should have been restored 528 # to the mixed precision one. 529 for buf in model.buffers(): 530 if mp_config.buffer_dtype is not None: 531 self.assertEqual(buf.dtype, mp_config.buffer_dtype) 532 else: 533 self.assertEqual(buf.dtype, _BUFFER_ORIG_DTYPE) 534 535 536class TestFSDPMixedPrecisionSharded(TestFSDPMixedPrecision): 537 @property 538 def world_size(self): 539 return 2 540 541 def _get_subtest_config(self) -> Dict[str, List[Any]]: 542 """Returns a subtest configuration that subtests prefetching settings 543 together.""" 544 return { 545 "forward_prefetch": [False, True], 546 "backward_prefetch": [ 547 None, 548 BackwardPrefetch.BACKWARD_PRE, 549 BackwardPrefetch.BACKWARD_POST, 550 ], 551 } 552 553 @skip_if_lt_x_gpu(2) 554 def test_mixed_precision_no_reshard_after_forward(self): 555 # Note that we don't exercise all possible different configs so as to 556 # not increase test TTS too much. 557 mp = default_mp if not nccl_supports_bf16 else mp_diff_buffer_and_reduce 558 self._run_test_mixed_precision_e2e( 559 mp_config=mp, 560 cpu_offload=CPUOffload(offload_params=True), 561 backward_prefetch=None, 562 forward_prefetch=False, 563 full_precision_param_dtype=torch.float64, 564 sharding_strategy=ShardingStrategy.SHARD_GRAD_OP, 565 enable_sharded_grad_scaler=False, 566 ) 567 568 @skip_if_lt_x_gpu(2) 569 @parametrize(params, configs, subtest_name) 570 def test_mixed_precision_e2e_full_shard( 571 self, 572 mp_config, 573 cpu_offload, 574 full_precision_param_dtype, 575 enable_sharded_grad_scaler, 576 ): 577 self.run_subtests( 578 self._get_subtest_config(), 579 self._run_test_mixed_precision_e2e, 580 mp_config=mp_config, 581 cpu_offload=cpu_offload, 582 full_precision_param_dtype=full_precision_param_dtype, 583 sharding_strategy=ShardingStrategy.FULL_SHARD, 584 enable_sharded_grad_scaler=enable_sharded_grad_scaler, 585 ) 586 587 def _test_mixed_precision_embedding_table(self, mp_config): 588 # Basic test to ensure int inputs are not casted which would break 589 # modules such as embedding tables. 590 param_dtype = mp_config.param_dtype or torch.float32 591 orig_reduce_scatter = dist.reduce_scatter_tensor 592 test_reduce_scatter = partial( 593 self._reduce_scatter_validate_mp, 594 orig_reduce_scatter, 595 mp_config, 596 True, 597 ) 598 with patch_reduce_scatter(test_reduce_scatter, param_dtype): 599 # TODO: `test_mp_embedding_reduce()` fails if we do not wrap the 600 # entire `TransformerWithSharedParams` with a single top-level FSDP 601 model = TransformerWithSharedParams.init( 602 self.process_group, 603 FSDPInitMode.NO_FSDP, 604 CUDAInitMode.CUDA_BEFORE, 605 {"mixed_precision": mp_config}, 606 ) 607 fsdp_model = FSDP(model, mixed_precision=mp_config) 608 optim = torch.optim.SGD(fsdp_model.parameters(), lr=0.1) 609 for _ in range(6): 610 inp = fsdp_model.module.get_input(torch.device("cuda")) 611 # This would fail if we casted integer module inputs such as for 612 # embedding tables. 613 output = fsdp_model(*inp) 614 loss = fsdp_model.module.get_loss(inp, output).cuda() 615 self.assertEqual(loss.dtype, param_dtype) 616 fsdp_model.module.run_backward(loss) 617 optim.step() 618 619 @skip_if_lt_x_gpu(2) 620 def test_mp_embedding_reduce(self): 621 self._test_mixed_precision_embedding_table( 622 mp_config=MixedPrecision(reduce_dtype=torch.float16) 623 ) 624 625 @skip_if_lt_x_gpu(2) 626 def test_mp_embedding_only_params_and_bufs(self): 627 self._test_mixed_precision_embedding_table( 628 mp_config=MixedPrecision( 629 param_dtype=torch.float16, 630 buffer_dtype=torch.float16, 631 ) 632 ) 633 634 @skip_if_lt_x_gpu(2) 635 def test_mp_embedding_default(self): 636 default_mp_config = MixedPrecision( 637 param_dtype=torch.float16, 638 buffer_dtype=torch.float16, 639 reduce_dtype=torch.float16, 640 ) 641 self._test_mixed_precision_embedding_table(mp_config=default_mp_config) 642 643 @skip_if_lt_x_gpu(2) 644 def test_mp_embedding_params_and_reduce_diff(self): 645 params_and_reduce_different = MixedPrecision( 646 param_dtype=torch.float16, 647 reduce_dtype=torch.float32, 648 buffer_dtype=torch.float16, 649 ) 650 self._test_mixed_precision_embedding_table( 651 mp_config=params_and_reduce_different 652 ) 653 654 @skip_if_lt_x_gpu(2) 655 @skipIfNoTorchVision 656 def test_mixed_precision_resnet(self): 657 """ 658 End to end test to ensure mixed precision + auto_wrap works 659 for ResNet model. 660 """ 661 resnet_model = torchvision.models.resnet50().cuda() 662 resnet_model = nn.SyncBatchNorm.convert_sync_batchnorm( 663 resnet_model, process_group=dist.distributed_c10d._get_default_group() 664 ) 665 n_bn = sum( 666 1 if isinstance(x, _BatchNorm) else 0 for x in resnet_model.modules() 667 ) 668 inp = torch.ones(1, 3, 1000, 1000, device="cuda") 669 mp_config = MixedPrecision( 670 param_dtype=torch.float16, 671 reduce_dtype=torch.float16, 672 buffer_dtype=torch.float16, 673 ) 674 fsdp = FSDP( 675 resnet_model, 676 auto_wrap_policy=size_based_auto_wrap_policy, 677 mixed_precision=mp_config, 678 ) 679 # Batchnorm units should be wrapped individually. Validate this by 680 # ensuring there are equal no. of FSDP units that are BN as BN units 681 # in original resnet model. 682 fsdp_bn = 0 683 for module in fsdp.fsdp_modules(fsdp): 684 wrapped_module = module.module 685 if isinstance(wrapped_module, _BatchNorm): 686 fsdp_bn += 1 687 688 self.assertEqual(fsdp_bn, n_bn) 689 # Would throw type mismatch issue without mixed precision autowrapping. 690 loss = fsdp(inp).sum() 691 loss.backward() 692 693 @skip_if_lt_x_gpu(2) 694 def test_grads_reduced_precision(self): 695 self.run_subtests( 696 { 697 "offload_params": [False, True], 698 "use_orig_params": [False, True], 699 }, 700 self._test_grads_reduced_precision, 701 ) 702 703 @skip_if_lt_x_gpu(2) 704 @parametrize("convert_sync_bn", [True, False]) 705 def test_mp_batchnorm(self, convert_sync_bn): 706 class BatchNormNet(nn.Module): 707 def __init__(self, affine=True): 708 super().__init__() 709 self.fc1 = nn.Linear(2, 40, bias=False) 710 self.bn = nn.BatchNorm1d(4, affine=affine) 711 self.fc2 = nn.Linear(40, 4, bias=False) 712 self.ln = nn.LayerNorm(4) 713 self.fc3 = nn.Linear(4, 4, bias=False) 714 715 def forward(self, x): 716 x = torch.reshape(self.fc1(x), (-1, 4, 10)) 717 x = self.bn(x) 718 x = torch.reshape(x, (-1, 40)) 719 x = self.fc2(x) 720 x = self.ln(x) 721 x = self.fc3(x) 722 return F.softmax(x, dim=1) 723 724 def never_wrap_policy(*args, **kwargs): 725 return False 726 727 net = BatchNormNet().cuda() 728 if convert_sync_bn: 729 net = nn.SyncBatchNorm.convert_sync_batchnorm(net) 730 # FSDP detects that mixed precision + batchnorm will cause issues 731 # and thus wrap batchnorm in a distinct FSDP unit that does not 732 # use mixed precision. 733 mp_config = MixedPrecision( 734 param_dtype=torch.float16, 735 reduce_dtype=torch.float16, 736 buffer_dtype=torch.float16, 737 _module_classes_to_ignore=[_BatchNorm, nn.LayerNorm], 738 ) 739 with self.assertWarnsRegex( 740 expected_warning=UserWarning, 741 expected_regex="These modules will be wrapped as separate FSDP", 742 ): 743 model = FSDP( 744 net, 745 mixed_precision=mp_config, 746 auto_wrap_policy=never_wrap_policy, 747 ) 748 749 no_mp = MixedPrecision() 750 for mod in [model.ln, model.bn]: 751 self.assertTrue(isinstance(mod, FSDP)) 752 self.assertEqual(no_mp, mod.mixed_precision) 753 # policy should not have wrapped any other submodules 754 for mod in [model.fc1, model.fc2, model.fc3]: 755 self.assertFalse(isinstance(mod, FSDP)) 756 757 # Overall mixed precision is still enabled 758 self.assertEqual(mp_config, model.mixed_precision) 759 760 inp = torch.randn((1, 2), device="cuda") 761 # Without FSDP BN mixed precision fix, this would result in 762 # RuntimeError: Expected counts to have type Half but got Float 763 # for syncBN 764 model(inp).sum().backward() 765 766 @skip_if_lt_x_gpu(2) 767 def test_eval_root_cast_inputs(self): 768 """ 769 In a case where root module does not manage FSDP parameters, 770 ensure that we don't cast forward inputs which could potentially 771 cause a dtype mismatch. Check that FSDP_USE_FULL_PREC_IN_EVAL controls 772 this. 773 """ 774 775 low_prec_dtype = torch.float16 776 777 class MyModel(torch.nn.Module): 778 def __init__(self) -> None: 779 super().__init__() 780 self.a = nn.Linear(5, 5) 781 782 def forward(self, x, expect_use_full_prec_in_eval): 783 if expect_use_full_prec_in_eval: 784 assert x.dtype == torch.float32, f"Expected fp32, got {x.dtype}" 785 else: 786 assert ( 787 x.dtype == low_prec_dtype 788 ), f"Expected {low_prec_dtype}, got {x.dtype}" 789 return self.a(x) 790 791 mp_config = MixedPrecision( 792 param_dtype=low_prec_dtype, 793 reduce_dtype=low_prec_dtype, 794 buffer_dtype=low_prec_dtype, 795 ) 796 797 for use_full_prec_in_eval in [True, False]: 798 os.environ["FSDP_USE_FULL_PREC_IN_EVAL"] = ( 799 "1" if use_full_prec_in_eval else "0" 800 ) 801 m = MyModel().cuda() 802 m.a = FSDP(m.a, mixed_precision=mp_config) 803 model = FSDP(m, mixed_precision=mp_config) 804 model.eval() 805 inp = torch.randn(5, 5) 806 model(inp, use_full_prec_in_eval).sum().backward() 807 808 @skip_if_lt_x_gpu(2) 809 def test_full_precision_in_eval(self): 810 """ 811 Tests that eval runs in full precision if FSDP_USE_FULL_PREC_IN_EVAL is set. 812 """ 813 for ( 814 use_composable, 815 cast_forward_inputs, 816 use_full_prec_in_eval, 817 ) in itertools.product([True, False], [True, False], [True, False]): 818 mp_config = MixedPrecision( 819 param_dtype=torch.float16, 820 reduce_dtype=torch.float16, 821 buffer_dtype=torch.float16, 822 cast_forward_inputs=cast_forward_inputs, 823 ) 824 os.environ["FSDP_USE_FULL_PREC_IN_EVAL"] = ( 825 "1" if use_full_prec_in_eval else "0" 826 ) 827 model = TransformerWithSharedParams.init( 828 self.process_group, 829 FSDPInitMode.NO_FSDP if use_composable else FSDPInitMode.RECURSIVE, 830 CUDAInitMode.CUDA_BEFORE, 831 {"mixed_precision": mp_config}, 832 ) 833 if use_composable: 834 auto_wrap_policy = ModuleWrapPolicy( 835 { 836 TransformerEncoderLayer, 837 TransformerDecoderLayer, 838 } 839 ) 840 fully_shard(model, policy=auto_wrap_policy, mixed_precision=mp_config) 841 module_accessor = model if use_composable else model 842 inp = module_accessor.get_input(torch.device("cuda")) 843 output = model(*inp) 844 loss = module_accessor.get_loss(inp, output).cuda() 845 # Loss should be in fp16 846 self.assertEqual(torch.float16, loss.dtype) 847 module_accessor.run_backward(loss) 848 # Grads should be in fp32 as we upcast them 849 for p in model.parameters(): 850 if p.grad is not None: 851 self.assertEqual(torch.float32, p.grad.dtype) 852 853 # Now in eval mode, loss should be fp32 if use_full_prec_in_eval is set. 854 model.eval() 855 inp = module_accessor.get_input(torch.device("cuda")) 856 output = model(*inp) 857 loss = module_accessor.get_loss(inp, output).cuda() 858 expected_dtype = torch.float32 if use_full_prec_in_eval else torch.float16 859 self.assertEqual(expected_dtype, loss.dtype) 860 861 @skip_if_lt_x_gpu(2) 862 def test_full_precision_in_eval_buffers(self): 863 """ 864 Tests that when model.eval() and FSDP_USE_FULL_PREC_IN_EVAL is set, 865 buffers are in the full precision. 866 """ 867 for ( 868 use_composable, 869 cast_forward_inputs, 870 use_full_prec_in_eval, 871 ) in itertools.product([True, False], [True, False], [True, False]): 872 os.environ["FSDP_USE_FULL_PREC_IN_EVAL"] = ( 873 "1" if use_full_prec_in_eval else "0" 874 ) 875 mp_config = MixedPrecision( 876 param_dtype=torch.float16, 877 reduce_dtype=torch.float16, 878 buffer_dtype=torch.float16, 879 cast_forward_inputs=cast_forward_inputs, 880 ) 881 model_getter = ( 882 self._get_simple_nested_model_composable 883 if use_composable 884 else self._get_simple_nested_model 885 ) 886 fsdp_model = model_getter( 887 param_dtype=torch.float32, 888 run_checks=False, 889 mixed_precision=mp_config, 890 ) 891 892 inp = torch.randn(3, 10, device="cuda") 893 fsdp_model((inp, self, fsdp_model, mp_config, torch.float32)) 894 for buf in fsdp_model.buffers(): 895 self.assertEqual(torch.float16, buf.dtype) 896 897 # model.eval() + forward pass should make the buffers in full prec again 898 # Add pre-forward hooks 899 def verify_eval_buffer_dtype(module, input): 900 expected_dtype = ( 901 _BUFFER_ORIG_DTYPE if use_full_prec_in_eval else torch.float16 902 ) 903 for buf in module.buffers(): 904 self.assertEqual(expected_dtype, buf.dtype) 905 906 def _get_underlying_module(m): 907 return m.module if isinstance(m, FSDP) else m 908 909 hook_handles = [] 910 hook_handles.append( 911 _get_underlying_module(fsdp_model[0]).register_forward_pre_hook( 912 verify_eval_buffer_dtype 913 ) 914 ) 915 hook_handles.append( 916 _get_underlying_module(fsdp_model[1]).register_forward_pre_hook( 917 verify_eval_buffer_dtype 918 ) 919 ) 920 921 fsdp_model.eval() 922 fsdp_model((inp, self, fsdp_model, mp_config, torch.float32)) 923 for hook_handle in hook_handles: 924 hook_handle.remove() 925 926 expected_dtype = ( 927 _BUFFER_ORIG_DTYPE if use_full_prec_in_eval else torch.float16 928 ) 929 for buf in fsdp_model.buffers(): 930 self.assertEqual(expected_dtype, buf.dtype) 931 932 # model.train() + forward again should make buffers in fp16 933 fsdp_model.train() 934 fsdp_model((inp, self, fsdp_model, mp_config, torch.float32)) 935 for buf in fsdp_model.buffers(): 936 self.assertEqual(torch.float16, buf.dtype) 937 938 @skip_if_lt_x_gpu(2) 939 def test_full_precision_in_eval_comm(self): 940 for ( 941 use_composable, 942 cast_forward_inputs, 943 use_full_prec_in_eval, 944 ) in itertools.product([True, False], [True, False], [True, False]): 945 os.environ["FSDP_USE_FULL_PREC_IN_EVAL"] = ( 946 "1" if use_full_prec_in_eval else "0" 947 ) 948 mp_config = MixedPrecision( 949 param_dtype=torch.float32, 950 reduce_dtype=torch.float16, 951 buffer_dtype=torch.float32, 952 cast_forward_inputs=cast_forward_inputs, 953 # cast reduction for batchnorm also just in this test, to make 954 # validation easier. 955 _module_classes_to_ignore=[], 956 ) 957 model = TransformerWithSharedParams.init( 958 self.process_group, 959 FSDPInitMode.NO_FSDP if use_composable else FSDPInitMode.RECURSIVE, 960 CUDAInitMode.CUDA_BEFORE, 961 {"mixed_precision": mp_config}, 962 ) 963 if use_composable: 964 auto_wrap_policy = ModuleWrapPolicy( 965 { 966 TransformerEncoderLayer, 967 TransformerDecoderLayer, 968 } 969 ) 970 fully_shard(model, policy=auto_wrap_policy, mixed_precision=mp_config) 971 model_accessor = model if use_composable else model.module 972 # Patch reduce_scatter to add validation for mixed precision types. 973 orig_reduce_scatter = dist.reduce_scatter_tensor 974 test_reduce_scatter = partial( 975 self._reduce_scatter_validate_mp, 976 orig_reduce_scatter, 977 mp_config, 978 not use_full_prec_in_eval, 979 ) 980 model.eval() 981 with patch_reduce_scatter(test_reduce_scatter, torch.float32): 982 inp = model_accessor.get_input(torch.device("cuda")) 983 output = model(*inp) 984 loss = model_accessor.get_loss(inp, output).cuda() 985 model_accessor.run_backward(loss) 986 987 @skip_if_lt_x_gpu(2) 988 def test_input_grads_with_param_mixed_precision(self): 989 """ 990 Tests that input tensors that require gradients do get their gradients 991 even after being cast to a low precision (when parameter mixed 992 precision is enabled). 993 """ 994 self.run_subtests( 995 { 996 "sharding_strategy": [ 997 ShardingStrategy.FULL_SHARD, 998 ShardingStrategy.SHARD_GRAD_OP, 999 ShardingStrategy.NO_SHARD, 1000 ], 1001 "use_orig_params": [False, True], 1002 }, 1003 self._test_input_grads_with_param_mixed_precision, 1004 ) 1005 1006 def _test_input_grads_with_param_mixed_precision( 1007 self, 1008 sharding_strategy: ShardingStrategy, 1009 use_orig_params: bool, 1010 ): 1011 model = nn.Linear(1024, 1024, bias=False) 1012 mixed_precision = MixedPrecision( 1013 param_dtype=torch.float16, 1014 reduce_dtype=torch.float32, 1015 buffer_dtype=torch.float32, 1016 ) 1017 fsdp_model = FSDP( 1018 model, 1019 sharding_strategy=sharding_strategy, 1020 mixed_precision=mixed_precision, 1021 device_id=torch.cuda.current_device(), 1022 use_orig_params=use_orig_params, 1023 ) 1024 # Use an input with dtype not equal to the mixed precision 1025 # `param_dtype` so that it gets cast 1026 x_float = torch.randn( 1027 (32, 1024), 1028 device="cuda", 1029 dtype=torch.float32, 1030 requires_grad=True, 1031 ) 1032 fsdp_model(x_float).sum().backward() 1033 self.assertTrue(x_float.grad is not None) 1034 # Check that `x_float` preserves its dtype, meaning that the gradient 1035 # propagated via `ToCopyBackward0` 1036 self.assertEqual(x_float.grad.dtype, torch.float32) 1037 1038 1039class TestFSDPMixedPrecisionUnsharded(TestFSDPMixedPrecision): 1040 """ 1041 Smaller test suite for unshared param (i.e. world_size == 1) case. 1042 """ 1043 1044 @property 1045 def world_size(self): 1046 return 1 1047 1048 @skip_if_lt_x_gpu(1) 1049 def test_grads_reduced_precision(self): 1050 self.run_subtests( 1051 {"offload_params": [False, True], "use_orig_params": [False, True]}, 1052 self._test_grads_reduced_precision, 1053 ) 1054 1055 @skip_if_lt_x_gpu(1) 1056 def test_mixed_precision_no_reshard_after_forward(self): 1057 # Note that we don't exercise all possible different configs so as to 1058 # not increase test TTS too much. 1059 mp = default_mp if not nccl_supports_bf16 else mp_diff_buffer_and_reduce 1060 self._run_test_mixed_precision_e2e( 1061 mp_config=mp, 1062 cpu_offload=CPUOffload(offload_params=True), 1063 backward_prefetch=None, 1064 forward_prefetch=False, 1065 full_precision_param_dtype=torch.float64, 1066 sharding_strategy=ShardingStrategy.SHARD_GRAD_OP, 1067 enable_sharded_grad_scaler=False, 1068 ) 1069 1070 @skip_if_lt_x_gpu(1) 1071 def test_mixed_precision_e2e_full_shard(self): 1072 mp = default_mp if not nccl_supports_bf16 else mp_diff_buffer_and_reduce 1073 self._run_test_mixed_precision_e2e( 1074 mp_config=mp, 1075 cpu_offload=CPUOffload(offload_params=True), 1076 backward_prefetch=None, 1077 forward_prefetch=False, 1078 full_precision_param_dtype=torch.float64, 1079 sharding_strategy=ShardingStrategy.FULL_SHARD, 1080 enable_sharded_grad_scaler=False, 1081 ) 1082 1083 1084instantiate_parametrized_tests(TestFSDPMixedPrecisionSharded) 1085 1086 1087class IgnoredModule(nn.Module): 1088 def __init__(self) -> None: 1089 super().__init__() 1090 self.l = nn.Linear(100, 100) 1091 1092 def forward(self, x): 1093 return self.l(x) 1094 1095 1096class ModelWithIgnoredModule(nn.Module): 1097 def __init__(self) -> None: 1098 super().__init__() 1099 self.l1 = nn.Linear(100, 100) 1100 self.ignored = IgnoredModule() 1101 self.l2 = nn.Linear(100, 100) 1102 1103 def forward(self, x): 1104 return self.l2(self.ignored(self.l1(x))) 1105 1106 1107class TestFSDPMixedPrecisionIgnoredModules(FSDPTest): 1108 @property 1109 def world_size(self): 1110 return 1 1111 1112 @skip_if_lt_x_gpu(1) 1113 def test_mixed_precision_with_ignored_module(self): 1114 model = ModelWithIgnoredModule().cuda() 1115 float16 = MixedPrecision(param_dtype=torch.float16) 1116 model = FSDP( 1117 model, 1118 ignored_modules=[model.ignored], 1119 mixed_precision=float16, 1120 ) 1121 1122 x = torch.ones(2, 100, device=torch.cuda.current_device()) 1123 1124 with self.assertRaisesRegex(RuntimeError, "must have the same dtype"): 1125 model(x).sum().backward() 1126 1127 1128class TestFSDPDifferentSubmodulePrecision(FSDPTest): 1129 @property 1130 def world_size(self): 1131 return 2 1132 1133 @skip_if_lt_x_gpu(2) 1134 def test_float16_on_one_submodule(self): 1135 forward_inputs: Dict[str, nn.Module] = {} 1136 float16 = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True) 1137 1138 model = SaveForwardInputsModel( 1139 forward_inputs, 1140 cast_forward_inputs=False, 1141 ).cuda() 1142 c1, c2 = model.c1, model.c2 1143 x = torch.zeros(2, 100, device="cuda") 1144 1145 # float16 on one submodule and float32 on everything else 1146 model.c2 = FSDP(model.c2, mixed_precision=float16) 1147 fsdp = FSDP(model) 1148 1149 fsdp(x).sum().backward() 1150 1151 self.assertEqual(forward_inputs[model].dtype, torch.float32) 1152 self.assertEqual(forward_inputs[c1].dtype, torch.float32) 1153 self.assertEqual(forward_inputs[c2].dtype, torch.float16) 1154 1155 @skip_if_lt_x_gpu(2) 1156 def test_float16_on_one_submodule_skip_inputs(self): 1157 forward_inputs: Dict[nn.Module, torch.Tensor] = {} 1158 float16 = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=False) 1159 1160 model = SaveForwardInputsModel( 1161 forward_inputs=forward_inputs, cast_forward_inputs=True 1162 ).cuda() 1163 c1, c2 = model.c1, model.c2 1164 x = torch.zeros(2, 100, device="cuda") 1165 1166 # float16 on one submodule and float32 on everything else 1167 model.c2 = FSDP(model.c2, mixed_precision=float16) 1168 fsdp = FSDP(model) 1169 1170 fsdp(x).sum().backward() 1171 1172 self.assertEqual(forward_inputs[model].dtype, torch.float32) 1173 self.assertEqual(forward_inputs[c1].dtype, torch.float32) 1174 self.assertEqual(forward_inputs[c2].dtype, torch.float32) 1175 1176 @skip_if_lt_x_gpu(2) 1177 def test_float16_on_one_submodule_skip_inputs_error(self): 1178 forward_inputs: Dict[nn.Module, torch.Tensor] = {} 1179 float16 = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=False) 1180 1181 model = SaveForwardInputsModel( 1182 forward_inputs=forward_inputs, cast_forward_inputs=False 1183 ).cuda() 1184 c1, c2 = model.c1, model.c2 1185 x = torch.zeros(2, 100, device="cuda") 1186 1187 # float16 on one submodule and float32 on everything else 1188 model.c2 = FSDP(model.c2, mixed_precision=float16) 1189 fsdp = FSDP(model) 1190 1191 with self.assertRaisesRegex( 1192 RuntimeError, "mat1 and mat2 must have the same dtype" 1193 ): 1194 fsdp(x).sum().backward() 1195 1196 @skip_if_lt_x_gpu(2) 1197 def test_submodules_with_different_precisions_error(self): 1198 forward_inputs: Dict[nn.Module, torch.Tensor] = {} 1199 float16 = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True) 1200 float32 = MixedPrecision(param_dtype=torch.float32, cast_forward_inputs=True) 1201 1202 model = SaveForwardInputsModel( 1203 forward_inputs=forward_inputs, cast_forward_inputs=False 1204 ).cuda() 1205 x = torch.zeros(2, 100, device="cuda") 1206 1207 # For submodules with different precisions, right now current design 1208 # does not support the case when the root FSDP instance wraps a submodule 1209 # that is not the first one executed. Because for that submodule, its inputs 1210 # (or previous submodule's outputs) have no way to be casted, instead, 1211 # the root module's inputs are casted upfront before entering 1212 # root module's forward 1213 model.c1 = FSDP(model.c1, mixed_precision=float16) 1214 fsdp = FSDP(model, mixed_precision=float32) 1215 with self.assertRaisesRegex( 1216 RuntimeError, "mat1 and mat2 must have the same dtype" 1217 ): 1218 fsdp(x).sum().backward() 1219 1220 @skip_if_lt_x_gpu(2) 1221 def test_submodules_with_different_precisions(self): 1222 forward_inputs: Dict[nn.Module, torch.Tensor] = {} 1223 float16 = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True) 1224 float32 = MixedPrecision(param_dtype=torch.float32, cast_forward_inputs=True) 1225 1226 model = SaveForwardInputsModel( 1227 forward_inputs=forward_inputs, cast_forward_inputs=False 1228 ).cuda() 1229 c1, c2 = model.c1, model.c2 1230 x = torch.zeros(2, 100, device="cuda") 1231 1232 model.c2 = FSDP(model.c2, mixed_precision=float16) 1233 fsdp = FSDP(model, mixed_precision=float32) 1234 1235 fsdp(x).sum().backward() 1236 1237 self.assertEqual(forward_inputs[model].dtype, torch.float32) 1238 self.assertEqual(forward_inputs[c1].dtype, torch.float32) 1239 self.assertEqual(forward_inputs[c2].dtype, torch.float16) 1240 1241 @skip_if_lt_x_gpu(2) 1242 def test_submodules_with_external_inputs(self): 1243 class ToyModule(nn.Module): 1244 def __init__(self, forward_inputs: Dict[str, torch.Tensor]) -> None: 1245 super().__init__() 1246 self.l = nn.Linear(100, 100) 1247 self.forward_inputs = forward_inputs 1248 1249 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 1250 self.forward_inputs["l2_input_x"] = x 1251 self.forward_inputs["l2_input_y"] = y 1252 return self.l(x) 1253 1254 class ToyModel(nn.Module): 1255 def __init__(self, forward_inputs: Dict[str, torch.Tensor]) -> None: 1256 super().__init__() 1257 self.l1 = nn.Linear(100, 100) 1258 self.l2 = ToyModule(forward_inputs) 1259 self.forward_inputs = forward_inputs 1260 1261 def forward(self, x: torch.Tensor) -> torch.Tensor: 1262 self.forward_inputs["model_input_x"] = x 1263 y = torch.ones(2, 100, device="cuda", dtype=torch.float32) 1264 return self.l2(self.l1(x), y) 1265 1266 forward_inputs: Dict[str, torch.Tensor] = {} 1267 1268 float16 = MixedPrecision(param_dtype=torch.float16) 1269 model = ToyModel(forward_inputs).cuda() 1270 x = torch.zeros(2, 100, device="cuda", dtype=torch.float32) 1271 model.l2 = FSDP(model.l2, mixed_precision=float16) 1272 fsdp = FSDP(model, mixed_precision=float16) 1273 1274 fsdp(x).sum().backward() 1275 1276 # Inputs are casted in root module in default, inputs of submodules are not 1277 # explicitly casted, so the external inputs ``y`` of module ``self.l2`` is 1278 # not casted. 1279 self.assertEqual(forward_inputs["model_input_x"].dtype, torch.float16) 1280 self.assertEqual(forward_inputs["l2_input_x"].dtype, torch.float16) 1281 self.assertEqual(forward_inputs["l2_input_y"].dtype, torch.float32) 1282 1283 1284class TestFSDPTrainEval(FSDPTest): 1285 @property 1286 def world_size(self): 1287 return 2 1288 1289 @skip_if_lt_x_gpu(2) 1290 def test_train_ema_eval_flow(self): 1291 """ 1292 Tests a train -> EMA update -> eval flow with mixed precision enabled. 1293 """ 1294 self.run_subtests( 1295 { 1296 "sharding_strategy": [ 1297 # We mainly want to test `SHARD_GRAD_OP` since it surfaced 1298 # the original bug of not using the right EMA parameters 1299 # for eval, but we also test the others for completeness 1300 ShardingStrategy.SHARD_GRAD_OP, 1301 ShardingStrategy.FULL_SHARD, 1302 ShardingStrategy.NO_SHARD, 1303 ] 1304 }, 1305 self._test_train_ema_eval_flow, 1306 ) 1307 1308 def _test_train_ema_eval_flow(self, sharding_strategy: ShardingStrategy): 1309 class TransformerWithEMA(nn.Module): 1310 def __init__(self, device: torch.device): 1311 super().__init__() 1312 self.module = nn.Transformer(device=device) 1313 self.ema_module = AveragedModel( 1314 nn.Transformer(device=device), 1315 multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(), 1316 use_buffers=True, 1317 ) 1318 1319 def forward(self, *args, **kwargs): 1320 # Use main copy for training and EMA copy for eval 1321 if self.training: 1322 return self.module(*args, **kwargs) 1323 return self.ema_module(*args, **kwargs) 1324 1325 device = torch.device("cuda") 1326 model = TransformerWithEMA(device=device) 1327 policy = ModuleWrapPolicy( 1328 {nn.Transformer, nn.TransformerEncoderLayer, nn.TransformerDecoderLayer} 1329 ) 1330 mixed_precision = MixedPrecision(param_dtype=torch.float16) 1331 fsdp_model = FSDP( 1332 model, 1333 auto_wrap_policy=policy, 1334 mixed_precision=mixed_precision, 1335 sharding_strategy=sharding_strategy, 1336 ) 1337 optim = torch.optim.Adam(fsdp_model.module.parameters(), lr=1e-2) 1338 if self.rank == 0: 1339 print(fsdp_model) 1340 torch.manual_seed(1 + self.rank) 1341 eval_src = torch.randn((8, 1, 512), device=device) 1342 eval_tgt = torch.randn((16, 1, 512), device=device) 1343 eval_out_sums: List[torch.Tensor] = [] 1344 # An iteration consists of training forward/backward/optimizer, 1345 # updating the EMA copy with the main copy, and eval forward 1346 for _ in range(3): 1347 fsdp_model.train() 1348 train_src = torch.randn((8, 4, 512), device=device) 1349 train_tgt = torch.randn((16, 4, 512), device=device) 1350 train_out = fsdp_model(train_src, train_tgt) 1351 train_out.sum().backward() 1352 optim.step() 1353 optim.zero_grad() 1354 with FSDP.summon_full_params(fsdp_model): 1355 fsdp_model.ema_module.update_parameters(fsdp_model.module) 1356 fsdp_model.eval() 1357 with torch.no_grad(): 1358 eval_out = fsdp_model(eval_src, eval_tgt) 1359 eval_out_sums.append(eval_out.sum()) 1360 # Check that the eval outputs differ from iteration to iteration as a 1361 # proxy for eval using the correct EMA parameters 1362 for i in range(len(eval_out_sums) - 1): 1363 self.assertNotEqual(eval_out_sums[i], eval_out_sums[i + 1]) 1364 self.assertNotEqual(eval_out_sums[0], eval_out_sums[-1]) 1365 1366 1367if __name__ == "__main__": 1368 run_tests() 1369