1# Owner(s): ["oncall: distributed"] 2 3import sys 4from contextlib import nullcontext 5from enum import auto, Enum 6from typing import List, Optional 7from unittest.mock import patch 8 9import torch 10import torch.nn as nn 11import torch.nn.functional as F 12from torch import distributed as dist 13from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 14from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy 15from torch.distributed.fsdp.wrap import ModuleWrapPolicy 16from torch.nn.parallel.distributed import DistributedDataParallel as DDP 17from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 18from torch.testing._internal.common_fsdp import ( 19 CUDAInitMode, 20 FSDPInitMode, 21 FSDPTest, 22 MLP, 23 NestedWrappedModule, 24 TransformerWithSharedParams, 25) 26from torch.testing._internal.common_utils import ( 27 instantiate_parametrized_tests, 28 parametrize, 29 run_tests, 30 TEST_WITH_DEV_DBG_ASAN, 31) 32 33 34if not dist.is_available(): 35 print("Distributed not available, skipping tests", file=sys.stderr) 36 sys.exit(0) 37 38if TEST_WITH_DEV_DBG_ASAN: 39 print( 40 "Skip dev-asan as torch + multiprocessing spawn have known issues", 41 file=sys.stderr, 42 ) 43 sys.exit(0) 44 45 46class PassType(Enum): 47 __order__ = "FWD BWD" 48 FWD = auto() 49 BWD = auto() 50 51 52class TestCommunication(FSDPTest): 53 """Tests ``FullyShardedDataParallel``'s collective communication usage.""" 54 55 def _init_model( 56 self, 57 nested_model: bool, 58 sharding_strategy: ShardingStrategy, 59 device: torch.device, 60 ): 61 fsdp_kwargs = {"sharding_strategy": sharding_strategy} 62 if nested_model: 63 model = NestedWrappedModule.init( 64 self.process_group, 65 FSDPInitMode.RECURSIVE, 66 CUDAInitMode.CUDA_AFTER, 67 fsdp_kwargs, 68 ) 69 fsdp_model: FSDP = FSDP( 70 model, 71 self.process_group, 72 **fsdp_kwargs, 73 ).to(device) 74 else: 75 fsdp_model: FSDP = TransformerWithSharedParams.init( 76 self.process_group, 77 FSDPInitMode.RECURSIVE, 78 CUDAInitMode.CUDA_BEFORE, 79 fsdp_kwargs, 80 ) 81 return fsdp_model 82 83 def _run_iter(self, fsdp_model, batch, use_no_sync: bool): 84 """Runs an iteration inside or outside the ``no_sync()`` context.""" 85 context = fsdp_model.no_sync() if use_no_sync else nullcontext() 86 with context: 87 output = fsdp_model(*batch) 88 loss = fsdp_model.module.get_loss(batch, output) 89 loss.backward() 90 91 def _get_ref_num_reduce_scatters( 92 self, 93 num_fsdp: int, 94 in_no_sync: bool, 95 ) -> int: 96 """Returns the reference number of reduce-scatters for an iteration 97 in the ``no_sync()`` context.""" 98 return num_fsdp if not in_no_sync else 0 99 100 def _get_ref_num_all_gathers( 101 self, 102 num_fsdp: int, 103 sharding_strategy: Optional[ShardingStrategy], 104 is_first_iter: bool, 105 is_last_iter_no_sync: bool, 106 ) -> int: 107 """Returns the reference number of all-gathers in an iteration, summing 108 over the forward and backward passes.""" 109 return sum( 110 self._get_ref_num_all_gathers_in_pass( 111 num_fsdp, 112 sharding_strategy, 113 pass_type, 114 is_first_iter, 115 is_last_iter_no_sync, 116 ) 117 for pass_type in PassType 118 ) 119 120 def _get_ref_num_all_gathers_in_pass( 121 self, 122 num_fsdp: int, 123 sharding_strategy: Optional[ShardingStrategy], 124 pass_type: PassType, 125 is_first_iter: bool, 126 is_last_iter_no_sync: bool, 127 ): 128 """Returns the reference number of all-gathers for a given setting.""" 129 if sharding_strategy is None: 130 sharding_strategy = ShardingStrategy.FULL_SHARD # default 131 # Forward pass: 132 if ( 133 pass_type == PassType.FWD 134 and sharding_strategy == ShardingStrategy.SHARD_GRAD_OP 135 and is_last_iter_no_sync 136 ): 137 # Modules do not free the full parameters in the last 138 # iteration's backward pass if it was in `no_sync()` 139 num_all_gathers = 0 140 elif pass_type == PassType.FWD: 141 # Otherwise, all modules all-gather the full parameters in the 142 # forward pass 143 num_all_gathers = num_fsdp 144 # Backward pass: 145 elif ( 146 pass_type == PassType.BWD 147 and sharding_strategy == ShardingStrategy.FULL_SHARD 148 ): 149 # Root does not free the full parameters at the end of the 150 # forward pass 151 num_all_gathers = num_fsdp - 1 152 elif ( 153 pass_type == PassType.BWD 154 and sharding_strategy == ShardingStrategy.SHARD_GRAD_OP 155 ): 156 # Modules do not free the full parameters at the end of the 157 # forward pass 158 num_all_gathers = 0 159 else: 160 assert 0, ( 161 f"Unsupported: add a branch for pass_type={pass_type} " 162 f"is_first_iter={is_first_iter} " 163 f"is_last_iter_no_sync={is_last_iter_no_sync} " 164 f"sharding_strategy={sharding_strategy}" 165 ) 166 if is_first_iter and pass_type == PassType.FWD: 167 # With execution order validation, on the first iteration, we have 168 # an additional two all-gathers before every actual all-gather in 169 # the forward pass 170 num_all_gathers *= 3 171 return num_all_gathers 172 173 def _print_ref_num_all_gathers_in_pass( 174 self, 175 num_fsdp: int, 176 sharding_strategy: ShardingStrategy, 177 pass_type: PassType, 178 is_first_iter: bool, 179 is_last_iter_no_sync: bool, 180 ): 181 """Helper method for printing the number of all-gathers for a specific 182 setting. This may be helpful since the branching is complex.""" 183 if self.rank != 0: 184 return # only print on one rank 185 num_all_gathers = self._get_ref_num_all_gathers_in_pass( 186 num_fsdp, 187 sharding_strategy, 188 pass_type, 189 is_first_iter, 190 is_last_iter_no_sync, 191 ) 192 print( 193 f"Pass: {pass_type}\n" 194 f"Is First Iteration: {is_first_iter}\n" 195 f"Sharding Strategy: {sharding_strategy}\n" 196 f"Last iteration in `no_sync()`: {is_last_iter_no_sync}\n" 197 f"Number of all-gathers: {num_all_gathers}" 198 ) 199 200 @skip_if_lt_x_gpu(2) 201 @parametrize("nested_model", [False, True]) 202 @parametrize("use_no_sync", [False, True]) 203 @parametrize("sharding_strategy", [ShardingStrategy.SHARD_GRAD_OP, None]) 204 def test_communication( 205 self, 206 nested_model: bool, 207 use_no_sync: bool, 208 sharding_strategy: Optional[ShardingStrategy], 209 ): 210 """ 211 Tests FSDP's communication cost in terms of calls to collective 212 communication primitives (i.e. all-gather and reduce-scatter). 213 214 Arguments: 215 nested_model (bool): If ``True``, uses ``NestedWrappedModule``, 216 which has nested FSDP instances; if ``False``, uses the default 217 model, which does not have nested FSDP instances. 218 use_no_sync (bool): If ``True``, runs some iterations inside the 219 ``no_sync()`` context manager to accumulate gradients, followed 220 by some iterations outside the context manager; if ``False``, 221 only runs some iterations outside the context manager. 222 sharding_strategy (Optional[ShardingStrategy]): Configures the 223 FSDP algorithm. 224 """ 225 # Enable execution order checking 226 dist.set_debug_level(dist.DebugLevel.DETAIL) 227 # Initialize the model and inputs 228 device = torch.device("cuda") 229 fsdp_model = self._init_model(nested_model, sharding_strategy, device) 230 batch = fsdp_model.module.get_input(device) 231 232 # Count the number of FSDP instances that manage parameters since the 233 # number of collectives are a function of this number 234 num_fsdp = sum( 235 (isinstance(m, FSDP) and len(m.params) > 0) for m in fsdp_model.modules() 236 ) 237 238 # If `use_no_sync=True`, we run `num_iters` iterations inside 239 # `no_sync()` followed by `num_iters` iterations outside `no_sync()`, 240 # and if `use_no_sync=False`, we only run `num_iters` iterations 241 # outside `no_sync()` 242 num_iters = 3 243 with patch( 244 "torch.distributed.all_gather_into_tensor" 245 ) as mock_all_gather, patch( 246 "torch.distributed.reduce_scatter_tensor" 247 ) as mock_reduce_scatter: 248 249 def reset_mocks(): 250 mock_all_gather.reset_mock() 251 mock_reduce_scatter.reset_mock() 252 253 # Check the communication cost when using `no_sync()` 254 if use_no_sync: 255 for i in range(num_iters): 256 reset_mocks() 257 self._run_iter(fsdp_model, batch, use_no_sync=True) 258 num_all_gathers = mock_all_gather.call_count 259 num_reduce_scatters = mock_reduce_scatter.call_count 260 ref_num_all_gathers = self._get_ref_num_all_gathers( 261 num_fsdp, 262 sharding_strategy, 263 is_first_iter=i == 0, 264 is_last_iter_no_sync=i > 0, 265 ) 266 ref_num_reduce_scatters = self._get_ref_num_reduce_scatters( 267 num_fsdp, 268 in_no_sync=True, 269 ) 270 self.assertEqual(num_all_gathers, ref_num_all_gathers) 271 self.assertEqual(num_reduce_scatters, ref_num_reduce_scatters) 272 # Check the normal communication cost (when not using `no_sync()`) 273 for i in range(num_iters): 274 reset_mocks() 275 self._run_iter(fsdp_model, batch, use_no_sync=False) 276 num_all_gathers = mock_all_gather.call_count 277 num_reduce_scatters = mock_reduce_scatter.call_count 278 ref_num_all_gathers = self._get_ref_num_all_gathers( 279 num_fsdp, 280 sharding_strategy, 281 is_first_iter=not use_no_sync and i == 0, 282 is_last_iter_no_sync=use_no_sync and i == 0, 283 ) 284 ref_num_reduce_scatters = self._get_ref_num_reduce_scatters( 285 num_fsdp, 286 in_no_sync=False, 287 ) 288 self.assertEqual(num_all_gathers, ref_num_all_gathers) 289 self.assertEqual(num_reduce_scatters, ref_num_reduce_scatters) 290 291 292class TestExplicitUnshard(FSDPTest): 293 @property 294 def world_size(self) -> int: 295 return min(torch.cuda.device_count(), 2) 296 297 @skip_if_lt_x_gpu(2) 298 @parametrize("use_orig_params", [False, True]) 299 def test_unshard_async(self, use_orig_params: bool): 300 class ReduceModule(nn.Module): 301 def __init__(self, dim: int, group: dist.ProcessGroup): 302 super().__init__() 303 self.group = group 304 self.weight = nn.Parameter(torch.randn(dim, dim)) 305 306 def forward(self, x: torch.Tensor): 307 y = F.relu(x @ self.weight) 308 # NOTE: This all-reduce is not differentiable and is included 309 # to exercise the overlap. 310 work = dist.all_reduce(y, group=self.group, async_op=True) 311 return y, work 312 313 class MLPs(nn.Module): 314 def __init__(self, dim: int): 315 super().__init__() 316 self.mlp1 = MLP(dim) 317 self.mlp2 = MLP(dim) 318 self.mlp3 = MLP(dim) 319 320 def forward(self, ys: List[torch.Tensor], works: List[dist.Work]): 321 (y1, y2, y3), (work1, work2, work3) = ys, works 322 work1.wait() 323 z1 = self.mlp1(y1) 324 work2.wait() 325 z2 = self.mlp2(y2) 326 work3.wait() 327 z3 = self.mlp3(y3) 328 return z1 + z2 + z3 329 330 class ReduceModel(nn.Module): 331 def __init__(self, dim: int, group: dist.ProcessGroup): 332 super().__init__() 333 self.reduce_module1 = ReduceModule(dim, group) 334 self.reduce_module2 = ReduceModule(dim, group) 335 self.reduce_module3 = ReduceModule(dim, group) 336 self.mlps = MLPs(dim) 337 338 def forward(self, x: torch.Tensor): 339 y1, work1 = self.reduce_module1(x) 340 if isinstance(self.mlps.mlp1, FSDP): 341 self.mlps.mlp1._unshard(async_op=True) 342 y2, work2 = self.reduce_module2(x) 343 if isinstance(self.mlps.mlp2, FSDP): 344 self.mlps.mlp2._unshard(async_op=True) 345 y3, work3 = self.reduce_module3(x) 346 if isinstance(self.mlps.mlp3, FSDP): 347 self.mlps.mlp3._unshard(async_op=True) 348 return self.mlps([y1, y2, y3], [work1, work2, work3]) 349 350 group = self.process_group 351 batch_size, dim = 2, 8 352 torch.manual_seed(42) 353 ref_model = DDP(ReduceModel(dim, group).cuda(), device_ids=[self.rank]) 354 ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) 355 356 torch.manual_seed(42) 357 model = ReduceModel(dim, group) 358 model.mlps = FSDP( 359 model.mlps, 360 sharding_strategy=ShardingStrategy.SHARD_GRAD_OP, 361 auto_wrap_policy=ModuleWrapPolicy((MLP,)), 362 device_id=self.rank, 363 use_orig_params=use_orig_params, 364 ) 365 model.mlps.check_is_root() 366 mlp_params = set(model.mlps.parameters()) 367 mlp_param_names = {n for n, p in model.named_parameters() if p in mlp_params} 368 DDP._set_params_and_buffers_to_ignore_for_model(model, mlp_param_names) 369 model = DDP(model.cuda(), device_ids=[self.rank]) 370 optim = torch.optim.Adam(model.parameters(), lr=1e-2) 371 372 torch.manual_seed(42 + self.rank + 1) 373 inp = torch.randn((batch_size, dim), device="cuda") 374 375 for _ in range(10): 376 losses: List[torch.Tensor] = [] 377 for _model, _optim in ((ref_model, ref_optim), (model, optim)): 378 losses.append(_model(inp).sum()) 379 losses[-1].backward() 380 _optim.step() 381 _optim.zero_grad() 382 self.assertEqual(losses[0], losses[1]) 383 model.module.mlps._wait_unshard_streams_on_current_stream() 384 385 386instantiate_parametrized_tests(TestCommunication) 387instantiate_parametrized_tests(TestExplicitUnshard) 388 389if __name__ == "__main__": 390 run_tests() 391