1# mypy: allow-untyped-defs 2import dataclasses 3import traceback 4from typing import ( 5 Any, 6 Callable, 7 Container, 8 Dict, 9 List, 10 Optional, 11 OrderedDict, 12 overload, 13 Set, 14 Tuple, 15 TypeVar, 16) 17 18import torch 19import torch.distributed as dist 20from torch import nn 21from torch.nn.parallel._functions import _get_stream 22from torch.nn.parallel.scatter_gather import _is_namedtuple 23from torch.nn.utils.rnn import PackedSequence 24 25 26__all__ = [] # type: ignore[var-annotated] 27 28 29def _pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[Tuple[Any, ...], Tuple[str, ...]]: 30 """ 31 Turn argument list into separate key list and value list (unpack_kwargs does the opposite). 32 33 Inspiration: https://github.com/facebookresearch/fairscale/blob/eeb6684/fairscale/internal/containers.py#L70 34 Usage:: 35 36 kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4) 37 assert kwarg_keys == ("a", "b") 38 assert flat_args == (1, 2, 3, 4) 39 args, kwargs = unpack_kwargs(kwarg_keys, flat_args) 40 assert args == (1, 2) 41 assert kwargs == {"a": 3, "b": 4} 42 Returns: 43 Tuple[Tuple[Any, ...], Tuple[str, ...]]: The first tuple element gives 44 gives both positional args and kwarg values, where the positional args 45 proceed kwarg values and kwarg values are ordered consistently with the 46 kwarg keys. The second tuple element gives the kwarg keys. 47 The second tuple element's length is at most the first tuple element's length. 48 """ 49 kwarg_keys: List[str] = [] 50 flat_args: List[Any] = list(args) 51 for k, v in kwargs.items(): 52 kwarg_keys.append(k) 53 flat_args.append(v) 54 55 return tuple(flat_args), tuple(kwarg_keys) 56 57 58def _cast_forward_inputs( 59 dtype: Optional[torch.dtype], 60 *args: Any, 61 **kwargs: Any, 62) -> Tuple[Any, Any]: 63 """ 64 Cast floating point tensors in ``args`` and ``kwargs`` to ``input_dtype``. 65 66 This respects the existing ``requires_grad`` on the tensors. 67 """ 68 if dtype is None: 69 return args, kwargs 70 71 def cast_fn(x: torch.Tensor) -> torch.Tensor: 72 if not torch.is_floating_point(x) or x.dtype == dtype: 73 return x 74 return x.to(dtype) 75 76 return (_apply_to_tensors(cast_fn, args), _apply_to_tensors(cast_fn, kwargs)) 77 78 79def _unpack_kwargs( 80 flat_args: Tuple[Any, ...], kwarg_keys: Tuple[str, ...] 81) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: 82 """See _pack_kwargs.""" 83 assert len(kwarg_keys) <= len( 84 flat_args 85 ), f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}" 86 if len(kwarg_keys) == 0: 87 return flat_args, {} 88 args = flat_args[: -len(kwarg_keys)] 89 kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys) :])) 90 return args, kwargs 91 92 93S = TypeVar("S", dict, list, tuple) 94T = TypeVar("T", torch.Tensor, PackedSequence) 95 96 97@overload 98def _recursive_to( 99 inputs: S, target_device: torch.device, use_side_stream_for_tensor_copies: bool 100) -> List[S]: 101 ... 102 103 104@overload 105def _recursive_to( 106 inputs: T, target_device: torch.device, use_side_stream_for_tensor_copies: bool 107) -> Tuple[T]: 108 ... 109 110 111def _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies): 112 r"""Recursively moves input to the target_device.""" 113 114 def to_map(obj): 115 if isinstance(obj, (torch.Tensor, PackedSequence)): 116 device = obj.data.device if isinstance(obj, PackedSequence) else obj.device 117 if device == target_device: 118 return (obj,) 119 if not use_side_stream_for_tensor_copies: 120 return (obj.to(target_device),) 121 else: 122 # If the custom module is not registered to torch, stream is not used for acceleration 123 device_mod = getattr(torch, device.type, None) 124 if device.type == "cpu" or device_mod is None: 125 return (obj.to(target_device),) 126 # Perform CPU -> target_device copies in a background stream. This code is 127 # motivated from similar logic in torch/nn/parallel/_functions.py 128 stream = _get_stream(target_device) 129 with device_mod.stream(stream): 130 output = obj.to(target_device) 131 # synchronize with the copy stream 132 with device_mod.device(target_device.index): 133 current_stream = device_mod.current_stream() 134 # Sync the current stream with the copy stream 135 current_stream.wait_stream(stream) 136 # Ensure tensor memory is not reused until work on 137 # main stream is complete 138 if isinstance(obj, PackedSequence): 139 output.data.record_stream(current_stream) # type: ignore[arg-type] 140 else: 141 assert isinstance(output, torch.Tensor) 142 output.record_stream(current_stream) # type: ignore[arg-type] 143 return (output,) 144 if _is_namedtuple(obj): 145 return [type(obj)(*args) for args in zip(*map(to_map, obj))] 146 if isinstance(obj, tuple) and len(obj) > 0: 147 return list(zip(*map(to_map, obj))) 148 if isinstance(obj, list) and len(obj) > 0: 149 return [list(i) for i in zip(*map(to_map, obj))] 150 if isinstance(obj, dict) and len(obj) > 0: 151 return [type(obj)(i) for i in zip(*map(to_map, obj.items()))] 152 return [obj] 153 154 # Avoid reference cycle 155 try: 156 res = to_map(inputs) 157 finally: 158 to_map = None # type: ignore[assignment] 159 return res 160 161 162def _p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None: 163 """Alternate to ``assert`` when in the backward context to print the error message ``s`` since otherwise, it is swallowed.""" 164 if not cond: 165 print(s) 166 traceback.print_stack() 167 if raise_assertion_error: 168 raise AssertionError(s) 169 170 171def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> None: 172 """ 173 Allocate storage for ``tensor`` with the given size. 174 175 Returns: 176 bool: ``True`` if this method allocated storage and ``False`` if the 177 storage was already allocated. 178 """ 179 with torch.no_grad(): 180 if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): 181 already_allocated = tensor._typed_storage()._size() == size.numel() 182 if not already_allocated: 183 tensor_storage_size = tensor._typed_storage()._size() 184 _p_assert( 185 tensor_storage_size == 0, 186 "Tensor storage should have been resized to be 0 but got PLACEHOLDEr", 187 ) 188 tensor._typed_storage()._resize_(size.numel()) 189 190 191def _free_storage(tensor: torch.Tensor): 192 """ 193 Frees the underlying storage of ``tensor``. 194 195 Returns: 196 bool: ``True`` if the method freed the storage and ``False`` if the 197 storage was already freed. 198 """ 199 with torch.no_grad(): 200 if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): 201 already_freed = tensor._typed_storage()._size() == 0 202 if not already_freed: 203 _p_assert( 204 tensor.storage_offset() == 0, 205 "Freeing a tensor's storage is unsafe when it is not the sole occupant\n" 206 f"storage offset: {tensor.storage_offset()}\n" 207 f"storage size: {tensor._typed_storage()._size()}\n" 208 f"tensor shape: {tensor.shape}", 209 ) 210 tensor._typed_storage()._resize_(0) 211 212 213Q = TypeVar("Q") 214R = TypeVar("R", dict, list, tuple, set, OrderedDict, PackedSequence, Any) 215 216 217@overload 218def _apply_to_tensors(fn: Callable[[torch.Tensor], Q], container: torch.Tensor) -> Q: 219 ... 220 221 222@overload 223def _apply_to_tensors(fn: Callable[[torch.Tensor], Any], container: R) -> R: 224 ... 225 226 227def _apply_to_tensors(fn, container): 228 """Recursively apply to all tensor in different kinds of container types.""" 229 230 def apply(x): 231 if isinstance(x, torch.Tensor): 232 return fn(x) 233 elif hasattr(x, "__dataclass_fields__"): 234 dc = dataclasses.replace(x) 235 changes = { 236 f.name: apply(getattr(dc, f.name)) for f in dataclasses.fields(dc) 237 } 238 return dataclasses.replace(dc, **changes) 239 elif isinstance(x, OrderedDict): 240 od = x.__class__() 241 for key, value in x.items(): 242 od[key] = apply(value) 243 return od 244 elif isinstance(x, PackedSequence): 245 apply(x.data) 246 return x 247 elif isinstance(x, dict): 248 return {key: apply(value) for key, value in x.items()} 249 elif _is_namedtuple(x): 250 res = (apply(el) for el in x) 251 return type(x)(*res) 252 elif isinstance(x, (list, tuple, set)): 253 return type(x)(apply(el) for el in x) 254 else: 255 return x 256 257 return apply(container) 258 259 260def _to_kwargs( 261 inputs: Tuple[Any, ...], 262 kwargs: Optional[Dict[str, Any]], 263 target_device: torch.device, 264 use_side_stream_for_tensor_copies: bool, 265) -> Tuple[Tuple[Any, ...], Tuple[Dict[str, Any], ...]]: 266 moved_inputs = ( 267 _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies) 268 if inputs 269 else [] 270 ) 271 moved_kwargs = ( 272 _recursive_to(kwargs, target_device, use_side_stream_for_tensor_copies) 273 if kwargs 274 else [] 275 ) 276 if len(moved_inputs) < len(moved_kwargs): 277 moved_inputs.extend([() for _ in range(len(moved_kwargs) - len(inputs))]) 278 elif len(moved_kwargs) < len(moved_inputs): 279 moved_kwargs.extend([{} for _ in range(len(moved_inputs) - len(moved_kwargs))]) 280 return tuple(moved_inputs), tuple(moved_kwargs) 281 282 283def _verify_param_shape_across_processes( 284 process_group: dist.ProcessGroup, 285 tensors: List[torch.Tensor], 286 logger: Optional["dist.Logger"] = None, 287): 288 return dist._verify_params_across_processes(process_group, tensors, logger) 289 290 291def _sync_module_states( 292 module: nn.Module, 293 process_group: dist.ProcessGroup, 294 broadcast_bucket_size: int, 295 src: int, 296 params_and_buffers_to_ignore: Container[str], 297 broadcast_buffers: bool = True, 298) -> None: 299 """ 300 Sync ``module``'s parameters and buffers state. 301 302 Syncs ``module``'s parameters and buffers state so that all ranks contain 303 the same module state across all ranks. Note that this API assumes that all 304 parameter shapes are consistent before running the synchronization. This can 305 be checked with ``_verify_param_shape_across_processes``. 306 """ 307 module_states: List[torch.Tensor] = [] 308 for name, param in module.named_parameters(): 309 if name not in params_and_buffers_to_ignore: 310 module_states.append(param.detach()) 311 312 if broadcast_buffers: 313 for name, buffer in module.named_buffers(): 314 if name not in params_and_buffers_to_ignore: 315 module_states.append(buffer.detach()) 316 317 _sync_params_and_buffers(process_group, module_states, broadcast_bucket_size, src) 318 319 320def _sync_params_and_buffers( 321 process_group: dist.ProcessGroup, 322 module_states: List[torch.Tensor], 323 broadcast_bucket_size: int, 324 src: int, 325) -> None: 326 """Synchronize ``module_states`` (list of tensors) across all processes by broadcasting them from rank 0.""" 327 if len(module_states) > 0: 328 dist._broadcast_coalesced( 329 process_group, module_states, broadcast_bucket_size, src 330 ) 331 332 333def _replace_by_prefix( 334 state_dict: Dict[str, Any], 335 old_prefix: str, 336 new_prefix: str, 337) -> None: 338 """ 339 Replace all keys that match a given old_prefix with a new_prefix (in-place). 340 341 Usage:: 342 343 state_dict = {"layer.xyz": torch.tensor(1)} 344 replace_by_prefix_(state_dict, "layer.", "module.layer.") 345 assert state_dict == {"module.layer.xyz": torch.tensor(1)} 346 """ 347 if old_prefix == new_prefix: 348 raise ValueError("old_prefix and new_prefix must be distinct") 349 for key in list(state_dict.keys()): 350 if not key.startswith(old_prefix): 351 continue 352 new_key = new_prefix + key[len(old_prefix) :] 353 state_dict[new_key] = state_dict[key] 354 del state_dict[key] 355 356 357def _data_ptr_allocated(tensor: torch.Tensor) -> bool: 358 return tensor.untyped_storage().data_ptr() > 0 359 360 361def _get_root_modules(modules: List[nn.Module]) -> List[nn.Module]: 362 """ 363 Returns the modules in ``modules`` that are root modules (i.e. 364 parent-less) with respect to the set ``modules``. In other words, these 365 are the modules in ``modules`` that are the not child of any other 366 module in ``modules``. 367 """ 368 root_modules: List[nn.Module] = [] 369 module_to_modules: Dict[nn.Module, Set[nn.Module]] = { 370 module: set(module.modules()) for module in modules 371 } 372 for candidate_module in modules: 373 is_root_module = True 374 for module, _modules in module_to_modules.items(): 375 is_child_module = ( 376 candidate_module is not module and candidate_module in _modules 377 ) 378 if is_child_module: 379 is_root_module = False 380 break 381 if is_root_module: 382 root_modules.append(candidate_module) 383 return root_modules 384