xref: /aosp_15_r20/external/pytorch/torch/distributed/_composable/fsdp/_fsdp_collectives.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2from typing import cast, List, NamedTuple, Optional, Tuple, Union
3
4import torch
5import torch._dynamo.compiled_autograd as ca
6import torch.distributed as dist
7from torch.distributed.distributed_c10d import ReduceOp
8from torch.distributed.tensor import DTensor
9
10from ._fsdp_common import (
11    _get_dim0_padded_size,
12    _raise_assert_with_print,
13    _to_dtype_if_needed,
14)
15from ._fsdp_param import FSDPParam, ShardedState
16
17
18class AllGatherResult(NamedTuple):
19    all_gather_output: torch.Tensor
20    all_gather_event: Optional[torch.cuda.Event]
21    all_gather_work: Optional[dist.distributed_c10d.Work]
22    # For each parameter, the all-gather input dtype for each input
23    param_all_gather_input_dtypes: List[List[torch.dtype]]
24    # For each parameter, the all-gather input numel for each input
25    param_all_gather_input_numels: List[List[int]]
26    # 1D flattened version of `param_all_gather_input_numels` saved to avoid
27    # CPU overhead from recomputing
28    all_gather_input_split_sizes: List[int]
29
30
31lib = torch.library.Library("fsdp", "FRAGMENT")  # noqa: TOR901
32
33lib.define(
34    """
35    all_gather_copy_in(
36        Tensor[] all_gather_inputs,
37        SymInt[] inp_split_sizes,
38        SymInt all_gather_input_numel,
39        SymInt world_size,
40        SymInt rank,
41        ScalarType dtype,
42        Device device
43    ) -> (Tensor, Tensor)
44    """
45)
46
47
48@torch.library.impl(lib, "all_gather_copy_in", "Meta")
49def all_gather_copy_in_meta(
50    all_gather_inputs: List[torch.Tensor],
51    inp_split_sizes: List[int],
52    all_gather_input_numel: int,
53    world_size: int,
54    rank: int,
55    dtype: torch.dtype,
56    device: torch.device,
57) -> Tuple[torch.Tensor, torch.Tensor]:
58    all_gather_output = torch.empty(
59        (all_gather_input_numel * world_size,), dtype=dtype, device="meta"
60    )
61    all_gather_input = all_gather_output.narrow(
62        0, all_gather_input_numel * rank, all_gather_input_numel
63    )
64    return all_gather_input, all_gather_output
65
66
67@torch.library.impl(lib, "all_gather_copy_in", "CUDA")
68@torch.library.impl(lib, "all_gather_copy_in", "CPU")
69def all_gather_copy_in_cuda(
70    all_gather_inputs: List[torch.Tensor],
71    inp_split_sizes: List[int],
72    all_gather_input_numel: int,
73    world_size: int,
74    rank: int,
75    dtype: torch.dtype,
76    device: torch.device,
77) -> Tuple[torch.Tensor, torch.Tensor]:
78    all_gather_output = torch.empty(
79        (all_gather_input_numel * world_size,), dtype=dtype, device=device
80    )
81    all_gather_input = all_gather_output.narrow(
82        0, all_gather_input_numel * rank, all_gather_input_numel
83    )
84    foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
85    with torch.no_grad():
86        torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs)
87    return all_gather_input, all_gather_output
88
89
90lib.define(
91    "split_with_sizes_copy(Tensor all_gather_output, SymInt[] all_gather_input_split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()"
92)
93
94
95@torch.library.impl(lib, "split_with_sizes_copy", "Meta")
96@torch.library.impl(lib, "split_with_sizes_copy", "CUDA")
97@torch.library.impl(lib, "split_with_sizes_copy", "CPU")
98def split_with_sizes_copy(
99    all_gather_output: torch.Tensor,
100    all_gather_input_split_sizes: List[int],
101    dim: int,
102    out: List[torch.Tensor],
103) -> None:
104    torch.split_with_sizes_copy(
105        all_gather_output, all_gather_input_split_sizes, dim=dim, out=out
106    )
107
108
109lib.define(
110    "chunk_cat(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> ()"
111)
112
113
114@torch.library.impl(lib, "chunk_cat", "Meta")
115@torch.library.impl(lib, "chunk_cat", "CUDA")
116@torch.library.impl(lib, "chunk_cat", "CPU")
117def chunk_cat(
118    tensors: List[torch.Tensor],
119    dim: int,
120    num_chunks: int,
121    out: torch.Tensor,
122) -> None:
123    torch._chunk_cat(tensors, dim, num_chunks, out=out)
124
125
126@torch.no_grad()
127def foreach_all_gather(
128    fsdp_params: List[FSDPParam],
129    group: dist.ProcessGroup,
130    async_op: bool,
131    all_gather_copy_in_stream: torch.cuda.Stream,
132    all_gather_stream: torch.cuda.Stream,
133    device: torch.device,
134) -> Optional[AllGatherResult]:
135    world_size, rank = group.size(), group.rank()
136    with torch.cuda.stream(all_gather_copy_in_stream):
137        param_all_gather_inputs = _get_param_all_gather_inputs(fsdp_params)
138        (
139            param_all_gather_input_dtypes,
140            param_all_gather_input_numels,
141            dtype,
142        ) = _get_all_gather_input_metadatas(param_all_gather_inputs)
143        if dtype == torch.uint8:
144            all_gather_inputs = [
145                t.view(torch.uint8) for ts in param_all_gather_inputs for t in ts
146            ]
147        else:
148            all_gather_inputs = [t for ts in param_all_gather_inputs for t in ts]
149        inp_split_sizes = [t.numel() for t in all_gather_inputs]
150        all_gather_input_numel = sum(inp_split_sizes)
151        all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(
152            all_gather_inputs,
153            inp_split_sizes,
154            all_gather_input_numel,
155            world_size,
156            rank,
157            dtype,
158            device,
159        )
160        del param_all_gather_inputs
161    all_gather_stream.wait_stream(all_gather_copy_in_stream)
162    with torch.cuda.stream(all_gather_stream):
163        all_gather_work = dist.all_gather_into_tensor(
164            output_tensor=all_gather_output,
165            input_tensor=all_gather_input,
166            group=group,
167            async_op=async_op,
168        )
169        all_gather_event = all_gather_stream.record_event()
170        return AllGatherResult(
171            all_gather_output,
172            all_gather_event,
173            all_gather_work,
174            param_all_gather_input_dtypes,
175            param_all_gather_input_numels,
176            inp_split_sizes,
177        )
178
179
180@torch.no_grad()
181def _get_param_all_gather_inputs(
182    fsdp_params: List[FSDPParam],
183) -> List[List[torch.Tensor]]:
184    if ca.compiled_autograd_enabled:
185        return [fsdp_param.all_gather_inputs for fsdp_param in fsdp_params]
186
187    # Intentionally try to run a fast-path that bypasses abstractions for the
188    # common FSDP case of bf16/fp32 mixed precision in order to use foreach
189    # copy for lower CPU overhead and more efficient copying in eager
190    def use_foreach_copy(fsdp_param: FSDPParam) -> bool:
191        return (
192            fsdp_param.param_dtype is not None
193            and not fsdp_param.offload_to_cpu
194            and not hasattr(fsdp_param._sharded_local_tensor, "fsdp_pre_all_gather")
195        )
196
197    param_all_gather_inputs: List[List[torch.Tensor]] = [[] for _ in fsdp_params]
198    foreach_copy_indices: List[int] = []
199    foreach_copy_inputs: List[torch.Tensor] = []
200    foreach_copy_input_numels: List[int] = []
201
202    # 1st pass: for foreach-copy parameters, get inputs and metadata for the
203    # foreach copy, and for the others, actually get their all-gather inputs
204    for i, fsdp_param in enumerate(fsdp_params):
205        if use_foreach_copy(fsdp_param):
206            foreach_copy_indices.append(i)
207            all_gather_input = (
208                fsdp_param._sharded_param_data
209                if fsdp_param.sharded_state == ShardedState.SHARDED
210                else cast(torch.Tensor, fsdp_param._sharded_post_forward_param_data)
211            )
212            foreach_copy_inputs.append(all_gather_input)
213            foreach_copy_input_numels.append(all_gather_input.numel())
214        else:
215            param_all_gather_inputs[i] = fsdp_param.all_gather_inputs
216
217    # 2nd pass: use foreach copy to compute the remaining all-gather inputs
218    if foreach_copy_inputs:
219        fsdp_param_0 = fsdp_params[foreach_copy_indices[0]]
220        param_dtype, device = fsdp_param_0.param_dtype, fsdp_param_0.device
221        flat_foreach_copy_input = torch.empty(
222            (sum(foreach_copy_input_numels),), device=device, dtype=param_dtype
223        )
224        splits = torch.split(flat_foreach_copy_input, foreach_copy_input_numels)
225        torch._foreach_copy_(splits, foreach_copy_inputs)
226        for i, split in zip(foreach_copy_indices, splits):
227            param_all_gather_inputs[i] = [split]
228
229    return param_all_gather_inputs
230
231
232@torch.no_grad()
233def foreach_all_gather_copy_out(
234    all_gather_result: AllGatherResult,
235    fsdp_params: List[FSDPParam],
236    group: dist.ProcessGroup,
237) -> None:
238    (
239        all_gather_output,
240        all_gather_event,
241        all_gather_work,
242        param_all_gather_input_dtypes,
243        param_all_gather_input_numels,
244        all_gather_input_split_sizes,
245    ) = all_gather_result
246    if all_gather_event is not None:  # sync op
247        torch.cuda.current_stream().wait_event(all_gather_event)
248    if isinstance(all_gather_work, dist.distributed_c10d.Work):  # async op
249        all_gather_work.wait()
250    world_size, device = group.size(), all_gather_output.device
251    for all_gather_input_numels, all_gather_input_dtypes, fsdp_param in zip(
252        param_all_gather_input_numels, param_all_gather_input_dtypes, fsdp_params
253    ):
254        if ca.compiled_autograd_enabled:
255            fsdp_param.init_all_gather_outputs(
256                all_gather_input_numels,
257                all_gather_input_dtypes,
258                world_size,
259                device,
260                # NOTE: Under compile, make sure we always recreate all_gather_outputs
261                # per AllGather. See [Note: Invariants for torch.compile Traceable FSDP2].
262                force_recreate=True,
263            )
264        else:
265            fsdp_param.init_all_gather_outputs(
266                all_gather_input_numels, all_gather_input_dtypes, world_size, device
267            )  # no-op after 1st call
268            fsdp_param.alloc_all_gather_outputs()
269    all_gather_output = all_gather_output.view(world_size, -1)
270    gen = (t for fsdp_param in fsdp_params for t in fsdp_param.all_gather_outputs)
271    if all_gather_output.dtype == torch.uint8:
272        out = [t.view(world_size, -1).view(torch.uint8) for t in gen]
273    else:
274        out = [t.view(world_size, -1) for t in gen]
275    torch.ops.fsdp.split_with_sizes_copy(
276        all_gather_output, all_gather_input_split_sizes, dim=1, out=out
277    )
278
279
280@torch.no_grad()
281def foreach_reduce(
282    fsdp_params: List[FSDPParam],
283    unsharded_grads: List[torch.Tensor],
284    reduce_scatter_group: dist.ProcessGroup,
285    reduce_scatter_stream: torch.cuda.Stream,
286    orig_dtype: torch.dtype,
287    reduce_dtype: Optional[torch.dtype],
288    device: torch.device,
289    reduce_scatter_reduce_op: Optional[Union[dist.ReduceOp, dist.ReduceOp.RedOpType]],
290    all_reduce_group: Optional[dist.ProcessGroup],  # not `None` iff HSDP
291    all_reduce_stream: torch.cuda.Stream,
292    all_reduce_grads: bool,
293    partial_reduce_output: Optional[torch.Tensor],  # only used for HSDP
294) -> Tuple[torch.Tensor, torch.cuda.Event, torch.cuda.Event, Optional[torch.Tensor]]:
295    """
296    ``unsharded_grads`` owns the references to the gradients computed by
297    autograd, so clearing the list frees the gradients.
298    """
299    grad_dtypes = {grad.dtype for grad in unsharded_grads}
300    if len(grad_dtypes) != 1:
301        # Check this at runtime since it could be a real runtime error if e.g.
302        # fp8 weights do not produce the correct higher precision gradients
303        _raise_assert_with_print(
304            f"FSDP reduce-scatter expects uniform gradient dtype but got {grad_dtypes}"
305        )
306    grad_dtype = unsharded_grads[0].dtype
307    reduce_dtype = reduce_dtype or grad_dtype
308    predivide_factor, postdivide_factor = _get_gradient_divide_factors(
309        reduce_scatter_group, all_reduce_group, reduce_dtype
310    )
311    world_size = reduce_scatter_group.size()
312    padded_unsharded_sizes = tuple(
313        _get_dim0_padded_size(grad.size(), world_size) for grad in unsharded_grads
314    )
315    reduce_scatter_input_numel = sum(s.numel() for s in padded_unsharded_sizes)
316    reduce_scatter_output_numel = reduce_scatter_input_numel // world_size
317    reduce_scatter_input = torch.empty(
318        (reduce_scatter_input_numel,), dtype=reduce_dtype, device=device
319    )
320    foreach_reduce_scatter_copy_in(unsharded_grads, reduce_scatter_input, world_size)
321    current_stream = torch.cuda.current_stream()
322    # Only after the copy-in finishes can we free the gradients
323    unsharded_grads.clear()
324    reduce_scatter_stream.wait_stream(current_stream)
325    with torch.cuda.stream(reduce_scatter_stream):
326        reduce_output = reduce_scatter_input.new_empty((reduce_scatter_output_numel,))
327        _div_if_needed(reduce_scatter_input, predivide_factor)
328        if reduce_scatter_reduce_op is None:
329            if predivide_factor is None:
330                reduce_scatter_reduce_op = ReduceOp.AVG
331            else:
332                reduce_scatter_reduce_op = ReduceOp.SUM
333        dist.reduce_scatter_tensor(
334            output=reduce_output,
335            input=reduce_scatter_input,
336            group=reduce_scatter_group,
337            op=reduce_scatter_reduce_op,
338        )
339        reduce_scatter_event = reduce_scatter_stream.record_event()
340        post_reduce_stream = reduce_scatter_stream
341        if all_reduce_group is not None:  # HSDP
342            # Accumulations must run in the reduce-scatter stream
343            if not all_reduce_grads:
344                if partial_reduce_output is not None:
345                    partial_reduce_output += reduce_output
346                else:
347                    partial_reduce_output = reduce_output
348                return (
349                    reduce_scatter_input,
350                    reduce_scatter_event,
351                    post_reduce_stream.record_event(),
352                    partial_reduce_output,
353                )
354            if partial_reduce_output is not None:
355                reduce_output += partial_reduce_output
356            post_reduce_stream = all_reduce_stream
357            all_reduce_stream.wait_stream(reduce_scatter_stream)
358            with torch.cuda.stream(all_reduce_stream):
359                dist.all_reduce(
360                    reduce_output,
361                    group=all_reduce_group,
362                    op=ReduceOp.AVG if predivide_factor is None else ReduceOp.SUM,
363                )
364    with torch.cuda.stream(post_reduce_stream):
365        _div_if_needed(reduce_output, postdivide_factor)
366        reduce_output = _to_dtype_if_needed(reduce_output, orig_dtype)
367        # View out and accumulate sharded gradients
368        flat_grad_offset = 0  # [0, reduce_scatter_output_numel - 1]
369        for padded_unsharded_size, fsdp_param in zip(
370            padded_unsharded_sizes, fsdp_params
371        ):
372            new_sharded_grad = torch.as_strided(
373                reduce_output,
374                size=fsdp_param.sharded_size,
375                stride=fsdp_param.contiguous_sharded_stride,
376                storage_offset=flat_grad_offset,
377            )
378            to_accumulate_grad = fsdp_param.sharded_param.grad is not None
379            if fsdp_param.offload_to_cpu:
380                # Only overlap the D2H copy (copying to pinned memory) if not
381                # accumulating gradients since the CPU add kernel depends on
382                # the copy result and we cannot run the add as a callback
383                non_blocking = fsdp_param.pin_memory and not to_accumulate_grad
384                # Since the GPU sharded gradient is allocated in the RS stream,
385                # we can free it here by not keeping a ref without waiting for
386                # the D2H copy since future RS-stream ops run after the copy
387                new_sharded_grad = new_sharded_grad.to(
388                    torch.device("cpu"), non_blocking=non_blocking
389                )
390                if non_blocking:
391                    # Record an event on which to block the CPU thread to
392                    # ensure that the D2H copy finishes before the optimizer
393                    fsdp_param.grad_offload_event = reduce_scatter_stream.record_event()
394            if to_accumulate_grad:
395                assert isinstance(fsdp_param.sharded_param.grad, DTensor)
396                fsdp_param.sharded_param.grad._local_tensor += new_sharded_grad
397            else:
398                new_sharded_dtensor_grad = fsdp_param.to_sharded_dtensor(
399                    new_sharded_grad
400                )
401                fsdp_param.sharded_param.grad = new_sharded_dtensor_grad
402            if not ca.compiled_autograd_enabled:
403                for hook in (
404                    getattr(fsdp_param.sharded_param, "_post_accumulate_grad_hooks", {})
405                    or {}
406                ).values():
407                    hook(fsdp_param.sharded_param)
408            padded_sharded_numel = padded_unsharded_size.numel() // world_size
409            flat_grad_offset += padded_sharded_numel
410        post_reduce_event = post_reduce_stream.record_event()
411    # The RS output is allocated in the RS stream and used in the default
412    # stream (for optimizer). To ensure its memory is not reused for later
413    # RSs, we do not need extra synchronization since the sharded parameters
414    # hold refs through the end of backward.
415    return reduce_scatter_input, reduce_scatter_event, post_reduce_event, None
416
417
418def foreach_reduce_scatter_copy_in(
419    unsharded_grads: List[torch.Tensor],
420    reduce_scatter_input: torch.Tensor,
421    world_size: int,
422) -> None:
423    reduce_scatter_input = reduce_scatter_input.view(world_size, -1)
424    torch.ops.fsdp.chunk_cat(
425        unsharded_grads, dim=0, num_chunks=world_size, out=reduce_scatter_input
426    )
427
428
429def _get_all_gather_input_metadatas(
430    param_all_gather_inputs: List[List[torch.Tensor]],
431) -> Tuple[List[List[torch.dtype]], List[List[int]], torch.dtype]:
432    param_all_gather_input_dtypes: List[List[torch.dtype]] = []
433    param_all_gather_input_numels: List[List[int]] = []
434    all_gather_dtype = param_all_gather_inputs[0][0].dtype
435    for all_gather_inputs in param_all_gather_inputs:
436        input_dtypes: List[torch.dtype] = []
437        input_numels: List[int] = []
438        for all_gather_input in all_gather_inputs:
439            if all_gather_input.dtype != all_gather_dtype:
440                all_gather_dtype = torch.uint8
441            input_dtypes.append(all_gather_input.dtype)
442            input_numels.append(all_gather_input.numel())
443        param_all_gather_input_dtypes.append(input_dtypes)
444        param_all_gather_input_numels.append(input_numels)
445    return (
446        param_all_gather_input_dtypes,
447        param_all_gather_input_numels,
448        all_gather_dtype,
449    )
450
451
452def _get_gradient_divide_factors(
453    reduce_scatter_group: dist.ProcessGroup,
454    all_reduce_group: Optional[dist.ProcessGroup],
455    reduce_dtype: torch.dtype,
456) -> Union[Tuple[None, None], Tuple[float, float]]:
457    # For fp32/bf16, we do not need to worry about overflow/underflow, so we
458    # use NCCL's built-in division to avoid separate div kernels
459    if reduce_dtype in (torch.float32, torch.bfloat16):
460        return None, None
461    data_parallel_size = reduce_scatter_group.size()
462    if all_reduce_group is not None:
463        data_parallel_size *= all_reduce_group.size()
464    # Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid
465    # overflow/underflow. For N data parallel workers, each worker computes
466    # g_i, and they collectively reduce (g_1 + ... + g_N) / N. To avoid
467    # overflow/underflow, we divide by ~sqrt(N) before/after the reduction.
468    factor: int = 1
469    while data_parallel_size % factor == 0 and data_parallel_size / factor > factor:
470        factor *= 2
471    factor = float(factor)
472    return (factor, data_parallel_size / factor)
473
474
475def _div_if_needed(tensor: torch.Tensor, div_factor: Optional[float]) -> None:
476    if div_factor is not None and div_factor > 1:
477        tensor.div_(div_factor)
478