xref: /aosp_15_r20/external/pytorch/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Any, Callable, cast, Tuple
3
4import torch
5import torch.distributed as dist
6
7
8__all__ = [
9    "allreduce_hook",
10    "fp16_compress_hook",
11    "bf16_compress_hook",
12    "fp16_compress_wrapper",
13    "bf16_compress_wrapper",
14]
15
16
17def _allreduce_fut(
18    process_group: dist.ProcessGroup, tensor: torch.Tensor
19) -> torch.futures.Future[torch.Tensor]:
20    """Average the input gradient tensor by allreduce and returns a future."""
21    group_to_use = process_group if process_group is not None else dist.group.WORLD
22
23    # Apply the division first to avoid overflow, especially for FP16.
24    tensor.div_(group_to_use.size())
25
26    return (
27        dist.all_reduce(tensor, group=group_to_use, async_op=True)
28        .get_future()
29        .then(lambda fut: fut.value()[0])
30    )
31
32
33def allreduce_hook(
34    process_group: dist.ProcessGroup, bucket: dist.GradBucket
35) -> torch.futures.Future[torch.Tensor]:
36    """
37    Call ``allreduce`` using ``GradBucket`` tensors.
38
39    Once gradient tensors are aggregated across all workers, its ``then``
40    callback takes the mean and returns the result.
41
42    If user registers this DDP communication hook,
43    DDP results is expected to be same as the case where no hook was registered.
44    Hence, this won't change behavior of DDP and user can use this as a reference
45    or modify this hook to log useful information or any other purposes while
46    unaffecting DDP behavior.
47
48    Example::
49        >>> # xdoctest: +SKIP
50        >>> ddp_model.register_comm_hook(process_group, allreduce_hook)
51    """
52    return _allreduce_fut(process_group, bucket.buffer())
53
54
55def fp16_compress_hook(
56    process_group: dist.ProcessGroup,
57    bucket: dist.GradBucket,
58) -> torch.futures.Future[torch.Tensor]:
59    """
60    Compress by casting ``GradBucket`` to ``torch.float16`` divided by process group size.
61
62    This DDP communication hook implements a simple gradient compression
63    approach that casts ``GradBucket`` tensor to half-precision floating-point format (``torch.float16``)
64    and then divides it by the process group size.
65    It allreduces those ``float16`` gradient tensors. Once compressed gradient
66    tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``).
67
68    Example::
69        >>> # xdoctest: +SKIP
70        >>> ddp_model.register_comm_hook(process_group, fp16_compress_hook)
71    """
72    group_to_use = process_group if process_group is not None else dist.group.WORLD
73    world_size = group_to_use.size()
74
75    buffer = (
76        cast(Tuple[torch.Tensor, ...], bucket)[0]
77        if isinstance(bucket, tuple)
78        else bucket.buffer()
79    )
80    compressed_tensor = buffer.to(torch.float16).div_(world_size)
81
82    def decompress(fut):
83        decompressed_tensor = buffer
84        # Decompress in place to reduce the peak memory.
85        # See: https://github.com/pytorch/pytorch/issues/45968
86        value = fut if isinstance(fut, torch.Tensor) else fut.value()[0]
87        decompressed_tensor.copy_(value)
88        return decompressed_tensor
89
90    if torch._utils.is_compiling():
91        grad = dist._functional_collectives.all_reduce(
92            compressed_tensor, "sum", group_to_use
93        )
94        return decompress(grad)
95    else:
96        fut = dist.all_reduce(
97            compressed_tensor, group=group_to_use, async_op=True
98        ).get_future()
99        return fut.then(decompress)
100
101
102# TODO: create an internal helper function and extract the duplicate code in FP16_compress and BF16_compress.
103def bf16_compress_hook(
104    process_group: dist.ProcessGroup,
105    bucket: dist.GradBucket,
106) -> torch.futures.Future[torch.Tensor]:
107    """
108    Warning: This API is experimental, and it requires NCCL version later than 2.9.6.
109
110    This DDP communication hook implements a simple gradient compression
111    approach that casts ``GradBucket`` tensor to half-precision
112    `Brain floating point format <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format>`_ (``torch.bfloat16``)
113    and then divides it by the process group size.
114    It allreduces those ``bfloat16`` gradient tensors. Once compressed gradient
115    tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``).
116
117    Example::
118        >>> # xdoctest: +SKIP
119        >>> ddp_model.register_comm_hook(process_group, bf16_compress_hook)
120    """
121    group_to_use = process_group if process_group is not None else dist.group.WORLD
122    world_size = group_to_use.size()
123
124    buffer = (
125        cast(Tuple[torch.Tensor, ...], bucket)[0]
126        if isinstance(bucket, tuple)
127        else bucket.buffer()
128    )
129    compressed_tensor = buffer.to(torch.bfloat16).div_(world_size)
130
131    def decompress(fut):
132        decompressed_tensor = buffer
133        # Decompress in place to reduce the peak memory.
134        # See: https://github.com/pytorch/pytorch/issues/45968
135        value = fut if isinstance(fut, torch.Tensor) else fut.value()[0]
136        decompressed_tensor.copy_(value)
137        return decompressed_tensor
138
139    if torch._utils.is_compiling():
140        grad = dist._functional_collectives.all_reduce(
141            compressed_tensor, "sum", group_to_use
142        )
143        return decompress(grad)
144    else:
145        fut = dist.all_reduce(
146            compressed_tensor, group=group_to_use, async_op=True
147        ).get_future()
148        return fut.then(decompress)
149
150
151def fp16_compress_wrapper(
152    hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]
153) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
154    """
155    Cast input tensor to ``torch.float16``, cast result of hook back to input dtype.
156
157    This wrapper casts the input gradient tensor of a given DDP communication hook to half-precision
158    floating point format (``torch.float16``), and casts the resulting tensor of the given hook back to
159    the input data type, such as ``float32``.
160    Therefore, ``fp16_compress_hook`` is equivalent to ``fp16_compress_wrapper(allreduce_hook)``.
161
162    Example::
163        >>> # xdoctest: +SKIP
164        >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
165        >>> ddp_model.register_comm_hook(state, fp16_compress_wrapper(powerSGD_hook))
166    """
167
168    def fp16_compress_wrapper_hook(
169        hook_state, bucket: dist.GradBucket
170    ) -> torch.futures.Future[torch.Tensor]:
171        # Cast bucket tensor to FP16.
172        bucket.set_buffer(bucket.buffer().to(torch.float16))
173
174        fut = hook(hook_state, bucket)
175
176        def decompress(fut):
177            decompressed_tensor = bucket.buffer()
178            # Decompress in place to reduce the peak memory.
179            # See: https://github.com/pytorch/pytorch/issues/45968
180            decompressed_tensor.copy_(fut.value())
181            return decompressed_tensor
182
183        # Decompress after hook has run.
184        return fut.then(decompress)
185
186    return fp16_compress_wrapper_hook
187
188
189def bf16_compress_wrapper(
190    hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]
191) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
192    """
193    Warning: This API is experimental, and it requires NCCL version later than 2.9.6.
194
195    This wrapper casts the input gradient tensor of a given DDP communication hook to half-precision
196    `Brain floating point format <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format> `_  (``torch.bfloat16``),
197    and casts the resulting tensor of the given hook back to the input data type, such as ``float32``.
198
199    Therefore, ``bf16_compress_hook`` is equivalent to ``bf16_compress_wrapper(allreduce_hook)``.
200
201    Example::
202        >>> # xdoctest: +SKIP
203        >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
204        >>> ddp_model.register_comm_hook(state, bf16_compress_wrapper(powerSGD_hook))
205    """
206
207    def bf16_compress_wrapper_hook(
208        hook_state, bucket: dist.GradBucket
209    ) -> torch.futures.Future[torch.Tensor]:
210        # Cast bucket tensor to BF16.
211        bucket.set_buffer(bucket.buffer().to(torch.bfloat16))
212
213        fut = hook(hook_state, bucket)
214
215        def decompress(fut):
216            decompressed_tensor = bucket.buffer()
217            # Decompress in place to reduce the peak memory.
218            # See: https://github.com/pytorch/pytorch/issues/45968
219            decompressed_tensor.copy_(fut.value())
220            return decompressed_tensor
221
222        # Decompress after hook has run.
223        return fut.then(decompress)
224
225    return bf16_compress_wrapper_hook
226