1# Owner(s): ["oncall: distributed"] 2import copy 3import os 4 5import torch 6import torch.nn as nn 7from torch.distributed._composable.fsdp.fully_shard import ( 8 fully_shard, 9 MixedPrecisionPolicy, 10) 11from torch.distributed._tensor import DTensor 12from torch.distributed.device_mesh import init_device_mesh 13from torch.distributed.pipelining import PipelineStage 14from torch.distributed.pipelining.schedules import ( 15 PipelineScheduleSingle, 16 Schedule1F1B, 17 ScheduleFlexibleInterleaved1F1B, 18 ScheduleGPipe, 19 ScheduleInterleaved1F1B, 20 ScheduleInterleavedZeroBubble, 21 ScheduleLoopedBFS, 22) 23from torch.nn.parallel import DistributedDataParallel as DDP 24from torch.testing._internal.common_cuda import TEST_MULTIGPU 25from torch.testing._internal.common_distributed import ( 26 MultiProcessTestCase, 27 requires_nccl, 28 skip_if_lt_x_gpu, 29) 30from torch.testing._internal.common_utils import ( 31 instantiate_parametrized_tests, 32 parametrize, 33 run_tests, 34 skip_but_pass_in_sandcastle_if, 35) 36 37 38# MLP Layer 39class MLPModule(torch.nn.Module): 40 def __init__(self, d_hid: int): 41 super().__init__() 42 self.net1 = torch.nn.Linear(d_hid, d_hid) 43 self.relu = torch.nn.ReLU() 44 self.net2 = torch.nn.Linear(d_hid, d_hid) 45 46 def forward(self, x): 47 x = self.net1(x) 48 x = self.relu(x) 49 x = self.net2(x) 50 return x 51 52 53class ComposabilityTest(MultiProcessTestCase): 54 @classmethod 55 def backend_str(cls) -> str: 56 # Testing with NCCL backend 57 return "nccl" 58 59 def setUp(self): 60 super().setUp() 61 self._spawn_processes() 62 63 def tearDown(self): 64 super().tearDown() 65 try: 66 os.remove(self.file_name) 67 except OSError: 68 pass 69 70 @property 71 def world_size(self): 72 return 4 73 74 @property 75 def device(self): 76 return self.rank 77 78 @requires_nccl() 79 @skip_if_lt_x_gpu(4) 80 @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "Test requires 4+ GPUs") 81 @parametrize("dp_type", ["DDP", "FSDP"]) 82 @parametrize( 83 "ScheduleClass", 84 [ 85 ScheduleGPipe, 86 Schedule1F1B, 87 ScheduleInterleaved1F1B, 88 ScheduleLoopedBFS, 89 ScheduleFlexibleInterleaved1F1B, 90 ScheduleInterleavedZeroBubble, 91 ], 92 ) 93 @parametrize("use_new_runtime", [False, True]) 94 def test_manual_with_data_parallel(self, dp_type, ScheduleClass, use_new_runtime): 95 device = torch.device("cuda", self.device) 96 torch.cuda.set_device(self.device) 97 store = torch.distributed.FileStore(self.file_name, self.world_size) 98 torch.distributed.init_process_group( 99 backend="nccl", 100 store=store, 101 rank=self.rank, 102 world_size=self.world_size, 103 device_id=device, 104 ) 105 device_mesh = init_device_mesh( 106 "cuda", mesh_shape=(2, 2), mesh_dim_names=("dp", "pp") 107 ) 108 pp_group = device_mesh["pp"].get_group() 109 dp_mesh = device_mesh["dp"] 110 111 # create "entire model" 112 total_layers = 8 113 dim = 10 114 full_model = nn.ModuleList([MLPModule(dim) for _ in range(total_layers)]) 115 ref_model = nn.Sequential(*copy.deepcopy(full_model)) 116 ref_model.to(self.device) 117 118 # Prepare inputs 119 num_microbatches = 8 120 inputs = [ 121 torch.rand((num_microbatches, dim), device=self.device) 122 for _ in range(dp_mesh.size()) 123 ] 124 input = inputs[dp_mesh.get_local_rank()] 125 input_mb = [[input[i].reshape((1, dim))] for i in range(num_microbatches)] 126 127 # dummy loss needed just to force backwards to run in schedule step 128 def loss_fn(y, target): 129 return y.sum() 130 131 # Get stage module i from the entire model 132 def get_stage_module(stage_idx, num_stages): 133 # divide the model (8 layers) by the number of stages 134 layers_per_stage = total_layers // num_stages 135 assert layers_per_stage * num_stages == total_layers 136 # return offset so validation code can match partial layer back to orig model 137 offset = stage_idx * layers_per_stage 138 partial_model = nn.Sequential( 139 *full_model[offset : (stage_idx + 1) * layers_per_stage] 140 ) 141 partial_model.to(self.device) 142 return partial_model, offset 143 144 # Apply DP to stage module 145 def apply_dp(partial_model, dp_type): 146 if dp_type == "FSDP": 147 # apply FSDP 148 mp_policy = MixedPrecisionPolicy( 149 # TODO(whc) need to fix PP + FSDP-mixed-precision 150 # tracer for PP assumes f32 and is caught off guard when runtime FSDP interacts using bf16 inputs 151 # param_dtype=torch.bfloat16, reduce_dtype=torch.float32 152 param_dtype=torch.float32, 153 reduce_dtype=torch.float32, 154 ) 155 fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} 156 for layer in partial_model.children(): 157 fully_shard( 158 layer, 159 **fsdp_config, 160 reshard_after_forward=False, 161 ) 162 dp_model = fully_shard(partial_model, **fsdp_config) 163 elif dp_type == "DDP": 164 dp_model = DDP(partial_model, process_group=dp_mesh.get_group()) 165 else: 166 raise RuntimeError(f"unsupported dp type {dp_type}") 167 return dp_model 168 169 # Create pipeline stage 170 def build_stage(stage_idx, num_stages): 171 partial_model, offset = get_stage_module(stage_idx, num_stages) 172 dp_model = apply_dp(partial_model, dp_type) 173 stage = PipelineStage( 174 dp_model, 175 stage_idx, 176 num_stages, 177 self.device, 178 group=pp_group, 179 input_args=input_mb[0], 180 ) 181 return stage, offset 182 183 # Attach to a schedule 184 if issubclass(ScheduleClass, PipelineScheduleSingle): 185 if use_new_runtime: 186 # Can't test PipelineScheduleSingle classes using new runtime 187 # return should still clean up this test instance correctly 188 torch.distributed.destroy_process_group() 189 return 190 pipeline_stage, offset = build_stage(pp_group.rank(), pp_group.size()) 191 partial_models = [pipeline_stage.submod] 192 offsets = [offset] 193 pipeline_schedule = ScheduleClass( 194 pipeline_stage, 195 n_microbatches=num_microbatches, 196 loss_fn=loss_fn, 197 ) 198 else: 199 n_virtual = 2 200 num_stages = pp_group.size() * n_virtual 201 stages = [] 202 offsets = [] 203 for i in range(n_virtual): 204 stage, offset = build_stage(pp_group.rank() + n_virtual * i, num_stages) 205 stages.append(stage) 206 offsets.append(offset) 207 partial_models = [pipeline_stage.submod for pipeline_stage in stages] 208 pipeline_schedule = ScheduleClass( 209 stages, 210 n_microbatches=num_microbatches, 211 loss_fn=loss_fn, 212 ) 213 214 # Run 215 pipeline_schedule._step_microbatches(arg_mbs=input_mb, target_mbs=input_mb) 216 217 # Ref model runs on 2 different inputs, accumulating grads across them. 218 # this ensures that we detect if the FSDP reduce becomes a no-op. 219 # (in fsdp case, we use one of these inputs on each DP rank) 220 (ref_model(inputs[0]).sum()).backward() 221 (ref_model(inputs[1]).sum()).backward() 222 223 # simulate the built-in averaging done by FSDP 224 for p in ref_model.parameters(): 225 p.grad /= dp_mesh.size() 226 227 # Validate that whichever weights we have locally match that part of our local/full ref model 228 # (we force FSDP's grads to be all-gathered (.full_tensor) to make it simpler) 229 ref_parameters = dict(ref_model.named_parameters()) 230 if dp_type == "FSDP": 231 for partial_model, offset in zip(partial_models, offsets): 232 for name, p in partial_model.named_parameters(): 233 parts = name.split(".") 234 parts[0] = str(int(parts[0]) + offset) 235 name = ".".join(parts) 236 ref_p = ref_parameters[name] 237 self.assertTrue(isinstance(p.grad, DTensor)) 238 torch.testing.assert_close( 239 ref_p.grad, p.grad.full_tensor(), rtol=1e-5, atol=5e-5 240 ) 241 elif dp_type == "DDP": 242 for partial_model, offset in zip(partial_models, offsets): 243 for name, p in partial_model.named_parameters(): 244 parts = name.split(".")[1:] # remove the "module." prefix 245 parts[0] = str(int(parts[0]) + offset) 246 name = ".".join(parts) 247 ref_p = ref_parameters[name] 248 torch.testing.assert_close(ref_p.grad, p.grad, rtol=1e-5, atol=5e-5) 249 250 torch.distributed.destroy_process_group() 251 252 253instantiate_parametrized_tests(ComposabilityTest) 254 255if __name__ == "__main__": 256 run_tests() 257