1# Owner(s): ["oncall: distributed"] 2 3import sys 4 5import torch 6from torch import distributed as dist 7from torch.distributed.checkpoint import ( 8 FileSystemReader, 9 FileSystemWriter, 10 load_state_dict, 11 save_state_dict, 12) 13from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType 14from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel 15from torch.distributed.fsdp.wrap import enable_wrap, wrap 16from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 17from torch.testing._internal.common_fsdp import FSDPTest, SkipModel 18from torch.testing._internal.common_utils import ( 19 instantiate_parametrized_tests, 20 parametrize, 21 run_tests, 22 TEST_WITH_DEV_DBG_ASAN, 23) 24from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir 25 26 27if not dist.is_available(): 28 print("Distributed not available, skipping tests", file=sys.stderr) 29 sys.exit(0) 30 31if TEST_WITH_DEV_DBG_ASAN: 32 print( 33 "Skip dev-asan as torch + multiprocessing spawn have known issues", 34 file=sys.stderr, 35 ) 36 sys.exit(0) 37 38 39_DISTRIBUTED_STATE_DICT_IMPLS = { 40 StateDictType.LOCAL_STATE_DICT, 41 StateDictType.SHARDED_STATE_DICT, 42} 43 44 45class TestDistributedCheckpoint(FSDPTest): 46 @property 47 def world_size(self): 48 return 2 49 50 @skip_if_lt_x_gpu(2) 51 @with_temp_dir 52 @parametrize("state_dict_type", _DISTRIBUTED_STATE_DICT_IMPLS) 53 def test_distributed_checkpoint(self, state_dict_type) -> None: 54 with enable_wrap(wrapper_cls=FSDP): 55 torch.manual_seed(100) 56 model = wrap(SkipModel(double_nest=True)) 57 torch.manual_seed(200) 58 new_model = wrap(SkipModel(double_nest=True)) 59 60 with FullyShardedDataParallel.summon_full_params( 61 model 62 ), FullyShardedDataParallel.summon_full_params(new_model): 63 params = list(model.parameters()) 64 new_params = list(new_model.parameters()) 65 self.assertNotEqual(params, new_params) 66 67 writer = FileSystemWriter(self.temp_dir) 68 reader = FileSystemReader(self.temp_dir) 69 with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type( 70 new_model, state_dict_type 71 ): 72 state_dict = model.state_dict() 73 74 save_state_dict(state_dict, writer) 75 76 with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type( 77 new_model, state_dict_type 78 ): 79 state_dict = new_model.state_dict() 80 load_state_dict(state_dict, reader) 81 new_model.load_state_dict(state_dict) 82 83 with FullyShardedDataParallel.summon_full_params( 84 model 85 ), FullyShardedDataParallel.summon_full_params(new_model): 86 params = list(model.parameters()) 87 new_params = list(new_model.parameters()) 88 self.assertEqual(params, new_params) 89 90 # TODO: add resharding test case. 91 92 93instantiate_parametrized_tests(TestDistributedCheckpoint) 94 95if __name__ == "__main__": 96 run_tests() 97