xref: /aosp_15_r20/external/pytorch/torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from dataclasses import dataclass
2from typing import Any, no_type_check
3
4import torch
5import torch.distributed as dist
6from torch.autograd import Variable
7from torch.distributed.utils import _free_storage
8
9
10@dataclass
11class _AllreduceUpcastHookState:
12    """
13    State to manage DDP mixed precision in backward / gradient communication.
14
15    This contains a weakref to the DDP module for access to reducer and process
16    group, and a stream to run parameter and gradient upcasts.
17    """
18
19    ddp_weakref: Any
20    upcast_stream: torch.cuda.Stream
21    wait_for_stream_enqueued: bool = False
22
23
24@no_type_check
25def _reducer_allreduce_and_upcast_hook(
26    hook_state: _AllreduceUpcastHookState, bucket: dist.GradBucket
27) -> torch.futures.Future[torch.Tensor]:
28    """
29    Perform allreduce in precision ``reduce_dtype``, upcast to prepare for optimizer.
30
31    Performs allreduce in the reduced precision given by DDP's mixed precision
32    reduce_dtype, and upcasts parameters and gradients to fp32 in preparation
33    to run the optimizer.
34    """
35    ddp_weakref = hook_state.ddp_weakref
36    reducer, process_group = ddp_weakref().reducer, ddp_weakref().process_group
37    gradient_is_bucket_view = ddp_weakref().gradient_as_bucket_view
38    # Cast bucket if different than param_dtype.
39    if (
40        ddp_weakref().mixed_precision.param_dtype
41        != ddp_weakref().mixed_precision.reduce_dtype
42    ):
43        # Cast bucket tensor to reduce_dtype
44        bucket.set_buffer(
45            bucket.buffer().to(ddp_weakref().mixed_precision.reduce_dtype)
46        )
47    fut = reducer._run_allreduce_hook(bucket)
48    ret_fut = torch.futures.Future()
49    stream = hook_state.upcast_stream
50    with torch.cuda.stream(stream):
51        fut.wait()
52        bucket.buffer().div_(process_group.size())
53        ret_fut.set_result(bucket.buffer())
54
55        # Upcast parameters and gradients so optimizer step can run in fp32.
56        params, grads = bucket.parameters(), bucket.gradients()
57        for p, g in zip(params, grads):
58            p.data = p._fp_param
59            # free storage for mp param as it will be allocated again in next
60            # forward pass.
61            _free_storage(p._mp_param)
62            p.grad.data = p.grad.to(p.data.dtype)
63
64    # enqueue a callback to wait for this stream at end of backward
65    def wait_for_stream_cb():
66        torch.cuda.current_stream().wait_stream(stream)
67        # Remove post-backward hooks since they are re-installed in next
68        # iteration, similar to FSDP.
69        # Parameters that don't require grad still needed to be casted since
70        # they may participate in computation. However, they would not be recast
71        # by hook above as they don't have a grad hook installed, so cast them
72        # back here.
73        for n, p in ddp_weakref().module.named_parameters():
74            if hasattr(p, "_ddp_mp_hook_state"):
75                p._ddp_mp_hook_state[1].remove()
76                delattr(p, "_ddp_mp_hook_state")
77            if not p.requires_grad and not hasattr(p, "_ddp_ignored"):
78                p.data = p._fp_param
79
80        # reset for next backward pass
81        hook_state.wait_for_stream_enqueued = False
82
83    if not hook_state.wait_for_stream_enqueued:
84        Variable._execution_engine.queue_callback(wait_for_stream_cb)
85        # mark that the callback is enqueued
86        hook_state.wait_for_stream_enqueued = True
87
88    return ret_fut
89