xref: /aosp_15_r20/external/pytorch/torch/distributed/algorithms/_comm_hooks/default_hooks.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3from typing import Optional
4
5import torch
6import torch.distributed as dist
7
8
9class DefaultState:
10    r"""
11    Stores state needed to perform the default communication algorithm within a communication hook.
12
13    Args:
14        process_group (ProcessGroup): The process group to be used.
15    """
16
17    __slots__ = [
18        "process_group",
19        "world_size",
20        "gradient_predivide_factor",
21        "gradient_postdivide_factor",
22    ]
23
24    def __init__(self, process_group: dist.ProcessGroup):
25        if process_group is None:
26            raise ValueError(f"Expected to pass in an explicit ProcessGroup to {self}.")
27        self.process_group = process_group
28        self.world_size = dist.get_world_size(process_group)
29        # Setting two factors `self.gradient_predivide_factor`
30        # and `self.gradient_postdivide_factor` to avoid underflow and overflow
31        self.gradient_predivide_factor = self._get_gradient_predivide_factor(
32            self.world_size
33        )
34        self.gradient_postdivide_factor = (
35            self.world_size / self.gradient_predivide_factor
36        )
37
38    @staticmethod
39    def _get_gradient_predivide_factor(world_size: int) -> float:
40        factor: int = 1
41        while world_size % factor == 0 and world_size / factor > factor:
42            factor *= 2
43        return float(factor)
44
45
46class LowPrecisionState(DefaultState):
47    r"""
48    Stores state needed to perform gradient communication in a lower precision within a communication hook.
49
50    Communication hook will cast gradients back to the original
51    parameter precision specified by ``parameter_type`` (default: torch.float32).
52    Builds on top of the :class:`DefaultState`.
53
54    Args:
55        parameter_type (torch.dtype): The precision of model's parameters.
56        Required for a hook to cast gradients back to a parameter's precision.
57    """
58
59    __slots__ = [
60        "parameter_type",
61    ]
62
63    def __init__(
64        self,
65        process_group,
66        parameter_type=torch.float32,
67    ):
68        super().__init__(process_group)
69        self.parameter_type = parameter_type
70
71
72def _decompress(state: LowPrecisionState, grad: torch.Tensor):
73    """
74    Casts gradients back to full parameter precision so that further computation happens in full precision.
75    """
76    orig_grad_data = grad.data
77    grad.data = grad.data.to(state.parameter_type)
78    device_type = ""
79    try:
80        if grad.device.type == "privateuse1":
81            device_type = torch._C._get_privateuse1_backend_name()
82        else:
83            device_type = grad.device.type
84        backend = getattr(torch, device_type)
85    except AttributeError as e:
86        raise AttributeError(
87            f"Device {grad.device}  does not have a \
88                corresponding backend registered as 'torch.device_type'."
89        ) from e
90
91    # Don't let this memory get reused until after the transfer.
92    orig_grad_data.record_stream(backend.current_stream())  # type: ignore[arg-type]
93
94
95def allreduce_hook(state: DefaultState, grad: torch.Tensor):
96    r"""
97    Implement the  FSDP communication hook for ``all_reduce`` algorithm and a necessary pre- and post-division of gradients.
98
99    Args:
100        state (DefaultState): State information, configures pre- and post-division factors.
101        grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks.
102    """
103    # Average grad by pre-division factor. Together pre- and post-division factors
104    # lead to an overall averaging by world_size, required for consistency with PyTorch DDP.
105    # This is a two-step process to avoid potential underflow and overflow.
106    if state.gradient_predivide_factor > 1:
107        grad.div_(state.gradient_predivide_factor)
108    dist.all_reduce(grad, group=state.process_group)
109    # Average grad by post-division factor.
110    if state.gradient_postdivide_factor > 1:
111        grad.div_(state.gradient_postdivide_factor)
112
113
114def reduce_scatter_hook(state: DefaultState, grad: torch.Tensor, output: torch.Tensor):
115    r"""
116    Implement the  FSDP communication hook for ``reduce_scatter`` algorithm.
117
118    For sharded FSDP strategies and a necessary pre- and post-division of gradients.
119
120    Args:
121        state (DefaultState): State information, configures pre- and post-division factors.
122        grad (torch.Tensor): An unsharded gradient for the local batch that needs to be
123        communicated across ranks.
124        output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``.
125    """
126    # Average grad by pre-division factor.
127    if state.gradient_predivide_factor > 1:
128        grad.div_(state.gradient_predivide_factor)
129    dist.reduce_scatter_tensor(output, grad, group=state.process_group)
130    # Average grad's shard by post-division factor.
131    if state.gradient_postdivide_factor > 1:
132        output.div_(state.gradient_postdivide_factor)
133
134
135def _low_precision_hook(
136    prec: torch.dtype,
137    state: LowPrecisionState,
138    grad: torch.Tensor,
139    output: torch.Tensor,
140):
141    if grad.dtype != prec:
142        grad.data = grad.data.to(prec)
143    if output is not None:
144        if output.dtype != prec:
145            output.data = output.data.to(prec)
146        reduce_scatter_hook(state, grad, output)
147        _decompress(state, output)
148    else:
149        allreduce_hook(state, grad)
150        _decompress(state, grad)
151
152
153def fp16_compress_hook(
154    state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None
155):
156    r"""
157    Implement FSDP communication hook for a simple gradient compression approach.
158    Casts ``grad`` to half-precision floating-point format (``torch.float16``).
159
160    It also averages gradients by ``world_size`` in two steps: first it pre-divides gradients by a
161    ``state.gradient_predivide_factor``, and after a communication step (``all_reduce`` or ``reduce_scatter``)
162    gradients are averaged by a ``state.gradient_postdivide_factor``.
163    Once post-division is done, compressed gradients are casted back to parameters' precision.
164
165    Args:
166        state (LowPrecisionState): State information, configures pre- and post-division factors, parameters' precision.
167        grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks in a lower precision.
168        output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``.
169    """
170    fp16_hook = functools.partial(_low_precision_hook, torch.float16)
171    return fp16_hook(state, grad, output)
172
173
174def bf16_compress_hook(
175    state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None
176):
177    r"""
178    Implement FSDP communication hook for a simple gradient compression approach .
179    Casts ``grad`` to half-precision floating-point format.
180
181    It also averages gradients by ``world_size`` in two steps: first it pre-divides gradients by a
182    ``state.gradient_predivide_factor``, and after a communication step (``all_reduce`` or ``reduce_scatter``)
183    gradients are averaged by a ``state.gradient_postdivide_factor``.
184    Once post-division is done, compressed gradients are casted back to parameters' precision.
185
186    Args:
187        state (LowPrecisionState): State information, configures pre- and post-division factors, parameters' precision.
188        grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks in a lower precision.
189        output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``.
190    """
191    bf16_hook = functools.partial(_low_precision_hook, torch.bfloat16)
192    return bf16_hook(state, grad, output)
193