1# mypy: allow-untyped-defs 2from typing import Any, Callable, cast, Tuple 3 4import torch 5import torch.distributed as dist 6 7 8__all__ = [ 9 "allreduce_hook", 10 "fp16_compress_hook", 11 "bf16_compress_hook", 12 "fp16_compress_wrapper", 13 "bf16_compress_wrapper", 14] 15 16 17def _allreduce_fut( 18 process_group: dist.ProcessGroup, tensor: torch.Tensor 19) -> torch.futures.Future[torch.Tensor]: 20 """Average the input gradient tensor by allreduce and returns a future.""" 21 group_to_use = process_group if process_group is not None else dist.group.WORLD 22 23 # Apply the division first to avoid overflow, especially for FP16. 24 tensor.div_(group_to_use.size()) 25 26 return ( 27 dist.all_reduce(tensor, group=group_to_use, async_op=True) 28 .get_future() 29 .then(lambda fut: fut.value()[0]) 30 ) 31 32 33def allreduce_hook( 34 process_group: dist.ProcessGroup, bucket: dist.GradBucket 35) -> torch.futures.Future[torch.Tensor]: 36 """ 37 Call ``allreduce`` using ``GradBucket`` tensors. 38 39 Once gradient tensors are aggregated across all workers, its ``then`` 40 callback takes the mean and returns the result. 41 42 If user registers this DDP communication hook, 43 DDP results is expected to be same as the case where no hook was registered. 44 Hence, this won't change behavior of DDP and user can use this as a reference 45 or modify this hook to log useful information or any other purposes while 46 unaffecting DDP behavior. 47 48 Example:: 49 >>> # xdoctest: +SKIP 50 >>> ddp_model.register_comm_hook(process_group, allreduce_hook) 51 """ 52 return _allreduce_fut(process_group, bucket.buffer()) 53 54 55def fp16_compress_hook( 56 process_group: dist.ProcessGroup, 57 bucket: dist.GradBucket, 58) -> torch.futures.Future[torch.Tensor]: 59 """ 60 Compress by casting ``GradBucket`` to ``torch.float16`` divided by process group size. 61 62 This DDP communication hook implements a simple gradient compression 63 approach that casts ``GradBucket`` tensor to half-precision floating-point format (``torch.float16``) 64 and then divides it by the process group size. 65 It allreduces those ``float16`` gradient tensors. Once compressed gradient 66 tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``). 67 68 Example:: 69 >>> # xdoctest: +SKIP 70 >>> ddp_model.register_comm_hook(process_group, fp16_compress_hook) 71 """ 72 group_to_use = process_group if process_group is not None else dist.group.WORLD 73 world_size = group_to_use.size() 74 75 buffer = ( 76 cast(Tuple[torch.Tensor, ...], bucket)[0] 77 if isinstance(bucket, tuple) 78 else bucket.buffer() 79 ) 80 compressed_tensor = buffer.to(torch.float16).div_(world_size) 81 82 def decompress(fut): 83 decompressed_tensor = buffer 84 # Decompress in place to reduce the peak memory. 85 # See: https://github.com/pytorch/pytorch/issues/45968 86 value = fut if isinstance(fut, torch.Tensor) else fut.value()[0] 87 decompressed_tensor.copy_(value) 88 return decompressed_tensor 89 90 if torch._utils.is_compiling(): 91 grad = dist._functional_collectives.all_reduce( 92 compressed_tensor, "sum", group_to_use 93 ) 94 return decompress(grad) 95 else: 96 fut = dist.all_reduce( 97 compressed_tensor, group=group_to_use, async_op=True 98 ).get_future() 99 return fut.then(decompress) 100 101 102# TODO: create an internal helper function and extract the duplicate code in FP16_compress and BF16_compress. 103def bf16_compress_hook( 104 process_group: dist.ProcessGroup, 105 bucket: dist.GradBucket, 106) -> torch.futures.Future[torch.Tensor]: 107 """ 108 Warning: This API is experimental, and it requires NCCL version later than 2.9.6. 109 110 This DDP communication hook implements a simple gradient compression 111 approach that casts ``GradBucket`` tensor to half-precision 112 `Brain floating point format <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format>`_ (``torch.bfloat16``) 113 and then divides it by the process group size. 114 It allreduces those ``bfloat16`` gradient tensors. Once compressed gradient 115 tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``). 116 117 Example:: 118 >>> # xdoctest: +SKIP 119 >>> ddp_model.register_comm_hook(process_group, bf16_compress_hook) 120 """ 121 group_to_use = process_group if process_group is not None else dist.group.WORLD 122 world_size = group_to_use.size() 123 124 buffer = ( 125 cast(Tuple[torch.Tensor, ...], bucket)[0] 126 if isinstance(bucket, tuple) 127 else bucket.buffer() 128 ) 129 compressed_tensor = buffer.to(torch.bfloat16).div_(world_size) 130 131 def decompress(fut): 132 decompressed_tensor = buffer 133 # Decompress in place to reduce the peak memory. 134 # See: https://github.com/pytorch/pytorch/issues/45968 135 value = fut if isinstance(fut, torch.Tensor) else fut.value()[0] 136 decompressed_tensor.copy_(value) 137 return decompressed_tensor 138 139 if torch._utils.is_compiling(): 140 grad = dist._functional_collectives.all_reduce( 141 compressed_tensor, "sum", group_to_use 142 ) 143 return decompress(grad) 144 else: 145 fut = dist.all_reduce( 146 compressed_tensor, group=group_to_use, async_op=True 147 ).get_future() 148 return fut.then(decompress) 149 150 151def fp16_compress_wrapper( 152 hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]] 153) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: 154 """ 155 Cast input tensor to ``torch.float16``, cast result of hook back to input dtype. 156 157 This wrapper casts the input gradient tensor of a given DDP communication hook to half-precision 158 floating point format (``torch.float16``), and casts the resulting tensor of the given hook back to 159 the input data type, such as ``float32``. 160 Therefore, ``fp16_compress_hook`` is equivalent to ``fp16_compress_wrapper(allreduce_hook)``. 161 162 Example:: 163 >>> # xdoctest: +SKIP 164 >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10) 165 >>> ddp_model.register_comm_hook(state, fp16_compress_wrapper(powerSGD_hook)) 166 """ 167 168 def fp16_compress_wrapper_hook( 169 hook_state, bucket: dist.GradBucket 170 ) -> torch.futures.Future[torch.Tensor]: 171 # Cast bucket tensor to FP16. 172 bucket.set_buffer(bucket.buffer().to(torch.float16)) 173 174 fut = hook(hook_state, bucket) 175 176 def decompress(fut): 177 decompressed_tensor = bucket.buffer() 178 # Decompress in place to reduce the peak memory. 179 # See: https://github.com/pytorch/pytorch/issues/45968 180 decompressed_tensor.copy_(fut.value()) 181 return decompressed_tensor 182 183 # Decompress after hook has run. 184 return fut.then(decompress) 185 186 return fp16_compress_wrapper_hook 187 188 189def bf16_compress_wrapper( 190 hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]] 191) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: 192 """ 193 Warning: This API is experimental, and it requires NCCL version later than 2.9.6. 194 195 This wrapper casts the input gradient tensor of a given DDP communication hook to half-precision 196 `Brain floating point format <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format> `_ (``torch.bfloat16``), 197 and casts the resulting tensor of the given hook back to the input data type, such as ``float32``. 198 199 Therefore, ``bf16_compress_hook`` is equivalent to ``bf16_compress_wrapper(allreduce_hook)``. 200 201 Example:: 202 >>> # xdoctest: +SKIP 203 >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10) 204 >>> ddp_model.register_comm_hook(state, bf16_compress_wrapper(powerSGD_hook)) 205 """ 206 207 def bf16_compress_wrapper_hook( 208 hook_state, bucket: dist.GradBucket 209 ) -> torch.futures.Future[torch.Tensor]: 210 # Cast bucket tensor to BF16. 211 bucket.set_buffer(bucket.buffer().to(torch.bfloat16)) 212 213 fut = hook(hook_state, bucket) 214 215 def decompress(fut): 216 decompressed_tensor = bucket.buffer() 217 # Decompress in place to reduce the peak memory. 218 # See: https://github.com/pytorch/pytorch/issues/45968 219 decompressed_tensor.copy_(fut.value()) 220 return decompressed_tensor 221 222 # Decompress after hook has run. 223 return fut.then(decompress) 224 225 return bf16_compress_wrapper_hook 226