1from dataclasses import dataclass 2from typing import Any, no_type_check 3 4import torch 5import torch.distributed as dist 6from torch.autograd import Variable 7from torch.distributed.utils import _free_storage 8 9 10@dataclass 11class _AllreduceUpcastHookState: 12 """ 13 State to manage DDP mixed precision in backward / gradient communication. 14 15 This contains a weakref to the DDP module for access to reducer and process 16 group, and a stream to run parameter and gradient upcasts. 17 """ 18 19 ddp_weakref: Any 20 upcast_stream: torch.cuda.Stream 21 wait_for_stream_enqueued: bool = False 22 23 24@no_type_check 25def _reducer_allreduce_and_upcast_hook( 26 hook_state: _AllreduceUpcastHookState, bucket: dist.GradBucket 27) -> torch.futures.Future[torch.Tensor]: 28 """ 29 Perform allreduce in precision ``reduce_dtype``, upcast to prepare for optimizer. 30 31 Performs allreduce in the reduced precision given by DDP's mixed precision 32 reduce_dtype, and upcasts parameters and gradients to fp32 in preparation 33 to run the optimizer. 34 """ 35 ddp_weakref = hook_state.ddp_weakref 36 reducer, process_group = ddp_weakref().reducer, ddp_weakref().process_group 37 gradient_is_bucket_view = ddp_weakref().gradient_as_bucket_view 38 # Cast bucket if different than param_dtype. 39 if ( 40 ddp_weakref().mixed_precision.param_dtype 41 != ddp_weakref().mixed_precision.reduce_dtype 42 ): 43 # Cast bucket tensor to reduce_dtype 44 bucket.set_buffer( 45 bucket.buffer().to(ddp_weakref().mixed_precision.reduce_dtype) 46 ) 47 fut = reducer._run_allreduce_hook(bucket) 48 ret_fut = torch.futures.Future() 49 stream = hook_state.upcast_stream 50 with torch.cuda.stream(stream): 51 fut.wait() 52 bucket.buffer().div_(process_group.size()) 53 ret_fut.set_result(bucket.buffer()) 54 55 # Upcast parameters and gradients so optimizer step can run in fp32. 56 params, grads = bucket.parameters(), bucket.gradients() 57 for p, g in zip(params, grads): 58 p.data = p._fp_param 59 # free storage for mp param as it will be allocated again in next 60 # forward pass. 61 _free_storage(p._mp_param) 62 p.grad.data = p.grad.to(p.data.dtype) 63 64 # enqueue a callback to wait for this stream at end of backward 65 def wait_for_stream_cb(): 66 torch.cuda.current_stream().wait_stream(stream) 67 # Remove post-backward hooks since they are re-installed in next 68 # iteration, similar to FSDP. 69 # Parameters that don't require grad still needed to be casted since 70 # they may participate in computation. However, they would not be recast 71 # by hook above as they don't have a grad hook installed, so cast them 72 # back here. 73 for n, p in ddp_weakref().module.named_parameters(): 74 if hasattr(p, "_ddp_mp_hook_state"): 75 p._ddp_mp_hook_state[1].remove() 76 delattr(p, "_ddp_mp_hook_state") 77 if not p.requires_grad and not hasattr(p, "_ddp_ignored"): 78 p.data = p._fp_param 79 80 # reset for next backward pass 81 hook_state.wait_for_stream_enqueued = False 82 83 if not hook_state.wait_for_stream_enqueued: 84 Variable._execution_engine.queue_callback(wait_for_stream_cb) 85 # mark that the callback is enqueued 86 hook_state.wait_for_stream_enqueued = True 87 88 return ret_fut 89