1# Owner(s): ["oncall: distributed"]
2import copy
3import os
4
5import torch
6import torch.nn as nn
7from torch.distributed._composable.fsdp.fully_shard import (
8    fully_shard,
9    MixedPrecisionPolicy,
10)
11from torch.distributed._tensor import DTensor
12from torch.distributed.device_mesh import init_device_mesh
13from torch.distributed.pipelining import PipelineStage
14from torch.distributed.pipelining.schedules import (
15    PipelineScheduleSingle,
16    Schedule1F1B,
17    ScheduleFlexibleInterleaved1F1B,
18    ScheduleGPipe,
19    ScheduleInterleaved1F1B,
20    ScheduleInterleavedZeroBubble,
21    ScheduleLoopedBFS,
22)
23from torch.nn.parallel import DistributedDataParallel as DDP
24from torch.testing._internal.common_cuda import TEST_MULTIGPU
25from torch.testing._internal.common_distributed import (
26    MultiProcessTestCase,
27    requires_nccl,
28    skip_if_lt_x_gpu,
29)
30from torch.testing._internal.common_utils import (
31    instantiate_parametrized_tests,
32    parametrize,
33    run_tests,
34    skip_but_pass_in_sandcastle_if,
35)
36
37
38# MLP Layer
39class MLPModule(torch.nn.Module):
40    def __init__(self, d_hid: int):
41        super().__init__()
42        self.net1 = torch.nn.Linear(d_hid, d_hid)
43        self.relu = torch.nn.ReLU()
44        self.net2 = torch.nn.Linear(d_hid, d_hid)
45
46    def forward(self, x):
47        x = self.net1(x)
48        x = self.relu(x)
49        x = self.net2(x)
50        return x
51
52
53class ComposabilityTest(MultiProcessTestCase):
54    @classmethod
55    def backend_str(cls) -> str:
56        # Testing with NCCL backend
57        return "nccl"
58
59    def setUp(self):
60        super().setUp()
61        self._spawn_processes()
62
63    def tearDown(self):
64        super().tearDown()
65        try:
66            os.remove(self.file_name)
67        except OSError:
68            pass
69
70    @property
71    def world_size(self):
72        return 4
73
74    @property
75    def device(self):
76        return self.rank
77
78    @requires_nccl()
79    @skip_if_lt_x_gpu(4)
80    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "Test requires 4+ GPUs")
81    @parametrize("dp_type", ["DDP", "FSDP"])
82    @parametrize(
83        "ScheduleClass",
84        [
85            ScheduleGPipe,
86            Schedule1F1B,
87            ScheduleInterleaved1F1B,
88            ScheduleLoopedBFS,
89            ScheduleFlexibleInterleaved1F1B,
90            ScheduleInterleavedZeroBubble,
91        ],
92    )
93    @parametrize("use_new_runtime", [False, True])
94    def test_manual_with_data_parallel(self, dp_type, ScheduleClass, use_new_runtime):
95        device = torch.device("cuda", self.device)
96        torch.cuda.set_device(self.device)
97        store = torch.distributed.FileStore(self.file_name, self.world_size)
98        torch.distributed.init_process_group(
99            backend="nccl",
100            store=store,
101            rank=self.rank,
102            world_size=self.world_size,
103            device_id=device,
104        )
105        device_mesh = init_device_mesh(
106            "cuda", mesh_shape=(2, 2), mesh_dim_names=("dp", "pp")
107        )
108        pp_group = device_mesh["pp"].get_group()
109        dp_mesh = device_mesh["dp"]
110
111        # create "entire model"
112        total_layers = 8
113        dim = 10
114        full_model = nn.ModuleList([MLPModule(dim) for _ in range(total_layers)])
115        ref_model = nn.Sequential(*copy.deepcopy(full_model))
116        ref_model.to(self.device)
117
118        # Prepare inputs
119        num_microbatches = 8
120        inputs = [
121            torch.rand((num_microbatches, dim), device=self.device)
122            for _ in range(dp_mesh.size())
123        ]
124        input = inputs[dp_mesh.get_local_rank()]
125        input_mb = [[input[i].reshape((1, dim))] for i in range(num_microbatches)]
126
127        # dummy loss needed just to force backwards to run in schedule step
128        def loss_fn(y, target):
129            return y.sum()
130
131        # Get stage module i from the entire model
132        def get_stage_module(stage_idx, num_stages):
133            # divide the model (8 layers) by the number of stages
134            layers_per_stage = total_layers // num_stages
135            assert layers_per_stage * num_stages == total_layers
136            # return offset so validation code can match partial layer back to orig model
137            offset = stage_idx * layers_per_stage
138            partial_model = nn.Sequential(
139                *full_model[offset : (stage_idx + 1) * layers_per_stage]
140            )
141            partial_model.to(self.device)
142            return partial_model, offset
143
144        # Apply DP to stage module
145        def apply_dp(partial_model, dp_type):
146            if dp_type == "FSDP":
147                # apply FSDP
148                mp_policy = MixedPrecisionPolicy(
149                    # TODO(whc) need to fix PP + FSDP-mixed-precision
150                    # tracer for PP assumes f32 and is caught off guard when runtime FSDP interacts using bf16 inputs
151                    # param_dtype=torch.bfloat16, reduce_dtype=torch.float32
152                    param_dtype=torch.float32,
153                    reduce_dtype=torch.float32,
154                )
155                fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
156                for layer in partial_model.children():
157                    fully_shard(
158                        layer,
159                        **fsdp_config,
160                        reshard_after_forward=False,
161                    )
162                dp_model = fully_shard(partial_model, **fsdp_config)
163            elif dp_type == "DDP":
164                dp_model = DDP(partial_model, process_group=dp_mesh.get_group())
165            else:
166                raise RuntimeError(f"unsupported dp type {dp_type}")
167            return dp_model
168
169        # Create pipeline stage
170        def build_stage(stage_idx, num_stages):
171            partial_model, offset = get_stage_module(stage_idx, num_stages)
172            dp_model = apply_dp(partial_model, dp_type)
173            stage = PipelineStage(
174                dp_model,
175                stage_idx,
176                num_stages,
177                self.device,
178                group=pp_group,
179                input_args=input_mb[0],
180            )
181            return stage, offset
182
183        # Attach to a schedule
184        if issubclass(ScheduleClass, PipelineScheduleSingle):
185            if use_new_runtime:
186                # Can't test PipelineScheduleSingle classes using new runtime
187                # return should still clean up this test instance correctly
188                torch.distributed.destroy_process_group()
189                return
190            pipeline_stage, offset = build_stage(pp_group.rank(), pp_group.size())
191            partial_models = [pipeline_stage.submod]
192            offsets = [offset]
193            pipeline_schedule = ScheduleClass(
194                pipeline_stage,
195                n_microbatches=num_microbatches,
196                loss_fn=loss_fn,
197            )
198        else:
199            n_virtual = 2
200            num_stages = pp_group.size() * n_virtual
201            stages = []
202            offsets = []
203            for i in range(n_virtual):
204                stage, offset = build_stage(pp_group.rank() + n_virtual * i, num_stages)
205                stages.append(stage)
206                offsets.append(offset)
207                partial_models = [pipeline_stage.submod for pipeline_stage in stages]
208            pipeline_schedule = ScheduleClass(
209                stages,
210                n_microbatches=num_microbatches,
211                loss_fn=loss_fn,
212            )
213
214        # Run
215        pipeline_schedule._step_microbatches(arg_mbs=input_mb, target_mbs=input_mb)
216
217        # Ref model runs on 2 different inputs, accumulating grads across them.
218        # this ensures that we detect if the FSDP reduce becomes a no-op.
219        # (in fsdp case, we use one of these inputs on each DP rank)
220        (ref_model(inputs[0]).sum()).backward()
221        (ref_model(inputs[1]).sum()).backward()
222
223        # simulate the built-in averaging done by FSDP
224        for p in ref_model.parameters():
225            p.grad /= dp_mesh.size()
226
227        # Validate that whichever weights we have locally match that part of our local/full ref model
228        # (we force FSDP's grads to be all-gathered (.full_tensor) to make it simpler)
229        ref_parameters = dict(ref_model.named_parameters())
230        if dp_type == "FSDP":
231            for partial_model, offset in zip(partial_models, offsets):
232                for name, p in partial_model.named_parameters():
233                    parts = name.split(".")
234                    parts[0] = str(int(parts[0]) + offset)
235                    name = ".".join(parts)
236                    ref_p = ref_parameters[name]
237                    self.assertTrue(isinstance(p.grad, DTensor))
238                    torch.testing.assert_close(
239                        ref_p.grad, p.grad.full_tensor(), rtol=1e-5, atol=5e-5
240                    )
241        elif dp_type == "DDP":
242            for partial_model, offset in zip(partial_models, offsets):
243                for name, p in partial_model.named_parameters():
244                    parts = name.split(".")[1:]  # remove the "module." prefix
245                    parts[0] = str(int(parts[0]) + offset)
246                    name = ".".join(parts)
247                    ref_p = ref_parameters[name]
248                    torch.testing.assert_close(ref_p.grad, p.grad, rtol=1e-5, atol=5e-5)
249
250        torch.distributed.destroy_process_group()
251
252
253instantiate_parametrized_tests(ComposabilityTest)
254
255if __name__ == "__main__":
256    run_tests()
257