xref: /aosp_15_r20/external/pytorch/test/distributed/checkpoint/e2e/test_pipeline.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import os
4import sys
5
6import torch
7import torch.distributed as dist
8import torch.distributed.checkpoint as dcp
9import torch.nn as nn
10from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
11from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
12from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
13from torch.testing._internal.common_fsdp import FSDPTest
14from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
15from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
16
17
18if not dist.is_available():
19    print("Distributed not available, skipping tests", file=sys.stderr)
20    sys.exit(0)
21
22if TEST_WITH_DEV_DBG_ASAN:
23    print(
24        "Skip dev-asan as torch + multiprocessing spawn have known issues",
25        file=sys.stderr,
26    )
27    sys.exit(0)
28
29
30DIM = 500
31
32
33class PipelineModel(nn.Module):
34    def __init__(self) -> None:
35        super().__init__()
36        self.layer1 = nn.Linear(DIM, DIM)
37        self.layer2 = nn.Linear(DIM, DIM)
38        self.layer3 = nn.Linear(DIM, DIM)
39        self.layer4 = nn.Linear(DIM, DIM)
40        self.relu = nn.ReLU()
41
42    def forward(self, batch):
43        x = self.relu(self.layer1(batch))
44        x = self.relu(self.layer2(x))
45        x = self.relu(self.layer3(x))
46        x = self.relu(self.layer4(x))
47        return x
48
49
50class TestPipeline(FSDPTest):
51    @property
52    def world_size(self) -> int:
53        return min(4, torch.cuda.device_count())
54
55    def save_with_pipeline(self, pipeline_dir: str) -> None:
56        with torch.device("meta"):
57            model = PipelineModel()
58
59        pipeline_modules = [model.layer1, model.layer2, model.layer3, model.layer4]
60
61        # Materialize the model
62        submodule = pipeline_modules[self.rank]
63        submodule.to_empty(device=torch.device("cuda"))
64        # submodule.reset_parameters()
65        optim = torch.optim.Adam(submodule.parameters(), lr=1e-3)
66
67        # Ignore the training as we don't have a real pipeline parallelism.
68
69        # Save state_dict
70        model_state_dict, optim_state_dict = get_state_dict(model, optimizers=optim)
71        saved_state_dict = {"model": model_state_dict, "optim": optim_state_dict}
72        dcp.save(
73            state_dict=saved_state_dict,
74            storage_writer=dcp.FileSystemWriter(pipeline_dir),
75        )
76
77    def load_with_fsdp(self, pipeline_dir: str) -> None:
78        model = FSDP(PipelineModel().cuda())
79        optim = torch.optim.Adam(model.parameters(), lr=1e-3)
80
81        # Load the checkpoint
82        model_state_dict, optim_state_dict = get_state_dict(model, optimizers=optim)
83        dcp.load(
84            {"model": model_state_dict, "optim": optim_state_dict},
85            storage_reader=dcp.FileSystemReader(pipeline_dir),
86        )
87        set_state_dict(
88            model,
89            optimizers=optim,
90            model_state_dict=model_state_dict,
91            optim_state_dict=optim_state_dict,
92        )
93
94    @skip_if_lt_x_gpu(4)
95    @with_temp_dir
96    def test_pipeline(self) -> None:
97        self.assertTrue(os.path.exists(self.temp_dir))
98        pipeline_dir = os.path.join(self.temp_dir, "pipeline")
99        if self.rank == 0:
100            os.mkdir(pipeline_dir)
101        os.sync()
102        dist.barrier()
103        self.assertTrue(os.path.exists(pipeline_dir))
104        self.save_with_pipeline(pipeline_dir)
105        self.load_with_fsdp(pipeline_dir)
106
107
108if __name__ == "__main__":
109    run_tests()
110