1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import functools 4import logging 5from typing import ( 6 Any, 7 Callable, 8 Dict, 9 List, 10 Optional, 11 Sequence, 12 Set, 13 Tuple, 14 TYPE_CHECKING, 15) 16 17import torch 18import torch._dynamo.compiled_autograd as ca 19import torch.nn as nn 20from torch._logging import warning_once 21from torch.autograd import Variable 22from torch.autograd.graph import _MultiHandle 23from torch.distributed._composable_state import ( 24 _get_module_state, 25 _insert_module_state, 26 _State, 27) 28from torch.distributed.utils import _to_kwargs 29from torch.utils._pytree import tree_flatten, tree_map 30 31from ._fsdp_api import MixedPrecisionPolicy 32from ._fsdp_common import _cast_fp_tensor, TrainingState 33from ._fsdp_param_group import FSDPCommContext, FSDPParamGroup 34 35 36if TYPE_CHECKING: 37 from ._fsdp_param import FSDPParam 38 39 40logger = logging.getLogger("torch.distributed._composable.fsdp") 41 42 43class FSDPStateContext: 44 """This has state shared across FSDP states.""" 45 46 def __init__(self) -> None: 47 # All FSDP states in the root state's module tree 48 self.all_states: List[FSDPState] = [] 49 # Iteration's forward root runs the once-per-forward logic; this root 50 # may not be the overall root set by lazy initialization in cases where 51 # only a submodule runs forward (e.g. encoder-only for eval) 52 self.iter_forward_root: Optional[FSDPState] = None 53 # Final callback should only be queued once per backward 54 self.post_backward_final_callback_queued: bool = False 55 # Whether to finalize backward in this backward's final callback 56 self.is_last_backward: bool = True 57 # Optional user-provided event recorded after optimizer for the 58 # all-gather streams to wait on in the root pre-forward 59 self.post_optim_event: Optional[torch.cuda.Event] = None 60 61 62def disable_if_config_true(func): 63 @functools.wraps(func) 64 def fsdp_hook_wrapper(*args, **kwargs): 65 if torch._dynamo.config.skip_fsdp_hooks: 66 return torch._dynamo.disable(func, recursive=True)(*args, **kwargs) 67 else: 68 return func(*args, **kwargs) 69 70 return fsdp_hook_wrapper 71 72 73class FSDPState(_State): 74 def __init__(self) -> None: 75 super().__init__() 76 self._fsdp_param_group: Optional[FSDPParamGroup] = None 77 self._is_root: Optional[bool] = None # root set during lazy init 78 self._state_ctx = FSDPStateContext() 79 self._comm_ctx = FSDPCommContext() 80 self._training_state: TrainingState = TrainingState.IDLE 81 self._states_to_forward_prefetch: List[FSDPState] = [] 82 self._states_to_backward_prefetch: List[FSDPState] = [] 83 self._modules_to_run_forward: Set[nn.Module] = set() 84 85 # Define a separate init since `__init__` is called in the contract 86 def init( 87 self, 88 modules: Tuple[nn.Module, ...], 89 device: torch.device, 90 mp_policy: MixedPrecisionPolicy, 91 ) -> None: 92 for module in modules: 93 _insert_module_state(module, self) 94 self._modules = modules 95 self._device = device 96 self._mp_policy = mp_policy 97 if len(modules) == 1: 98 self._pre_forward_hook_handle = modules[0].register_forward_pre_hook( 99 self._pre_forward, prepend=True, with_kwargs=True 100 ) 101 self._post_forward_hook_handle = modules[0].register_forward_hook( 102 self._post_forward, prepend=False 103 ) 104 else: 105 hook_handle = _register_group_forward_hooks( 106 modules, 107 self._pre_forward, 108 self._post_forward, 109 self._modules_to_run_forward, 110 ) 111 self._pre_forward_hook_handle = hook_handle 112 self._post_forward_hook_handle = hook_handle 113 114 def _root_pre_forward( 115 self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any] 116 ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: 117 self._lazy_init() 118 if self._state_ctx.iter_forward_root is not None: 119 return args, kwargs 120 if not ca.compiled_autograd_enabled: 121 logger.debug("FSDP::root_pre_forward") 122 self._state_ctx.iter_forward_root = self 123 with torch.profiler.record_function("FSDP::root_pre_forward"): 124 # Wait for optimizer before implicitly prefetched all-gathers 125 if (event := self._state_ctx.post_optim_event) is not None: 126 self._comm_ctx.all_gather_copy_in_stream.wait_event(event) 127 self._comm_ctx.all_gather_stream.wait_event(event) 128 self._state_ctx.post_optim_event = None 129 else: 130 current_stream = torch.cuda.current_stream() 131 self._comm_ctx.all_gather_copy_in_stream.wait_stream(current_stream) 132 self._comm_ctx.all_gather_stream.wait_stream(current_stream) 133 if self._device.type == "cuda": 134 with torch.profiler.record_function("FSDP::inputs_to_device"): 135 args_tuple, kwargs_tuple = _to_kwargs( 136 args, kwargs, self._device, False 137 ) # same as DDP 138 args, kwargs = args_tuple[0], kwargs_tuple[0] 139 return args, kwargs 140 141 def _lazy_init(self) -> None: 142 """ 143 Lazy initialization represents when all modules' parallelisms have 144 finalized (e.g. FSDP has been applied to all desired modules). This 145 means that we can determine which state is the root, and we do so by 146 the 1st state to run forward. 147 """ 148 if self._is_root is not None: 149 return # no-op: already initialized 150 self._is_root = True 151 if len(self._modules) > 1: 152 raise RuntimeError( 153 f"FSDP requires a single root module but got {self._modules}" 154 ) 155 root_module = self._modules[0] 156 visited_states: Set[FSDPState] = set() 157 for module_name, module in root_module.named_modules(): 158 if (state := _get_module_fsdp_state(module)) is None: 159 continue 160 if module is not root_module: 161 if state not in visited_states and state._is_root is not None: 162 raise RuntimeError( 163 "FSDP state has already been lazily initialized for " 164 f"{module_name}\nFSDP requires running forward through " 165 "the root module first" 166 ) 167 state._is_root = False 168 self._state_ctx.all_states.append(state) 169 visited_states.add(state) 170 if self._fsdp_param_group: 171 # For the root, do not reshard after forward since for training, 172 # the parameters would be freed and all-gathered immediately 173 self._fsdp_param_group.post_forward_mesh_info = None 174 self._init_fqns() 175 self._init_shared_state() 176 # Run parameter group lazy inits after initializing FQNs for improved 177 # error messages 178 for state in self._state_ctx.all_states: 179 if state._fsdp_param_group: 180 state._fsdp_param_group.lazy_init() 181 182 def _init_shared_state(self) -> None: 183 self._comm_ctx.lazy_init() 184 for state in self._state_ctx.all_states: 185 state._state_ctx = self._state_ctx 186 state._comm_ctx = self._comm_ctx 187 if fsdp_param_group := state._fsdp_param_group: 188 fsdp_param_group.comm_ctx = self._comm_ctx 189 190 def _init_fqns(self) -> None: 191 """Sets module and parameter FQN attributes for debugging.""" 192 assert self._is_root 193 root_module = self._modules[0] 194 param_to_fsdp_param: Dict[nn.Parameter, FSDPParam] = {} 195 module_to_fsdp_param_group: Dict[nn.Module, FSDPParamGroup] = {} 196 for state in self._state_ctx.all_states: 197 if fsdp_param_group := state._fsdp_param_group: 198 for fsdp_param in fsdp_param_group.fsdp_params: 199 param_to_fsdp_param[fsdp_param.sharded_param] = fsdp_param 200 for module in fsdp_param_group.modules: 201 module_to_fsdp_param_group[module] = fsdp_param_group 202 for param_name, param in root_module.named_parameters(): 203 if param in param_to_fsdp_param: 204 param_to_fsdp_param[param]._param_fqn = param_name 205 for module_name, module in root_module.named_modules(): 206 if module in module_to_fsdp_param_group: 207 module_fqn = module_to_fsdp_param_group[module]._module_fqn 208 if module_fqn is None: 209 module_to_fsdp_param_group[module]._module_fqn = module_name 210 else: 211 assert isinstance(module_fqn, str), f"{module_fqn}" 212 module_fqn += f", {module_name}" 213 module_to_fsdp_param_group[module]._module_fqn = module_fqn 214 215 @disable_if_config_true 216 def _pre_forward( 217 self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any] 218 ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: 219 # When composing with module-hook-based activation checkpointing, the 220 # the pre-backward hook is responsible for the unshard 221 if self._training_state == TrainingState.PRE_BACKWARD: 222 return args, kwargs 223 self._training_state = TrainingState.FORWARD 224 args, kwargs = self._root_pre_forward(module, args, kwargs) 225 if self._mp_policy.cast_forward_inputs and self._mp_policy.param_dtype: 226 with torch.profiler.record_function("FSDP::cast_forward_inputs"): 227 cast_fn = functools.partial( 228 _cast_fp_tensor, self._mp_policy.param_dtype 229 ) 230 args, kwargs = tree_map(cast_fn, args), tree_map(cast_fn, kwargs) 231 if self._fsdp_param_group: 232 args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs) 233 for fsdp_state in self._states_to_forward_prefetch: 234 if (target_param_group := fsdp_state._fsdp_param_group) is not None: 235 FSDPParamGroup._prefetch_unshard(target_param_group, "forward") 236 return args, kwargs 237 238 @disable_if_config_true 239 def _post_forward(self, module: nn.Module, input: Any, output: Any) -> Any: 240 # When composing with module-hook-based activation checkpointing, the 241 # post-backward hook is responsible for the reshard 242 if self._training_state == TrainingState.PRE_BACKWARD: 243 return output 244 if self._fsdp_param_group: 245 output = self._fsdp_param_group.post_forward(module, input, output) 246 output = self._register_pre_backward_hook(output) 247 self._training_state = TrainingState.IDLE 248 if self._state_ctx.iter_forward_root is self: 249 if all_gather_state := self._comm_ctx.all_gather_state: 250 # Free the last all-gather result if needed; refer to 251 # [Note: Overlapping all-gather copy-in and all-gather] 252 self._comm_ctx.all_gather_copy_in_stream.wait_event( 253 all_gather_state.event 254 ) 255 self._comm_ctx.all_gather_stream.wait_event(all_gather_state.event) 256 self._comm_ctx.all_gather_state = None # free the all-gather result 257 self._state_ctx.iter_forward_root = None 258 if self._mp_policy.output_dtype is not None: 259 with torch.profiler.record_function("FSDP::cast_forward_outputs"): 260 output = tree_map( 261 functools.partial(_cast_fp_tensor, self._mp_policy.output_dtype), 262 output, 263 ) 264 return output 265 266 def _pre_backward(self, grad: torch.Tensor) -> torch.Tensor: 267 self._training_state = TrainingState.PRE_BACKWARD 268 self._register_root_post_backward_final_callback() 269 if self._fsdp_param_group: 270 default_prefetch = len(self._states_to_backward_prefetch) == 0 271 self._fsdp_param_group.pre_backward(default_prefetch) 272 for fsdp_state in self._states_to_backward_prefetch: 273 if (target_param_group := fsdp_state._fsdp_param_group) is not None: 274 FSDPParamGroup._prefetch_unshard(target_param_group, "backward") 275 return grad 276 277 def _root_post_backward_final_callback(self) -> None: 278 if not ca.compiled_autograd_enabled: 279 logger.debug("FSDP::root_post_backward") 280 with torch.profiler.record_function("FSDP::root_post_backward_callback"): 281 for state in self._state_ctx.all_states: 282 if state._fsdp_param_group and state._fsdp_param_group.is_unsharded: 283 # Run post-backward in case forward inputs did not require 284 # gradient so the autograd backward did not run 285 state._fsdp_param_group.post_backward() 286 state._training_state = TrainingState.IDLE 287 if state._fsdp_param_group: 288 state._fsdp_param_group._training_state = TrainingState.IDLE 289 if self._state_ctx.is_last_backward: 290 state._finalize_backward() 291 if self._state_ctx.is_last_backward: 292 self._comm_ctx.post_forward_order.clear() 293 if self._comm_ctx.reduce_scatter_state is not None: 294 torch.cuda.current_stream().wait_event( 295 self._comm_ctx.reduce_scatter_state.event 296 ) 297 self._comm_ctx.reduce_scatter_state = None 298 self._state_ctx.post_backward_final_callback_queued = False 299 300 def _finalize_backward(self) -> None: 301 if self._modules_to_run_forward: 302 msg = ( 303 f"{len(self._modules_to_run_forward)} of the {len(self._modules)} " 304 f"modules passed to fully_shard did not run forward before backward, " 305 "which is error-prone since FSDP post-forward/pre-backward logic " 306 "will not run for these modules. We recommend passing only modules " 307 "that run forward together. Modules that did not run forward: " 308 f"{list(self._modules_to_run_forward)}" 309 ) 310 warning_once(logger, msg, stacklevel=2) 311 # Clear since we want the next forward to run 312 self._modules_to_run_forward.clear() 313 if self._fsdp_param_group: 314 self._fsdp_param_group.finalize_backward() 315 316 def _register_pre_backward_hook(self, output: Any) -> Any: 317 if not torch.is_grad_enabled(): 318 return output 319 flat_outputs, _ = tree_flatten(output) 320 for t in flat_outputs: 321 if torch.is_tensor(t) and t.requires_grad: 322 t.register_hook(self._pre_backward) 323 return output 324 325 def _register_root_post_backward_final_callback(self): 326 if self._state_ctx.post_backward_final_callback_queued: 327 return 328 self._state_ctx.post_backward_final_callback_queued = True 329 Variable._execution_engine.queue_callback( 330 self._root_post_backward_final_callback 331 ) 332 333 334def _get_module_fsdp_state(module: nn.Module) -> Optional[FSDPState]: 335 state = _get_module_state(module) 336 if isinstance(state, FSDPState): 337 return state 338 return None 339 340 341def _register_group_forward_hooks( 342 modules: Sequence[nn.Module], 343 pre_hook: Callable, 344 post_hook: Callable, 345 modules_to_run: Set[nn.Module], 346): 347 """ 348 Registers group forward pre and post-hooks. The pre-hook runs upon the 349 first module pre-forward, and the post-hook runs upon the last. If at least 350 one module does not run forward, then the post-hook does not run. 351 """ 352 modules_set = set(modules) 353 354 @disable_if_config_true 355 @functools.wraps(pre_hook) 356 def wrapped_pre_hook(*args: Any, **kwargs: Any): 357 if len(modules_to_run) == 0: # first to run 358 modules_to_run.update(modules_set) 359 return pre_hook(*args, **kwargs) 360 361 @disable_if_config_true 362 def get_wrapped_post_hook(module: nn.Module): 363 @functools.wraps(post_hook) 364 def wrapped_post_hook(*args: Any, **kwargs: Any): 365 modules_to_run.discard(module) 366 if len(modules_to_run) == 0: 367 return post_hook(*args, **kwargs) 368 369 return wrapped_post_hook 370 371 pre_handles = [ 372 module.register_forward_pre_hook( 373 wrapped_pre_hook, prepend=True, with_kwargs=True 374 ) 375 for module in modules 376 ] 377 post_handles = [ 378 module.register_forward_hook( 379 get_wrapped_post_hook(module), prepend=False, always_call=True 380 ) 381 for module in modules 382 ] 383 return _MultiHandle(tuple(pre_handles + post_handles)) 384