1# mypy: allow-untyped-defs 2import itertools 3import warnings 4from enum import auto, Enum 5from typing import Dict, List, Optional, Tuple, Union 6 7import torch 8import torch.distributed as dist 9import torch.distributed.fsdp._traversal_utils as traversal_utils 10import torch.nn as nn 11from torch.distributed.fsdp._common_utils import _FSDPState, _get_param_to_fqns 12from torch.distributed.fsdp._flat_param import FlatParamHandle 13 14 15class _ExecOrderWarnStatus(Enum): 16 """Used internally for execution order validation.""" 17 18 NONE = auto() # no deviation yet 19 WARNING = auto() # deviated this iteration; currently issuing warnings 20 WARNED = auto() # deviated in a previous iteration 21 22 23class _ExecOrderData: 24 """ 25 This contains the data structures to track the execution order. We track 26 the pre-forward order on the *first* iteration for forward prefetching 27 (which thus assumes static graph) and the post-forward order on *every* 28 iteration for backward prefetching (which thus does not assume static 29 graph but may be provide an incorrect order). 30 """ 31 32 def __init__( 33 self, 34 debug_level: dist.DebugLevel, 35 backward_prefetch_limit: int, 36 forward_prefetch_limit: int, 37 ) -> None: 38 # Tracks the (static) pre-forward order for execution order validation 39 # and forward prefetching 40 self.handles_pre_forward_order: List[FlatParamHandle] = [] 41 # Tracks the post-forward order for pre-backward prefetching 42 self.handles_post_forward_order: List[Optional[FlatParamHandle]] = [] 43 self._iter = 0 44 45 # Gives the max number of backward/forward prefetched all-gathers by a 46 # single module 47 self._backward_prefetch_limit = backward_prefetch_limit 48 self._forward_prefetch_limit = forward_prefetch_limit 49 50 # Data structures for execution order validation 51 self._checking_order: bool = debug_level == dist.DebugLevel.DETAIL 52 self.process_group: Optional[dist.ProcessGroup] = None 53 self.world_size: Optional[int] = None 54 self.all_handles: List[FlatParamHandle] = [] 55 # Names are prefixed from the root module 56 self.param_to_fqn: Dict[nn.Parameter, List[str]] = {} 57 # Current index in the pre-forward execution order 58 self.current_order_index = 0 59 self.warn_status = _ExecOrderWarnStatus.NONE 60 61 def init( 62 self, 63 state: _FSDPState, 64 root_module: nn.Module, 65 process_group: dist.ProcessGroup, 66 ) -> None: 67 """ 68 Initializes the data structures needed for checking the forward order. 69 This should be called after a root FSDP instance has been set during 70 lazy initialization. 71 """ 72 self.process_group = process_group 73 self.rank = process_group.rank() 74 self.world_size = process_group.size() 75 # Fix an order over the handles, which should be the same across ranks 76 for handle in traversal_utils._get_fsdp_handles(root_module): 77 index = len(self.all_handles) 78 self.all_handles.append(handle) 79 handle._handle_index = index 80 self.param_to_fqn = _get_param_to_fqns(root_module) 81 # TODO (awgu): We can broadcast the metadata of rank 0's `all_handles` 82 # to check that all ranks have the same handles in the same order. 83 # https://github.com/pytorch/pytorch/issues/79620 84 85 @property 86 def is_first_iter(self) -> bool: 87 return self._iter == 0 88 89 def get_handle_to_backward_prefetch( 90 self, 91 current_handle: FlatParamHandle, 92 ) -> Optional[FlatParamHandle]: 93 """ 94 Returns a :class:`list` of the handles keys of the handles to backward 95 prefetch given the current handles key. If there are no valid handles 96 keys to prefetch, then this returns an empty :class:`list`. 97 """ 98 current_index = current_handle._post_forward_index 99 if current_index is None: 100 return None 101 target_index = current_index - 1 102 target_handle: Optional[FlatParamHandle] = None 103 for _ in range(self._backward_prefetch_limit): 104 if target_index < 0: 105 break 106 target_handle = self.handles_post_forward_order[target_index] 107 target_index -= 1 108 return target_handle 109 110 def get_handle_to_forward_prefetch( 111 self, 112 current_handle: FlatParamHandle, 113 ) -> Optional[FlatParamHandle]: 114 """ 115 Returns a :class:`list` of the handles keys of the handles to forward 116 prefetch given the current handles key. If there are no valid handles 117 keys to prefetch, then this returns an empty :class:`list`. 118 """ 119 current_index = current_handle._pre_forward_order_index 120 if current_index is None: 121 return None 122 target_index = current_index + 1 123 target_handle: Optional[FlatParamHandle] = None 124 for _ in range(self._forward_prefetch_limit): 125 if target_index >= len(self.handles_pre_forward_order): 126 break 127 target_handle = self.handles_pre_forward_order[target_index] 128 target_index += 1 129 return target_handle 130 131 def record_post_forward(self, handle: Optional[FlatParamHandle]) -> None: 132 """ 133 Records ``handles`` in the post-forward order, where ``handles`` should 134 be a group of handles used in the same module's forward. If ``handles`` 135 is empty, then it is omitted. 136 137 Unlike :meth:`record_pre_forward`, this records the order *every* 138 iteration with the expectation that the recorded order is reset in 139 :meth:`next_iter`. 140 """ 141 if not handle: 142 return 143 # Only record the first usage of a handles key 144 if handle._post_forward_index: 145 self.handles_post_forward_order.append(handle) 146 return 147 index = len(self.handles_post_forward_order) 148 handle._post_forward_index = index 149 self.handles_post_forward_order.append(handle) 150 151 def record_pre_forward( 152 self, handle: Optional[FlatParamHandle], is_training: bool 153 ) -> None: 154 """ 155 Records ``handles`` in the pre-forward order, where ``handles`` should 156 be a group of handles used in the same module's forward. If ``handles`` 157 is empty, then it is omitted. 158 159 On the first iteration, this checks the execution order across ranks. 160 See :meth:`_check_order` for details. 161 """ 162 if not handle: 163 return 164 self._check_order(handle, is_training) 165 # Fix the order after the first iteration and only record the first 166 # usage of a handles key 167 if not self.is_first_iter or handle._pre_forward_order_index is not None: 168 return 169 index = len(self.handles_pre_forward_order) 170 handle._pre_forward_order_index = index 171 self.handles_pre_forward_order.append(handle) 172 173 def _check_order(self, handle: FlatParamHandle, is_training: bool) -> None: 174 """ 175 Checks the forward execution order as long as ``is_training`` is 176 ``True`` since checking in eval mode is not supported. This only checks 177 if the distributed debug level is DETAIL. 178 179 - On the first iteration, this uses all-gathers to check that all ranks 180 are all-gathering the same handles and hence ``FlatParameter`` s, 181 raising an error if not. 182 - On subsequent iterations, this checks that each rank is locally 183 consistent with its own forward order from the first iteration, issuing 184 a warning if not. This issues a warning on the first deviating 185 iteration and stops warning thereafter. 186 """ 187 # Do not check order in eval mode since the post-backward callback does 188 # not run so it cannot be used to mark the end of an iteration 189 if not is_training or not self._checking_order: 190 return 191 if self.is_first_iter: 192 msg_prefix = "Forward order differs across ranks:" 193 optional_local_indices: Tuple[ 194 Optional[int], ... 195 ] = self._get_handle_indices(handle) 196 device = handle.device # guaranteed to be non-CPU 197 num_valid_indices = sum( 198 (index is not None) for index in optional_local_indices 199 ) 200 tensor_kwargs: Dict[str, Union[torch.dtype, torch.device]] = { 201 "dtype": torch.int32, 202 "device": device, 203 } 204 world_num_valid_indices = torch.zeros(self.world_size, **tensor_kwargs) # type: ignore[arg-type, call-overload] 205 local_num_valid_indices = torch.tensor([num_valid_indices], **tensor_kwargs) # type: ignore[arg-type, call-overload] 206 dist.all_gather_into_tensor( 207 world_num_valid_indices, 208 local_num_valid_indices, 209 group=self.process_group, 210 ) 211 # Copy entire tensor from D2H once to avoid per element D2H copies 212 world_num_valid_indices = world_num_valid_indices.cpu() 213 # Check that all ranks plan to all-gather the same number of 214 # parameters 215 # TODO (awgu): Since every module has at most one handle in the 216 # current implementation, this should never raise the error. 217 assert self.world_size is not None # mypy 218 if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): 219 # TODO(voz): Don't graph break on this - dynamo hates the n1 != n2 220 # tensor comparison control flow. 221 # https://github.com/pytorch/pytorch/issues/107055 222 for (r1, n1), (r2, n2) in itertools.combinations( 223 ( 224 (rank, world_num_valid_indices[rank]) 225 for rank in range(self.world_size) 226 ), 227 2, 228 ): 229 if n1 != n2: 230 raise RuntimeError( 231 f"{msg_prefix} rank {r1} is all-gathering {n1} parameters " 232 f"while rank {r2} is all-gathering {n2} parameters" 233 ) 234 world_indices = torch.zeros( # type: ignore[call-overload] 235 self.world_size * num_valid_indices, **tensor_kwargs 236 ) 237 local_indices = torch.tensor(optional_local_indices, **tensor_kwargs) # type: ignore[arg-type] 238 dist.all_gather_into_tensor( 239 world_indices, local_indices, group=self.process_group 240 ) 241 # Copy entire tensor from D2H once to avoid per element D2H copies 242 world_indices = world_indices.cpu() 243 # Check that all ranks plan to all-gather the same index parameters 244 if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): 245 # TODO(voz): Don't graph break on this - dynamo hates the i1 != i2 246 # tensor comparison control flow. 247 # https://github.com/pytorch/pytorch/issues/107055 248 for (r1, i1), (r2, i2) in itertools.combinations( 249 ( 250 ( 251 rank, 252 world_indices[ 253 rank 254 * num_valid_indices : (rank + 1) 255 * num_valid_indices 256 ], 257 ) 258 for rank in range(self.world_size) 259 ), 260 2, 261 ): 262 if i1 != i2: 263 r1_param_names = self._get_names_from_handle_indices(i1) 264 r2_param_names = self._get_names_from_handle_indices(i2) 265 raise RuntimeError( 266 f"{msg_prefix} rank {r1} is all-gathering parameters " 267 f"for {r1_param_names} while rank {r2} is all-gathering " 268 f"parameters for {r2_param_names}" 269 ) 270 else: 271 # Only issue warnings on the first deviating iteration and stop 272 # checking thereafter to avoid flooding the console 273 if self.warn_status == _ExecOrderWarnStatus.WARNED: 274 return 275 msg_prefix = None # non-`None` means we should warn 276 if self.current_order_index >= len(self.handles_pre_forward_order): 277 # This iteration sees extra all-gather(s) compared to the first 278 msg_prefix = ( 279 "Expected to not all-gather any more parameters in the " 280 "forward but trying to all-gather parameters for " 281 ) 282 else: 283 expected_handle = self.handles_pre_forward_order[ 284 self.current_order_index 285 ] 286 if expected_handle != handle: 287 expected_param_names = self._get_names_from_handles(expected_handle) 288 msg_prefix = ( 289 f"Expected to all-gather for {expected_param_names} " 290 "but trying to all-gather parameters for " 291 ) 292 if msg_prefix is not None: 293 param_names = self._get_names_from_handles(handle) 294 msg_suffix = ( 295 f"{param_names}" 296 if param_names 297 else "a newly-added parameter since construction time" 298 ) 299 warnings.warn( 300 "Forward order differs from that of the first iteration " 301 f"on rank {self.rank}. Collectives are unchecked and may " 302 f"give incorrect results or hang.\n{msg_prefix}{msg_suffix}" 303 ) 304 self.warn_status = _ExecOrderWarnStatus.WARNING 305 self.current_order_index += 1 306 307 def _get_handle_indices( 308 self, 309 handle: FlatParamHandle, 310 ) -> Tuple[Optional[int], ...]: 311 """ 312 Returns the handle indices (i.e. indices into ``self.all_handles``) 313 corresponding to the handles in ``handle``. An entry in the 314 returned tuple is ``None`` if the handle is invalid. 315 """ 316 indices: List[Optional[int]] = [] 317 if handle: 318 indices.append(handle._handle_index) 319 return tuple(indices) 320 321 def _get_names_from_handle_indices( 322 self, 323 handle_indices: Tuple[int, ...], 324 ) -> List[List[str]]: 325 """ 326 Returns a list of FQNs for each handle in ``handle_indices``. If a 327 handle index is invalid, then its FQNs are omitted from the returned 328 list. 329 """ 330 fqns: List[List[str]] = [] 331 for index in handle_indices: 332 if index is None or index < 0 or index >= len(self.all_handles): 333 continue 334 handle = self.all_handles[index] 335 flat_param = handle.flat_param 336 fqns.append(self.param_to_fqn[flat_param]) 337 return fqns 338 339 def _get_names_from_handles( 340 self, 341 handle: FlatParamHandle, 342 ) -> List[List[str]]: 343 """ 344 Returns a list of FQNs for each handle in ``handles_key``. If a handle 345 is invalid, then its FQNs are omitted from the returned list. 346 """ 347 fqns: List[List[str]] = [] 348 if handle: 349 flat_param = handle.flat_param 350 if flat_param in self.param_to_fqn: 351 fqns.append(self.param_to_fqn[flat_param]) 352 return fqns 353 354 def next_iter(self): 355 """ 356 Advances the internal data structures per iteration. This should be 357 called in the post-backward callback since that marks the true end of 358 an iteration. 359 """ 360 self._iter += 1 361 self.handles_post_forward_order.clear() 362 if self._checking_order: 363 self.current_order_index = 0 364 if self.warn_status == _ExecOrderWarnStatus.WARNING: 365 self.warn_status = _ExecOrderWarnStatus.WARNED 366