1# Copyright (c) Meta Platforms, Inc. and affiliates 2# Owner(s): ["oncall: distributed"] 3import os 4import sys 5import tempfile 6 7from model_registry import ExampleCode, ModelWithKwargs, MultiMLP 8 9import torch 10import torch.distributed as dist 11from torch.distributed.pipelining import ( 12 build_stage, 13 pipeline, 14 PipelineStage, 15 ScheduleGPipe, 16) 17from torch.distributed.pipelining._utils import PipeliningShapeError 18from torch.testing._internal.common_cuda import TEST_MULTIGPU 19from torch.testing._internal.common_distributed import ( 20 MultiProcContinousTest, 21 requires_nccl, 22) 23from torch.testing._internal.common_utils import ( 24 instantiate_parametrized_tests, 25 parametrize, 26 skip_but_pass_in_sandcastle_if, 27) 28from torch.utils._pytree import tree_map_only 29 30 31d_hid = 512 32batch_size = 256 33chunks = 4 34 35torch.manual_seed(0) 36 37 38def get_dtype_change_hook(new_dtype): 39 """A simple hook for simulating mixed precision""" 40 41 def dtype_change_hook(module, input, output): 42 def f(x): 43 return x.to(new_dtype) 44 45 return tree_map_only(torch.Tensor, f, output) 46 47 return dtype_change_hook 48 49 50def get_flatten_hook(): 51 """A simple hook for simulating wrong model output shape""" 52 53 def flatten_hook(module, input, output): 54 def f(x): 55 return x.flatten() 56 57 return tree_map_only(torch.Tensor, f, output) 58 59 return flatten_hook 60 61 62class StageTest(MultiProcContinousTest): 63 @classmethod 64 def backend_str(cls) -> str: 65 # Testing with NCCL backend 66 return "nccl" 67 68 @classmethod 69 def setUpClass(cls): 70 """ 71 Class-scope test fixture. Run once for entire test class, before any test starts. 72 Set up the device. 73 """ 74 super().setUpClass() 75 dev_id = cls.rank % torch.cuda.device_count() 76 cls.device = torch.device(f"cuda:{dev_id}") 77 78 @requires_nccl() 79 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") 80 @parametrize("ModelClass", [ExampleCode, MultiMLP]) 81 def test_tracer(self, ModelClass): 82 mod = ModelClass(d_hid) 83 mod.to(self.device) 84 85 x = torch.randn(batch_size, d_hid, device=self.device) 86 x_mb = x.chunk(chunks)[0] 87 88 split_spec = mod.split_spec if hasattr(mod, "split_spec") else None 89 pipe = pipeline( 90 mod, 91 mb_args=(x_mb,), 92 split_spec=split_spec, 93 ) 94 95 stage = pipe.build_stage( 96 self.rank, 97 self.device, 98 ) 99 100 # Attach to a schedule 101 schedule = ScheduleGPipe(stage, chunks) 102 103 # Run 104 def _run_step(x): 105 if self.rank == 0: 106 return schedule.step(x) 107 else: 108 return schedule.step() 109 110 out = _run_step(x) 111 # Last rank checks result 112 if self.rank == self.world_size - 1: 113 ref_out = mod(x) 114 torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=5e-2) 115 116 # Test qualname mapping 117 submod_keys = stage.submod.state_dict().keys() 118 # Confirm keys are consistent with original model 119 old_keys = mod.state_dict().keys() 120 assert all(k in old_keys for k in submod_keys) 121 122 if self.rank == 0: 123 # intended to run this code on all ranks, but the problem is if rank0 throws, 124 # it won't perform the send that unblocks rank 1. 125 126 # TODO(whc) can't test this until fixing args/kwargs issue 127 # with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"): 128 # _run_step(torch.randn(batch_size + 1, d_hid, device=self.device)) 129 130 with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"): 131 _run_step(x.to(torch.int32)) 132 133 # output of stage's mlp layer will be flattened by this hook, the stage should err 134 handle = stage.submod.register_forward_hook(get_flatten_hook()) 135 with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"): 136 _run_step(x) 137 handle.remove() 138 139 stage.submod.register_forward_hook(get_dtype_change_hook(torch.bfloat16)) 140 with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"): 141 _run_step(x) 142 143 @requires_nccl() 144 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") 145 @parametrize("ModelClass", [ModelWithKwargs]) 146 def test_tracer_kwargs(self, ModelClass): 147 mod = ModelClass(d_hid) 148 mod.to(self.device) 149 150 x = torch.randn(batch_size, d_hid, device=self.device) 151 y = torch.randn(batch_size, d_hid, device=self.device) 152 153 x_mb = x.chunk(chunks)[0] 154 y_mb = y.chunk(chunks)[0] 155 156 pipe = pipeline( 157 mod, 158 mb_args=(x_mb,), 159 mb_kwargs={"y": y_mb}, 160 ) 161 162 stage_mod = pipe.get_stage_module(self.rank) 163 164 # Test build_stage 165 stage = build_stage( 166 stage_mod, 167 self.rank, 168 pipe.info(), 169 self.device, 170 ) 171 172 # Attach to a schedule 173 schedule = ScheduleGPipe(stage, chunks) 174 175 # Run 176 def _run_step(x): 177 if self.rank == 0: 178 return schedule.step(x, y=y) 179 else: 180 return schedule.step() 181 182 # Last rank checks result 183 out = _run_step(x) 184 if self.rank == self.world_size - 1: 185 ref_out = mod(x, y=y) 186 torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=5e-2) 187 188 # Test qualname mapping 189 submod_keys = stage.submod.state_dict().keys() 190 # Confirm keys are consistent with original model 191 old_keys = mod.state_dict().keys() 192 assert all(k in old_keys for k in submod_keys) 193 194 if self.rank == 0: 195 with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"): 196 _run_step(torch.randn(batch_size + 1, d_hid, device=self.device)) 197 198 with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"): 199 _run_step(x.to(torch.int32)) 200 201 # output of stage's mlp layer will be flattened by this hook, the stage should err 202 handle = stage.submod.register_forward_hook(get_flatten_hook()) 203 with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"): 204 _run_step(x) 205 handle.remove() 206 207 stage.submod.register_forward_hook(get_dtype_change_hook(torch.bfloat16)) 208 with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"): 209 _run_step(x) 210 211 @requires_nccl() 212 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") 213 def test_manual(self): 214 full_mod = MultiMLP(d_hid, n_layers=self.world_size) 215 full_mod.to(self.device) 216 stage_mod = full_mod.get_submodule(f"layers.{self.rank}") 217 218 x = torch.randn(batch_size, d_hid, device=self.device) 219 220 stage = PipelineStage( 221 stage_mod, 222 self.rank, 223 self.world_size, 224 self.device, 225 input_args=x.chunk(chunks)[0], 226 ) 227 228 # Attach to a schedule 229 schedule = ScheduleGPipe(stage, chunks) 230 231 # Run 232 def _run_step(x): 233 if self.rank == 0: 234 return schedule.step(x) 235 else: 236 return schedule.step() 237 238 out = _run_step(x) 239 # Last rank checks result 240 if self.rank == self.world_size - 1: 241 ref_out = full_mod(x) 242 torch.testing.assert_close(out, ref_out) 243 244 if self.rank == 0: 245 with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"): 246 _run_step(torch.randn(batch_size + 1, d_hid, device=self.device)) 247 248 with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"): 249 _run_step(x.to(torch.int32)) 250 251 # output of stage's mlp layer will be flattened by this hook, the stage should err 252 handle = stage_mod.register_forward_hook(get_flatten_hook()) 253 with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"): 254 _run_step(x) 255 handle.remove() 256 257 stage_mod.register_forward_hook(get_dtype_change_hook(torch.bfloat16)) 258 with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"): 259 _run_step(x) 260 261 @requires_nccl() 262 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") 263 def test_custom_dw_with_fb_schedule(self): 264 """Tests that separate weight grad function 'dw_runner' gets run under a schedule that's only aware of F/B.""" 265 full_mod = MultiMLP(d_hid, n_layers=self.world_size) 266 full_mod.to(self.device) 267 stage_mod = full_mod.get_submodule(f"layers.{self.rank}") 268 269 x = torch.randn(batch_size, d_hid, device=self.device) 270 target = torch.randn(batch_size, d_hid, device=self.device) 271 272 class CustomState: 273 def __init__(self) -> None: 274 self.i = 0 275 276 def dw_builder(self): 277 """This simulates a function attached to a model with a custom backward. 278 Each call to builder gives a new dw_runner that has some updated state to compute the latest dw. 279 """ 280 281 def dw_runner(): 282 # This inner function would be called by PipelineStage during `backward_weight_one_chunk` 283 print(f"dw called {self.i}th time") 284 self.i += 1 285 286 return dw_runner 287 288 cs = CustomState() 289 290 stage = PipelineStage( 291 stage_mod, 292 self.rank, 293 self.world_size, 294 self.device, 295 input_args=x.chunk(chunks)[0], 296 dw_builder=cs.dw_builder, 297 ) 298 299 # Attach to a schedule 300 schedule = ScheduleGPipe( 301 stage, chunks, loss_fn=torch.nn.MSELoss(reduction="sum") 302 ) 303 304 # Run 305 def _run_step(x): 306 if self.rank == 0: 307 return schedule.step(x) 308 elif self.rank == self.world_size - 1: 309 return schedule.step(target=target) 310 else: 311 return schedule.step() 312 313 out = _run_step(x) 314 315 self.assertEqual(cs.i, chunks) 316 317 # Last rank checks result 318 if self.rank == self.world_size - 1: 319 ref_out = full_mod(x) 320 torch.testing.assert_close(out, ref_out) 321 322 if self.rank == 0: 323 with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"): 324 _run_step(torch.randn(batch_size + 1, d_hid, device=self.device)) 325 326 @requires_nccl() 327 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") 328 def test_custom_dw_errors(self): 329 """Tests expected errors are raised""" 330 full_mod = MultiMLP(d_hid, n_layers=self.world_size) 331 full_mod.to(self.device) 332 stage_mod = full_mod.get_submodule(f"layers.{self.rank}") 333 334 x = torch.randn(batch_size, d_hid, device=self.device) 335 target = torch.randn(batch_size, d_hid, device=self.device) 336 337 stage_with_dw_builder = PipelineStage( 338 stage_mod, 339 self.rank, 340 self.world_size, 341 self.device, 342 input_args=x.chunk(chunks)[0], 343 dw_builder=lambda: None, 344 ) 345 with self.assertRaisesRegex(AssertionError, "backward_one_chunk"): 346 stage_with_dw_builder.backward_weight_one_chunk(bwd_chunk_id=0) 347 348 349instantiate_parametrized_tests(StageTest) 350 351if __name__ == "__main__": 352 # Check if GPU and NCCL are available 353 if not ( 354 dist.is_available() 355 and dist.is_nccl_available() 356 and torch.cuda.device_count() > 1 357 ): 358 print( 359 "c10d NCCL not available or not enough GPUs, skipping tests", 360 file=sys.stderr, 361 ) 362 sys.exit(0) 363 364 rank = int(os.getenv("RANK", -1)) 365 world_size = int(os.getenv("WORLD_SIZE", 2)) 366 367 if rank != -1: 368 # Launched with torchrun or other multi-proc launchers. Directly run the test. 369 StageTest.run_rank(rank, world_size) 370 else: 371 # Launched as a single process. Spawn subprocess to run the tests. 372 # Also need a rendezvous file for `init_process_group` purpose. 373 rdvz_file = tempfile.NamedTemporaryFile(delete=False).name 374 torch.multiprocessing.spawn( 375 StageTest.run_rank, 376 nprocs=world_size, 377 args=(world_size, rdvz_file), 378 ) 379