xref: /aosp_15_r20/external/pytorch/test/distributed/pipelining/test_stage.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Meta Platforms, Inc. and affiliates
2# Owner(s): ["oncall: distributed"]
3import os
4import sys
5import tempfile
6
7from model_registry import ExampleCode, ModelWithKwargs, MultiMLP
8
9import torch
10import torch.distributed as dist
11from torch.distributed.pipelining import (
12    build_stage,
13    pipeline,
14    PipelineStage,
15    ScheduleGPipe,
16)
17from torch.distributed.pipelining._utils import PipeliningShapeError
18from torch.testing._internal.common_cuda import TEST_MULTIGPU
19from torch.testing._internal.common_distributed import (
20    MultiProcContinousTest,
21    requires_nccl,
22)
23from torch.testing._internal.common_utils import (
24    instantiate_parametrized_tests,
25    parametrize,
26    skip_but_pass_in_sandcastle_if,
27)
28from torch.utils._pytree import tree_map_only
29
30
31d_hid = 512
32batch_size = 256
33chunks = 4
34
35torch.manual_seed(0)
36
37
38def get_dtype_change_hook(new_dtype):
39    """A simple hook for simulating mixed precision"""
40
41    def dtype_change_hook(module, input, output):
42        def f(x):
43            return x.to(new_dtype)
44
45        return tree_map_only(torch.Tensor, f, output)
46
47    return dtype_change_hook
48
49
50def get_flatten_hook():
51    """A simple hook for simulating wrong model output shape"""
52
53    def flatten_hook(module, input, output):
54        def f(x):
55            return x.flatten()
56
57        return tree_map_only(torch.Tensor, f, output)
58
59    return flatten_hook
60
61
62class StageTest(MultiProcContinousTest):
63    @classmethod
64    def backend_str(cls) -> str:
65        # Testing with NCCL backend
66        return "nccl"
67
68    @classmethod
69    def setUpClass(cls):
70        """
71        Class-scope test fixture. Run once for entire test class, before any test starts.
72        Set up the device.
73        """
74        super().setUpClass()
75        dev_id = cls.rank % torch.cuda.device_count()
76        cls.device = torch.device(f"cuda:{dev_id}")
77
78    @requires_nccl()
79    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
80    @parametrize("ModelClass", [ExampleCode, MultiMLP])
81    def test_tracer(self, ModelClass):
82        mod = ModelClass(d_hid)
83        mod.to(self.device)
84
85        x = torch.randn(batch_size, d_hid, device=self.device)
86        x_mb = x.chunk(chunks)[0]
87
88        split_spec = mod.split_spec if hasattr(mod, "split_spec") else None
89        pipe = pipeline(
90            mod,
91            mb_args=(x_mb,),
92            split_spec=split_spec,
93        )
94
95        stage = pipe.build_stage(
96            self.rank,
97            self.device,
98        )
99
100        # Attach to a schedule
101        schedule = ScheduleGPipe(stage, chunks)
102
103        # Run
104        def _run_step(x):
105            if self.rank == 0:
106                return schedule.step(x)
107            else:
108                return schedule.step()
109
110        out = _run_step(x)
111        # Last rank checks result
112        if self.rank == self.world_size - 1:
113            ref_out = mod(x)
114            torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=5e-2)
115
116        # Test qualname mapping
117        submod_keys = stage.submod.state_dict().keys()
118        # Confirm keys are consistent with original model
119        old_keys = mod.state_dict().keys()
120        assert all(k in old_keys for k in submod_keys)
121
122        if self.rank == 0:
123            # intended to run this code on all ranks, but the problem is if rank0 throws,
124            # it won't perform the send that unblocks rank 1.
125
126            # TODO(whc) can't test this until fixing args/kwargs issue
127            # with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"):
128            #     _run_step(torch.randn(batch_size + 1, d_hid, device=self.device))
129
130            with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"):
131                _run_step(x.to(torch.int32))
132
133            # output of stage's mlp layer will be flattened by this hook, the stage should err
134            handle = stage.submod.register_forward_hook(get_flatten_hook())
135            with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"):
136                _run_step(x)
137            handle.remove()
138
139            stage.submod.register_forward_hook(get_dtype_change_hook(torch.bfloat16))
140            with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"):
141                _run_step(x)
142
143    @requires_nccl()
144    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
145    @parametrize("ModelClass", [ModelWithKwargs])
146    def test_tracer_kwargs(self, ModelClass):
147        mod = ModelClass(d_hid)
148        mod.to(self.device)
149
150        x = torch.randn(batch_size, d_hid, device=self.device)
151        y = torch.randn(batch_size, d_hid, device=self.device)
152
153        x_mb = x.chunk(chunks)[0]
154        y_mb = y.chunk(chunks)[0]
155
156        pipe = pipeline(
157            mod,
158            mb_args=(x_mb,),
159            mb_kwargs={"y": y_mb},
160        )
161
162        stage_mod = pipe.get_stage_module(self.rank)
163
164        # Test build_stage
165        stage = build_stage(
166            stage_mod,
167            self.rank,
168            pipe.info(),
169            self.device,
170        )
171
172        # Attach to a schedule
173        schedule = ScheduleGPipe(stage, chunks)
174
175        # Run
176        def _run_step(x):
177            if self.rank == 0:
178                return schedule.step(x, y=y)
179            else:
180                return schedule.step()
181
182        # Last rank checks result
183        out = _run_step(x)
184        if self.rank == self.world_size - 1:
185            ref_out = mod(x, y=y)
186            torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=5e-2)
187
188        # Test qualname mapping
189        submod_keys = stage.submod.state_dict().keys()
190        # Confirm keys are consistent with original model
191        old_keys = mod.state_dict().keys()
192        assert all(k in old_keys for k in submod_keys)
193
194        if self.rank == 0:
195            with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"):
196                _run_step(torch.randn(batch_size + 1, d_hid, device=self.device))
197
198            with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"):
199                _run_step(x.to(torch.int32))
200
201            # output of stage's mlp layer will be flattened by this hook, the stage should err
202            handle = stage.submod.register_forward_hook(get_flatten_hook())
203            with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"):
204                _run_step(x)
205            handle.remove()
206
207            stage.submod.register_forward_hook(get_dtype_change_hook(torch.bfloat16))
208            with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"):
209                _run_step(x)
210
211    @requires_nccl()
212    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
213    def test_manual(self):
214        full_mod = MultiMLP(d_hid, n_layers=self.world_size)
215        full_mod.to(self.device)
216        stage_mod = full_mod.get_submodule(f"layers.{self.rank}")
217
218        x = torch.randn(batch_size, d_hid, device=self.device)
219
220        stage = PipelineStage(
221            stage_mod,
222            self.rank,
223            self.world_size,
224            self.device,
225            input_args=x.chunk(chunks)[0],
226        )
227
228        # Attach to a schedule
229        schedule = ScheduleGPipe(stage, chunks)
230
231        # Run
232        def _run_step(x):
233            if self.rank == 0:
234                return schedule.step(x)
235            else:
236                return schedule.step()
237
238        out = _run_step(x)
239        # Last rank checks result
240        if self.rank == self.world_size - 1:
241            ref_out = full_mod(x)
242            torch.testing.assert_close(out, ref_out)
243
244        if self.rank == 0:
245            with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"):
246                _run_step(torch.randn(batch_size + 1, d_hid, device=self.device))
247
248            with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"):
249                _run_step(x.to(torch.int32))
250
251            # output of stage's mlp layer will be flattened by this hook, the stage should err
252            handle = stage_mod.register_forward_hook(get_flatten_hook())
253            with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"):
254                _run_step(x)
255            handle.remove()
256
257            stage_mod.register_forward_hook(get_dtype_change_hook(torch.bfloat16))
258            with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"):
259                _run_step(x)
260
261    @requires_nccl()
262    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
263    def test_custom_dw_with_fb_schedule(self):
264        """Tests that separate weight grad function 'dw_runner' gets run under a schedule that's only aware of F/B."""
265        full_mod = MultiMLP(d_hid, n_layers=self.world_size)
266        full_mod.to(self.device)
267        stage_mod = full_mod.get_submodule(f"layers.{self.rank}")
268
269        x = torch.randn(batch_size, d_hid, device=self.device)
270        target = torch.randn(batch_size, d_hid, device=self.device)
271
272        class CustomState:
273            def __init__(self) -> None:
274                self.i = 0
275
276            def dw_builder(self):
277                """This simulates a function attached to a model with a custom backward.
278                Each call to builder gives a new dw_runner that has some updated state to compute the latest dw.
279                """
280
281                def dw_runner():
282                    # This inner function would be called by PipelineStage during `backward_weight_one_chunk`
283                    print(f"dw called {self.i}th time")
284                    self.i += 1
285
286                return dw_runner
287
288        cs = CustomState()
289
290        stage = PipelineStage(
291            stage_mod,
292            self.rank,
293            self.world_size,
294            self.device,
295            input_args=x.chunk(chunks)[0],
296            dw_builder=cs.dw_builder,
297        )
298
299        # Attach to a schedule
300        schedule = ScheduleGPipe(
301            stage, chunks, loss_fn=torch.nn.MSELoss(reduction="sum")
302        )
303
304        # Run
305        def _run_step(x):
306            if self.rank == 0:
307                return schedule.step(x)
308            elif self.rank == self.world_size - 1:
309                return schedule.step(target=target)
310            else:
311                return schedule.step()
312
313        out = _run_step(x)
314
315        self.assertEqual(cs.i, chunks)
316
317        # Last rank checks result
318        if self.rank == self.world_size - 1:
319            ref_out = full_mod(x)
320            torch.testing.assert_close(out, ref_out)
321
322        if self.rank == 0:
323            with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"):
324                _run_step(torch.randn(batch_size + 1, d_hid, device=self.device))
325
326    @requires_nccl()
327    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
328    def test_custom_dw_errors(self):
329        """Tests expected errors are raised"""
330        full_mod = MultiMLP(d_hid, n_layers=self.world_size)
331        full_mod.to(self.device)
332        stage_mod = full_mod.get_submodule(f"layers.{self.rank}")
333
334        x = torch.randn(batch_size, d_hid, device=self.device)
335        target = torch.randn(batch_size, d_hid, device=self.device)
336
337        stage_with_dw_builder = PipelineStage(
338            stage_mod,
339            self.rank,
340            self.world_size,
341            self.device,
342            input_args=x.chunk(chunks)[0],
343            dw_builder=lambda: None,
344        )
345        with self.assertRaisesRegex(AssertionError, "backward_one_chunk"):
346            stage_with_dw_builder.backward_weight_one_chunk(bwd_chunk_id=0)
347
348
349instantiate_parametrized_tests(StageTest)
350
351if __name__ == "__main__":
352    # Check if GPU and NCCL are available
353    if not (
354        dist.is_available()
355        and dist.is_nccl_available()
356        and torch.cuda.device_count() > 1
357    ):
358        print(
359            "c10d NCCL not available or not enough GPUs, skipping tests",
360            file=sys.stderr,
361        )
362        sys.exit(0)
363
364    rank = int(os.getenv("RANK", -1))
365    world_size = int(os.getenv("WORLD_SIZE", 2))
366
367    if rank != -1:
368        # Launched with torchrun or other multi-proc launchers. Directly run the test.
369        StageTest.run_rank(rank, world_size)
370    else:
371        # Launched as a single process. Spawn subprocess to run the tests.
372        # Also need a rendezvous file for `init_process_group` purpose.
373        rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
374        torch.multiprocessing.spawn(
375            StageTest.run_rank,
376            nprocs=world_size,
377            args=(world_size, rdvz_file),
378        )
379