1# mypy: allow-untyped-decorators 2from typing import cast, List, NamedTuple, Optional, Tuple, Union 3 4import torch 5import torch._dynamo.compiled_autograd as ca 6import torch.distributed as dist 7from torch.distributed.distributed_c10d import ReduceOp 8from torch.distributed.tensor import DTensor 9 10from ._fsdp_common import ( 11 _get_dim0_padded_size, 12 _raise_assert_with_print, 13 _to_dtype_if_needed, 14) 15from ._fsdp_param import FSDPParam, ShardedState 16 17 18class AllGatherResult(NamedTuple): 19 all_gather_output: torch.Tensor 20 all_gather_event: Optional[torch.cuda.Event] 21 all_gather_work: Optional[dist.distributed_c10d.Work] 22 # For each parameter, the all-gather input dtype for each input 23 param_all_gather_input_dtypes: List[List[torch.dtype]] 24 # For each parameter, the all-gather input numel for each input 25 param_all_gather_input_numels: List[List[int]] 26 # 1D flattened version of `param_all_gather_input_numels` saved to avoid 27 # CPU overhead from recomputing 28 all_gather_input_split_sizes: List[int] 29 30 31lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901 32 33lib.define( 34 """ 35 all_gather_copy_in( 36 Tensor[] all_gather_inputs, 37 SymInt[] inp_split_sizes, 38 SymInt all_gather_input_numel, 39 SymInt world_size, 40 SymInt rank, 41 ScalarType dtype, 42 Device device 43 ) -> (Tensor, Tensor) 44 """ 45) 46 47 48@torch.library.impl(lib, "all_gather_copy_in", "Meta") 49def all_gather_copy_in_meta( 50 all_gather_inputs: List[torch.Tensor], 51 inp_split_sizes: List[int], 52 all_gather_input_numel: int, 53 world_size: int, 54 rank: int, 55 dtype: torch.dtype, 56 device: torch.device, 57) -> Tuple[torch.Tensor, torch.Tensor]: 58 all_gather_output = torch.empty( 59 (all_gather_input_numel * world_size,), dtype=dtype, device="meta" 60 ) 61 all_gather_input = all_gather_output.narrow( 62 0, all_gather_input_numel * rank, all_gather_input_numel 63 ) 64 return all_gather_input, all_gather_output 65 66 67@torch.library.impl(lib, "all_gather_copy_in", "CUDA") 68@torch.library.impl(lib, "all_gather_copy_in", "CPU") 69def all_gather_copy_in_cuda( 70 all_gather_inputs: List[torch.Tensor], 71 inp_split_sizes: List[int], 72 all_gather_input_numel: int, 73 world_size: int, 74 rank: int, 75 dtype: torch.dtype, 76 device: torch.device, 77) -> Tuple[torch.Tensor, torch.Tensor]: 78 all_gather_output = torch.empty( 79 (all_gather_input_numel * world_size,), dtype=dtype, device=device 80 ) 81 all_gather_input = all_gather_output.narrow( 82 0, all_gather_input_numel * rank, all_gather_input_numel 83 ) 84 foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) 85 with torch.no_grad(): 86 torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs) 87 return all_gather_input, all_gather_output 88 89 90lib.define( 91 "split_with_sizes_copy(Tensor all_gather_output, SymInt[] all_gather_input_split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()" 92) 93 94 95@torch.library.impl(lib, "split_with_sizes_copy", "Meta") 96@torch.library.impl(lib, "split_with_sizes_copy", "CUDA") 97@torch.library.impl(lib, "split_with_sizes_copy", "CPU") 98def split_with_sizes_copy( 99 all_gather_output: torch.Tensor, 100 all_gather_input_split_sizes: List[int], 101 dim: int, 102 out: List[torch.Tensor], 103) -> None: 104 torch.split_with_sizes_copy( 105 all_gather_output, all_gather_input_split_sizes, dim=dim, out=out 106 ) 107 108 109lib.define( 110 "chunk_cat(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> ()" 111) 112 113 114@torch.library.impl(lib, "chunk_cat", "Meta") 115@torch.library.impl(lib, "chunk_cat", "CUDA") 116@torch.library.impl(lib, "chunk_cat", "CPU") 117def chunk_cat( 118 tensors: List[torch.Tensor], 119 dim: int, 120 num_chunks: int, 121 out: torch.Tensor, 122) -> None: 123 torch._chunk_cat(tensors, dim, num_chunks, out=out) 124 125 126@torch.no_grad() 127def foreach_all_gather( 128 fsdp_params: List[FSDPParam], 129 group: dist.ProcessGroup, 130 async_op: bool, 131 all_gather_copy_in_stream: torch.cuda.Stream, 132 all_gather_stream: torch.cuda.Stream, 133 device: torch.device, 134) -> Optional[AllGatherResult]: 135 world_size, rank = group.size(), group.rank() 136 with torch.cuda.stream(all_gather_copy_in_stream): 137 param_all_gather_inputs = _get_param_all_gather_inputs(fsdp_params) 138 ( 139 param_all_gather_input_dtypes, 140 param_all_gather_input_numels, 141 dtype, 142 ) = _get_all_gather_input_metadatas(param_all_gather_inputs) 143 if dtype == torch.uint8: 144 all_gather_inputs = [ 145 t.view(torch.uint8) for ts in param_all_gather_inputs for t in ts 146 ] 147 else: 148 all_gather_inputs = [t for ts in param_all_gather_inputs for t in ts] 149 inp_split_sizes = [t.numel() for t in all_gather_inputs] 150 all_gather_input_numel = sum(inp_split_sizes) 151 all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in( 152 all_gather_inputs, 153 inp_split_sizes, 154 all_gather_input_numel, 155 world_size, 156 rank, 157 dtype, 158 device, 159 ) 160 del param_all_gather_inputs 161 all_gather_stream.wait_stream(all_gather_copy_in_stream) 162 with torch.cuda.stream(all_gather_stream): 163 all_gather_work = dist.all_gather_into_tensor( 164 output_tensor=all_gather_output, 165 input_tensor=all_gather_input, 166 group=group, 167 async_op=async_op, 168 ) 169 all_gather_event = all_gather_stream.record_event() 170 return AllGatherResult( 171 all_gather_output, 172 all_gather_event, 173 all_gather_work, 174 param_all_gather_input_dtypes, 175 param_all_gather_input_numels, 176 inp_split_sizes, 177 ) 178 179 180@torch.no_grad() 181def _get_param_all_gather_inputs( 182 fsdp_params: List[FSDPParam], 183) -> List[List[torch.Tensor]]: 184 if ca.compiled_autograd_enabled: 185 return [fsdp_param.all_gather_inputs for fsdp_param in fsdp_params] 186 187 # Intentionally try to run a fast-path that bypasses abstractions for the 188 # common FSDP case of bf16/fp32 mixed precision in order to use foreach 189 # copy for lower CPU overhead and more efficient copying in eager 190 def use_foreach_copy(fsdp_param: FSDPParam) -> bool: 191 return ( 192 fsdp_param.param_dtype is not None 193 and not fsdp_param.offload_to_cpu 194 and not hasattr(fsdp_param._sharded_local_tensor, "fsdp_pre_all_gather") 195 ) 196 197 param_all_gather_inputs: List[List[torch.Tensor]] = [[] for _ in fsdp_params] 198 foreach_copy_indices: List[int] = [] 199 foreach_copy_inputs: List[torch.Tensor] = [] 200 foreach_copy_input_numels: List[int] = [] 201 202 # 1st pass: for foreach-copy parameters, get inputs and metadata for the 203 # foreach copy, and for the others, actually get their all-gather inputs 204 for i, fsdp_param in enumerate(fsdp_params): 205 if use_foreach_copy(fsdp_param): 206 foreach_copy_indices.append(i) 207 all_gather_input = ( 208 fsdp_param._sharded_param_data 209 if fsdp_param.sharded_state == ShardedState.SHARDED 210 else cast(torch.Tensor, fsdp_param._sharded_post_forward_param_data) 211 ) 212 foreach_copy_inputs.append(all_gather_input) 213 foreach_copy_input_numels.append(all_gather_input.numel()) 214 else: 215 param_all_gather_inputs[i] = fsdp_param.all_gather_inputs 216 217 # 2nd pass: use foreach copy to compute the remaining all-gather inputs 218 if foreach_copy_inputs: 219 fsdp_param_0 = fsdp_params[foreach_copy_indices[0]] 220 param_dtype, device = fsdp_param_0.param_dtype, fsdp_param_0.device 221 flat_foreach_copy_input = torch.empty( 222 (sum(foreach_copy_input_numels),), device=device, dtype=param_dtype 223 ) 224 splits = torch.split(flat_foreach_copy_input, foreach_copy_input_numels) 225 torch._foreach_copy_(splits, foreach_copy_inputs) 226 for i, split in zip(foreach_copy_indices, splits): 227 param_all_gather_inputs[i] = [split] 228 229 return param_all_gather_inputs 230 231 232@torch.no_grad() 233def foreach_all_gather_copy_out( 234 all_gather_result: AllGatherResult, 235 fsdp_params: List[FSDPParam], 236 group: dist.ProcessGroup, 237) -> None: 238 ( 239 all_gather_output, 240 all_gather_event, 241 all_gather_work, 242 param_all_gather_input_dtypes, 243 param_all_gather_input_numels, 244 all_gather_input_split_sizes, 245 ) = all_gather_result 246 if all_gather_event is not None: # sync op 247 torch.cuda.current_stream().wait_event(all_gather_event) 248 if isinstance(all_gather_work, dist.distributed_c10d.Work): # async op 249 all_gather_work.wait() 250 world_size, device = group.size(), all_gather_output.device 251 for all_gather_input_numels, all_gather_input_dtypes, fsdp_param in zip( 252 param_all_gather_input_numels, param_all_gather_input_dtypes, fsdp_params 253 ): 254 if ca.compiled_autograd_enabled: 255 fsdp_param.init_all_gather_outputs( 256 all_gather_input_numels, 257 all_gather_input_dtypes, 258 world_size, 259 device, 260 # NOTE: Under compile, make sure we always recreate all_gather_outputs 261 # per AllGather. See [Note: Invariants for torch.compile Traceable FSDP2]. 262 force_recreate=True, 263 ) 264 else: 265 fsdp_param.init_all_gather_outputs( 266 all_gather_input_numels, all_gather_input_dtypes, world_size, device 267 ) # no-op after 1st call 268 fsdp_param.alloc_all_gather_outputs() 269 all_gather_output = all_gather_output.view(world_size, -1) 270 gen = (t for fsdp_param in fsdp_params for t in fsdp_param.all_gather_outputs) 271 if all_gather_output.dtype == torch.uint8: 272 out = [t.view(world_size, -1).view(torch.uint8) for t in gen] 273 else: 274 out = [t.view(world_size, -1) for t in gen] 275 torch.ops.fsdp.split_with_sizes_copy( 276 all_gather_output, all_gather_input_split_sizes, dim=1, out=out 277 ) 278 279 280@torch.no_grad() 281def foreach_reduce( 282 fsdp_params: List[FSDPParam], 283 unsharded_grads: List[torch.Tensor], 284 reduce_scatter_group: dist.ProcessGroup, 285 reduce_scatter_stream: torch.cuda.Stream, 286 orig_dtype: torch.dtype, 287 reduce_dtype: Optional[torch.dtype], 288 device: torch.device, 289 reduce_scatter_reduce_op: Optional[Union[dist.ReduceOp, dist.ReduceOp.RedOpType]], 290 all_reduce_group: Optional[dist.ProcessGroup], # not `None` iff HSDP 291 all_reduce_stream: torch.cuda.Stream, 292 all_reduce_grads: bool, 293 partial_reduce_output: Optional[torch.Tensor], # only used for HSDP 294) -> Tuple[torch.Tensor, torch.cuda.Event, torch.cuda.Event, Optional[torch.Tensor]]: 295 """ 296 ``unsharded_grads`` owns the references to the gradients computed by 297 autograd, so clearing the list frees the gradients. 298 """ 299 grad_dtypes = {grad.dtype for grad in unsharded_grads} 300 if len(grad_dtypes) != 1: 301 # Check this at runtime since it could be a real runtime error if e.g. 302 # fp8 weights do not produce the correct higher precision gradients 303 _raise_assert_with_print( 304 f"FSDP reduce-scatter expects uniform gradient dtype but got {grad_dtypes}" 305 ) 306 grad_dtype = unsharded_grads[0].dtype 307 reduce_dtype = reduce_dtype or grad_dtype 308 predivide_factor, postdivide_factor = _get_gradient_divide_factors( 309 reduce_scatter_group, all_reduce_group, reduce_dtype 310 ) 311 world_size = reduce_scatter_group.size() 312 padded_unsharded_sizes = tuple( 313 _get_dim0_padded_size(grad.size(), world_size) for grad in unsharded_grads 314 ) 315 reduce_scatter_input_numel = sum(s.numel() for s in padded_unsharded_sizes) 316 reduce_scatter_output_numel = reduce_scatter_input_numel // world_size 317 reduce_scatter_input = torch.empty( 318 (reduce_scatter_input_numel,), dtype=reduce_dtype, device=device 319 ) 320 foreach_reduce_scatter_copy_in(unsharded_grads, reduce_scatter_input, world_size) 321 current_stream = torch.cuda.current_stream() 322 # Only after the copy-in finishes can we free the gradients 323 unsharded_grads.clear() 324 reduce_scatter_stream.wait_stream(current_stream) 325 with torch.cuda.stream(reduce_scatter_stream): 326 reduce_output = reduce_scatter_input.new_empty((reduce_scatter_output_numel,)) 327 _div_if_needed(reduce_scatter_input, predivide_factor) 328 if reduce_scatter_reduce_op is None: 329 if predivide_factor is None: 330 reduce_scatter_reduce_op = ReduceOp.AVG 331 else: 332 reduce_scatter_reduce_op = ReduceOp.SUM 333 dist.reduce_scatter_tensor( 334 output=reduce_output, 335 input=reduce_scatter_input, 336 group=reduce_scatter_group, 337 op=reduce_scatter_reduce_op, 338 ) 339 reduce_scatter_event = reduce_scatter_stream.record_event() 340 post_reduce_stream = reduce_scatter_stream 341 if all_reduce_group is not None: # HSDP 342 # Accumulations must run in the reduce-scatter stream 343 if not all_reduce_grads: 344 if partial_reduce_output is not None: 345 partial_reduce_output += reduce_output 346 else: 347 partial_reduce_output = reduce_output 348 return ( 349 reduce_scatter_input, 350 reduce_scatter_event, 351 post_reduce_stream.record_event(), 352 partial_reduce_output, 353 ) 354 if partial_reduce_output is not None: 355 reduce_output += partial_reduce_output 356 post_reduce_stream = all_reduce_stream 357 all_reduce_stream.wait_stream(reduce_scatter_stream) 358 with torch.cuda.stream(all_reduce_stream): 359 dist.all_reduce( 360 reduce_output, 361 group=all_reduce_group, 362 op=ReduceOp.AVG if predivide_factor is None else ReduceOp.SUM, 363 ) 364 with torch.cuda.stream(post_reduce_stream): 365 _div_if_needed(reduce_output, postdivide_factor) 366 reduce_output = _to_dtype_if_needed(reduce_output, orig_dtype) 367 # View out and accumulate sharded gradients 368 flat_grad_offset = 0 # [0, reduce_scatter_output_numel - 1] 369 for padded_unsharded_size, fsdp_param in zip( 370 padded_unsharded_sizes, fsdp_params 371 ): 372 new_sharded_grad = torch.as_strided( 373 reduce_output, 374 size=fsdp_param.sharded_size, 375 stride=fsdp_param.contiguous_sharded_stride, 376 storage_offset=flat_grad_offset, 377 ) 378 to_accumulate_grad = fsdp_param.sharded_param.grad is not None 379 if fsdp_param.offload_to_cpu: 380 # Only overlap the D2H copy (copying to pinned memory) if not 381 # accumulating gradients since the CPU add kernel depends on 382 # the copy result and we cannot run the add as a callback 383 non_blocking = fsdp_param.pin_memory and not to_accumulate_grad 384 # Since the GPU sharded gradient is allocated in the RS stream, 385 # we can free it here by not keeping a ref without waiting for 386 # the D2H copy since future RS-stream ops run after the copy 387 new_sharded_grad = new_sharded_grad.to( 388 torch.device("cpu"), non_blocking=non_blocking 389 ) 390 if non_blocking: 391 # Record an event on which to block the CPU thread to 392 # ensure that the D2H copy finishes before the optimizer 393 fsdp_param.grad_offload_event = reduce_scatter_stream.record_event() 394 if to_accumulate_grad: 395 assert isinstance(fsdp_param.sharded_param.grad, DTensor) 396 fsdp_param.sharded_param.grad._local_tensor += new_sharded_grad 397 else: 398 new_sharded_dtensor_grad = fsdp_param.to_sharded_dtensor( 399 new_sharded_grad 400 ) 401 fsdp_param.sharded_param.grad = new_sharded_dtensor_grad 402 if not ca.compiled_autograd_enabled: 403 for hook in ( 404 getattr(fsdp_param.sharded_param, "_post_accumulate_grad_hooks", {}) 405 or {} 406 ).values(): 407 hook(fsdp_param.sharded_param) 408 padded_sharded_numel = padded_unsharded_size.numel() // world_size 409 flat_grad_offset += padded_sharded_numel 410 post_reduce_event = post_reduce_stream.record_event() 411 # The RS output is allocated in the RS stream and used in the default 412 # stream (for optimizer). To ensure its memory is not reused for later 413 # RSs, we do not need extra synchronization since the sharded parameters 414 # hold refs through the end of backward. 415 return reduce_scatter_input, reduce_scatter_event, post_reduce_event, None 416 417 418def foreach_reduce_scatter_copy_in( 419 unsharded_grads: List[torch.Tensor], 420 reduce_scatter_input: torch.Tensor, 421 world_size: int, 422) -> None: 423 reduce_scatter_input = reduce_scatter_input.view(world_size, -1) 424 torch.ops.fsdp.chunk_cat( 425 unsharded_grads, dim=0, num_chunks=world_size, out=reduce_scatter_input 426 ) 427 428 429def _get_all_gather_input_metadatas( 430 param_all_gather_inputs: List[List[torch.Tensor]], 431) -> Tuple[List[List[torch.dtype]], List[List[int]], torch.dtype]: 432 param_all_gather_input_dtypes: List[List[torch.dtype]] = [] 433 param_all_gather_input_numels: List[List[int]] = [] 434 all_gather_dtype = param_all_gather_inputs[0][0].dtype 435 for all_gather_inputs in param_all_gather_inputs: 436 input_dtypes: List[torch.dtype] = [] 437 input_numels: List[int] = [] 438 for all_gather_input in all_gather_inputs: 439 if all_gather_input.dtype != all_gather_dtype: 440 all_gather_dtype = torch.uint8 441 input_dtypes.append(all_gather_input.dtype) 442 input_numels.append(all_gather_input.numel()) 443 param_all_gather_input_dtypes.append(input_dtypes) 444 param_all_gather_input_numels.append(input_numels) 445 return ( 446 param_all_gather_input_dtypes, 447 param_all_gather_input_numels, 448 all_gather_dtype, 449 ) 450 451 452def _get_gradient_divide_factors( 453 reduce_scatter_group: dist.ProcessGroup, 454 all_reduce_group: Optional[dist.ProcessGroup], 455 reduce_dtype: torch.dtype, 456) -> Union[Tuple[None, None], Tuple[float, float]]: 457 # For fp32/bf16, we do not need to worry about overflow/underflow, so we 458 # use NCCL's built-in division to avoid separate div kernels 459 if reduce_dtype in (torch.float32, torch.bfloat16): 460 return None, None 461 data_parallel_size = reduce_scatter_group.size() 462 if all_reduce_group is not None: 463 data_parallel_size *= all_reduce_group.size() 464 # Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid 465 # overflow/underflow. For N data parallel workers, each worker computes 466 # g_i, and they collectively reduce (g_1 + ... + g_N) / N. To avoid 467 # overflow/underflow, we divide by ~sqrt(N) before/after the reduction. 468 factor: int = 1 469 while data_parallel_size % factor == 0 and data_parallel_size / factor > factor: 470 factor *= 2 471 factor = float(factor) 472 return (factor, data_parallel_size / factor) 473 474 475def _div_if_needed(tensor: torch.Tensor, div_factor: Optional[float]) -> None: 476 if div_factor is not None and div_factor > 1: 477 tensor.div_(div_factor) 478