1# mypy: allow-untyped-defs 2import functools 3from typing import Optional 4 5import torch 6import torch.distributed as dist 7 8 9class DefaultState: 10 r""" 11 Stores state needed to perform the default communication algorithm within a communication hook. 12 13 Args: 14 process_group (ProcessGroup): The process group to be used. 15 """ 16 17 __slots__ = [ 18 "process_group", 19 "world_size", 20 "gradient_predivide_factor", 21 "gradient_postdivide_factor", 22 ] 23 24 def __init__(self, process_group: dist.ProcessGroup): 25 if process_group is None: 26 raise ValueError(f"Expected to pass in an explicit ProcessGroup to {self}.") 27 self.process_group = process_group 28 self.world_size = dist.get_world_size(process_group) 29 # Setting two factors `self.gradient_predivide_factor` 30 # and `self.gradient_postdivide_factor` to avoid underflow and overflow 31 self.gradient_predivide_factor = self._get_gradient_predivide_factor( 32 self.world_size 33 ) 34 self.gradient_postdivide_factor = ( 35 self.world_size / self.gradient_predivide_factor 36 ) 37 38 @staticmethod 39 def _get_gradient_predivide_factor(world_size: int) -> float: 40 factor: int = 1 41 while world_size % factor == 0 and world_size / factor > factor: 42 factor *= 2 43 return float(factor) 44 45 46class LowPrecisionState(DefaultState): 47 r""" 48 Stores state needed to perform gradient communication in a lower precision within a communication hook. 49 50 Communication hook will cast gradients back to the original 51 parameter precision specified by ``parameter_type`` (default: torch.float32). 52 Builds on top of the :class:`DefaultState`. 53 54 Args: 55 parameter_type (torch.dtype): The precision of model's parameters. 56 Required for a hook to cast gradients back to a parameter's precision. 57 """ 58 59 __slots__ = [ 60 "parameter_type", 61 ] 62 63 def __init__( 64 self, 65 process_group, 66 parameter_type=torch.float32, 67 ): 68 super().__init__(process_group) 69 self.parameter_type = parameter_type 70 71 72def _decompress(state: LowPrecisionState, grad: torch.Tensor): 73 """ 74 Casts gradients back to full parameter precision so that further computation happens in full precision. 75 """ 76 orig_grad_data = grad.data 77 grad.data = grad.data.to(state.parameter_type) 78 device_type = "" 79 try: 80 if grad.device.type == "privateuse1": 81 device_type = torch._C._get_privateuse1_backend_name() 82 else: 83 device_type = grad.device.type 84 backend = getattr(torch, device_type) 85 except AttributeError as e: 86 raise AttributeError( 87 f"Device {grad.device} does not have a \ 88 corresponding backend registered as 'torch.device_type'." 89 ) from e 90 91 # Don't let this memory get reused until after the transfer. 92 orig_grad_data.record_stream(backend.current_stream()) # type: ignore[arg-type] 93 94 95def allreduce_hook(state: DefaultState, grad: torch.Tensor): 96 r""" 97 Implement the FSDP communication hook for ``all_reduce`` algorithm and a necessary pre- and post-division of gradients. 98 99 Args: 100 state (DefaultState): State information, configures pre- and post-division factors. 101 grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks. 102 """ 103 # Average grad by pre-division factor. Together pre- and post-division factors 104 # lead to an overall averaging by world_size, required for consistency with PyTorch DDP. 105 # This is a two-step process to avoid potential underflow and overflow. 106 if state.gradient_predivide_factor > 1: 107 grad.div_(state.gradient_predivide_factor) 108 dist.all_reduce(grad, group=state.process_group) 109 # Average grad by post-division factor. 110 if state.gradient_postdivide_factor > 1: 111 grad.div_(state.gradient_postdivide_factor) 112 113 114def reduce_scatter_hook(state: DefaultState, grad: torch.Tensor, output: torch.Tensor): 115 r""" 116 Implement the FSDP communication hook for ``reduce_scatter`` algorithm. 117 118 For sharded FSDP strategies and a necessary pre- and post-division of gradients. 119 120 Args: 121 state (DefaultState): State information, configures pre- and post-division factors. 122 grad (torch.Tensor): An unsharded gradient for the local batch that needs to be 123 communicated across ranks. 124 output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``. 125 """ 126 # Average grad by pre-division factor. 127 if state.gradient_predivide_factor > 1: 128 grad.div_(state.gradient_predivide_factor) 129 dist.reduce_scatter_tensor(output, grad, group=state.process_group) 130 # Average grad's shard by post-division factor. 131 if state.gradient_postdivide_factor > 1: 132 output.div_(state.gradient_postdivide_factor) 133 134 135def _low_precision_hook( 136 prec: torch.dtype, 137 state: LowPrecisionState, 138 grad: torch.Tensor, 139 output: torch.Tensor, 140): 141 if grad.dtype != prec: 142 grad.data = grad.data.to(prec) 143 if output is not None: 144 if output.dtype != prec: 145 output.data = output.data.to(prec) 146 reduce_scatter_hook(state, grad, output) 147 _decompress(state, output) 148 else: 149 allreduce_hook(state, grad) 150 _decompress(state, grad) 151 152 153def fp16_compress_hook( 154 state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None 155): 156 r""" 157 Implement FSDP communication hook for a simple gradient compression approach. 158 Casts ``grad`` to half-precision floating-point format (``torch.float16``). 159 160 It also averages gradients by ``world_size`` in two steps: first it pre-divides gradients by a 161 ``state.gradient_predivide_factor``, and after a communication step (``all_reduce`` or ``reduce_scatter``) 162 gradients are averaged by a ``state.gradient_postdivide_factor``. 163 Once post-division is done, compressed gradients are casted back to parameters' precision. 164 165 Args: 166 state (LowPrecisionState): State information, configures pre- and post-division factors, parameters' precision. 167 grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks in a lower precision. 168 output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``. 169 """ 170 fp16_hook = functools.partial(_low_precision_hook, torch.float16) 171 return fp16_hook(state, grad, output) 172 173 174def bf16_compress_hook( 175 state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None 176): 177 r""" 178 Implement FSDP communication hook for a simple gradient compression approach . 179 Casts ``grad`` to half-precision floating-point format. 180 181 It also averages gradients by ``world_size`` in two steps: first it pre-divides gradients by a 182 ``state.gradient_predivide_factor``, and after a communication step (``all_reduce`` or ``reduce_scatter``) 183 gradients are averaged by a ``state.gradient_postdivide_factor``. 184 Once post-division is done, compressed gradients are casted back to parameters' precision. 185 186 Args: 187 state (LowPrecisionState): State information, configures pre- and post-division factors, parameters' precision. 188 grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks in a lower precision. 189 output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``. 190 """ 191 bf16_hook = functools.partial(_low_precision_hook, torch.bfloat16) 192 return bf16_hook(state, grad, output) 193