1# Owner(s): ["oncall: distributed"] 2 3import os 4import sys 5 6import torch 7import torch.distributed as dist 8import torch.distributed.checkpoint as dcp 9import torch.nn as nn 10from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict 11from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 12from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 13from torch.testing._internal.common_fsdp import FSDPTest 14from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN 15from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir 16 17 18if not dist.is_available(): 19 print("Distributed not available, skipping tests", file=sys.stderr) 20 sys.exit(0) 21 22if TEST_WITH_DEV_DBG_ASAN: 23 print( 24 "Skip dev-asan as torch + multiprocessing spawn have known issues", 25 file=sys.stderr, 26 ) 27 sys.exit(0) 28 29 30DIM = 500 31 32 33class PipelineModel(nn.Module): 34 def __init__(self) -> None: 35 super().__init__() 36 self.layer1 = nn.Linear(DIM, DIM) 37 self.layer2 = nn.Linear(DIM, DIM) 38 self.layer3 = nn.Linear(DIM, DIM) 39 self.layer4 = nn.Linear(DIM, DIM) 40 self.relu = nn.ReLU() 41 42 def forward(self, batch): 43 x = self.relu(self.layer1(batch)) 44 x = self.relu(self.layer2(x)) 45 x = self.relu(self.layer3(x)) 46 x = self.relu(self.layer4(x)) 47 return x 48 49 50class TestPipeline(FSDPTest): 51 @property 52 def world_size(self) -> int: 53 return min(4, torch.cuda.device_count()) 54 55 def save_with_pipeline(self, pipeline_dir: str) -> None: 56 with torch.device("meta"): 57 model = PipelineModel() 58 59 pipeline_modules = [model.layer1, model.layer2, model.layer3, model.layer4] 60 61 # Materialize the model 62 submodule = pipeline_modules[self.rank] 63 submodule.to_empty(device=torch.device("cuda")) 64 # submodule.reset_parameters() 65 optim = torch.optim.Adam(submodule.parameters(), lr=1e-3) 66 67 # Ignore the training as we don't have a real pipeline parallelism. 68 69 # Save state_dict 70 model_state_dict, optim_state_dict = get_state_dict(model, optimizers=optim) 71 saved_state_dict = {"model": model_state_dict, "optim": optim_state_dict} 72 dcp.save( 73 state_dict=saved_state_dict, 74 storage_writer=dcp.FileSystemWriter(pipeline_dir), 75 ) 76 77 def load_with_fsdp(self, pipeline_dir: str) -> None: 78 model = FSDP(PipelineModel().cuda()) 79 optim = torch.optim.Adam(model.parameters(), lr=1e-3) 80 81 # Load the checkpoint 82 model_state_dict, optim_state_dict = get_state_dict(model, optimizers=optim) 83 dcp.load( 84 {"model": model_state_dict, "optim": optim_state_dict}, 85 storage_reader=dcp.FileSystemReader(pipeline_dir), 86 ) 87 set_state_dict( 88 model, 89 optimizers=optim, 90 model_state_dict=model_state_dict, 91 optim_state_dict=optim_state_dict, 92 ) 93 94 @skip_if_lt_x_gpu(4) 95 @with_temp_dir 96 def test_pipeline(self) -> None: 97 self.assertTrue(os.path.exists(self.temp_dir)) 98 pipeline_dir = os.path.join(self.temp_dir, "pipeline") 99 if self.rank == 0: 100 os.mkdir(pipeline_dir) 101 os.sync() 102 dist.barrier() 103 self.assertTrue(os.path.exists(pipeline_dir)) 104 self.save_with_pipeline(pipeline_dir) 105 self.load_with_fsdp(pipeline_dir) 106 107 108if __name__ == "__main__": 109 run_tests() 110