1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport dataclasses 3*da0073e9SAndroid Build Coastguard Workerimport traceback 4*da0073e9SAndroid Build Coastguard Workerfrom typing import ( 5*da0073e9SAndroid Build Coastguard Worker Any, 6*da0073e9SAndroid Build Coastguard Worker Callable, 7*da0073e9SAndroid Build Coastguard Worker Container, 8*da0073e9SAndroid Build Coastguard Worker Dict, 9*da0073e9SAndroid Build Coastguard Worker List, 10*da0073e9SAndroid Build Coastguard Worker Optional, 11*da0073e9SAndroid Build Coastguard Worker OrderedDict, 12*da0073e9SAndroid Build Coastguard Worker overload, 13*da0073e9SAndroid Build Coastguard Worker Set, 14*da0073e9SAndroid Build Coastguard Worker Tuple, 15*da0073e9SAndroid Build Coastguard Worker TypeVar, 16*da0073e9SAndroid Build Coastguard Worker) 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Workerimport torch 19*da0073e9SAndroid Build Coastguard Workerimport torch.distributed as dist 20*da0073e9SAndroid Build Coastguard Workerfrom torch import nn 21*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.parallel._functions import _get_stream 22*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.parallel.scatter_gather import _is_namedtuple 23*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.utils.rnn import PackedSequence 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker__all__ = [] # type: ignore[var-annotated] 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Workerdef _pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[Tuple[Any, ...], Tuple[str, ...]]: 30*da0073e9SAndroid Build Coastguard Worker """ 31*da0073e9SAndroid Build Coastguard Worker Turn argument list into separate key list and value list (unpack_kwargs does the opposite). 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker Inspiration: https://github.com/facebookresearch/fairscale/blob/eeb6684/fairscale/internal/containers.py#L70 34*da0073e9SAndroid Build Coastguard Worker Usage:: 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4) 37*da0073e9SAndroid Build Coastguard Worker assert kwarg_keys == ("a", "b") 38*da0073e9SAndroid Build Coastguard Worker assert flat_args == (1, 2, 3, 4) 39*da0073e9SAndroid Build Coastguard Worker args, kwargs = unpack_kwargs(kwarg_keys, flat_args) 40*da0073e9SAndroid Build Coastguard Worker assert args == (1, 2) 41*da0073e9SAndroid Build Coastguard Worker assert kwargs == {"a": 3, "b": 4} 42*da0073e9SAndroid Build Coastguard Worker Returns: 43*da0073e9SAndroid Build Coastguard Worker Tuple[Tuple[Any, ...], Tuple[str, ...]]: The first tuple element gives 44*da0073e9SAndroid Build Coastguard Worker gives both positional args and kwarg values, where the positional args 45*da0073e9SAndroid Build Coastguard Worker proceed kwarg values and kwarg values are ordered consistently with the 46*da0073e9SAndroid Build Coastguard Worker kwarg keys. The second tuple element gives the kwarg keys. 47*da0073e9SAndroid Build Coastguard Worker The second tuple element's length is at most the first tuple element's length. 48*da0073e9SAndroid Build Coastguard Worker """ 49*da0073e9SAndroid Build Coastguard Worker kwarg_keys: List[str] = [] 50*da0073e9SAndroid Build Coastguard Worker flat_args: List[Any] = list(args) 51*da0073e9SAndroid Build Coastguard Worker for k, v in kwargs.items(): 52*da0073e9SAndroid Build Coastguard Worker kwarg_keys.append(k) 53*da0073e9SAndroid Build Coastguard Worker flat_args.append(v) 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker return tuple(flat_args), tuple(kwarg_keys) 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Workerdef _cast_forward_inputs( 59*da0073e9SAndroid Build Coastguard Worker dtype: Optional[torch.dtype], 60*da0073e9SAndroid Build Coastguard Worker *args: Any, 61*da0073e9SAndroid Build Coastguard Worker **kwargs: Any, 62*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Any, Any]: 63*da0073e9SAndroid Build Coastguard Worker """ 64*da0073e9SAndroid Build Coastguard Worker Cast floating point tensors in ``args`` and ``kwargs`` to ``input_dtype``. 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker This respects the existing ``requires_grad`` on the tensors. 67*da0073e9SAndroid Build Coastguard Worker """ 68*da0073e9SAndroid Build Coastguard Worker if dtype is None: 69*da0073e9SAndroid Build Coastguard Worker return args, kwargs 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker def cast_fn(x: torch.Tensor) -> torch.Tensor: 72*da0073e9SAndroid Build Coastguard Worker if not torch.is_floating_point(x) or x.dtype == dtype: 73*da0073e9SAndroid Build Coastguard Worker return x 74*da0073e9SAndroid Build Coastguard Worker return x.to(dtype) 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker return (_apply_to_tensors(cast_fn, args), _apply_to_tensors(cast_fn, kwargs)) 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Workerdef _unpack_kwargs( 80*da0073e9SAndroid Build Coastguard Worker flat_args: Tuple[Any, ...], kwarg_keys: Tuple[str, ...] 81*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: 82*da0073e9SAndroid Build Coastguard Worker """See _pack_kwargs.""" 83*da0073e9SAndroid Build Coastguard Worker assert len(kwarg_keys) <= len( 84*da0073e9SAndroid Build Coastguard Worker flat_args 85*da0073e9SAndroid Build Coastguard Worker ), f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}" 86*da0073e9SAndroid Build Coastguard Worker if len(kwarg_keys) == 0: 87*da0073e9SAndroid Build Coastguard Worker return flat_args, {} 88*da0073e9SAndroid Build Coastguard Worker args = flat_args[: -len(kwarg_keys)] 89*da0073e9SAndroid Build Coastguard Worker kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys) :])) 90*da0073e9SAndroid Build Coastguard Worker return args, kwargs 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard WorkerS = TypeVar("S", dict, list, tuple) 94*da0073e9SAndroid Build Coastguard WorkerT = TypeVar("T", torch.Tensor, PackedSequence) 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker@overload 98*da0073e9SAndroid Build Coastguard Workerdef _recursive_to( 99*da0073e9SAndroid Build Coastguard Worker inputs: S, target_device: torch.device, use_side_stream_for_tensor_copies: bool 100*da0073e9SAndroid Build Coastguard Worker) -> List[S]: 101*da0073e9SAndroid Build Coastguard Worker ... 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker@overload 105*da0073e9SAndroid Build Coastguard Workerdef _recursive_to( 106*da0073e9SAndroid Build Coastguard Worker inputs: T, target_device: torch.device, use_side_stream_for_tensor_copies: bool 107*da0073e9SAndroid Build Coastguard Worker) -> Tuple[T]: 108*da0073e9SAndroid Build Coastguard Worker ... 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Workerdef _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies): 112*da0073e9SAndroid Build Coastguard Worker r"""Recursively moves input to the target_device.""" 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker def to_map(obj): 115*da0073e9SAndroid Build Coastguard Worker if isinstance(obj, (torch.Tensor, PackedSequence)): 116*da0073e9SAndroid Build Coastguard Worker device = obj.data.device if isinstance(obj, PackedSequence) else obj.device 117*da0073e9SAndroid Build Coastguard Worker if device == target_device: 118*da0073e9SAndroid Build Coastguard Worker return (obj,) 119*da0073e9SAndroid Build Coastguard Worker if not use_side_stream_for_tensor_copies: 120*da0073e9SAndroid Build Coastguard Worker return (obj.to(target_device),) 121*da0073e9SAndroid Build Coastguard Worker else: 122*da0073e9SAndroid Build Coastguard Worker # If the custom module is not registered to torch, stream is not used for acceleration 123*da0073e9SAndroid Build Coastguard Worker device_mod = getattr(torch, device.type, None) 124*da0073e9SAndroid Build Coastguard Worker if device.type == "cpu" or device_mod is None: 125*da0073e9SAndroid Build Coastguard Worker return (obj.to(target_device),) 126*da0073e9SAndroid Build Coastguard Worker # Perform CPU -> target_device copies in a background stream. This code is 127*da0073e9SAndroid Build Coastguard Worker # motivated from similar logic in torch/nn/parallel/_functions.py 128*da0073e9SAndroid Build Coastguard Worker stream = _get_stream(target_device) 129*da0073e9SAndroid Build Coastguard Worker with device_mod.stream(stream): 130*da0073e9SAndroid Build Coastguard Worker output = obj.to(target_device) 131*da0073e9SAndroid Build Coastguard Worker # synchronize with the copy stream 132*da0073e9SAndroid Build Coastguard Worker with device_mod.device(target_device.index): 133*da0073e9SAndroid Build Coastguard Worker current_stream = device_mod.current_stream() 134*da0073e9SAndroid Build Coastguard Worker # Sync the current stream with the copy stream 135*da0073e9SAndroid Build Coastguard Worker current_stream.wait_stream(stream) 136*da0073e9SAndroid Build Coastguard Worker # Ensure tensor memory is not reused until work on 137*da0073e9SAndroid Build Coastguard Worker # main stream is complete 138*da0073e9SAndroid Build Coastguard Worker if isinstance(obj, PackedSequence): 139*da0073e9SAndroid Build Coastguard Worker output.data.record_stream(current_stream) # type: ignore[arg-type] 140*da0073e9SAndroid Build Coastguard Worker else: 141*da0073e9SAndroid Build Coastguard Worker assert isinstance(output, torch.Tensor) 142*da0073e9SAndroid Build Coastguard Worker output.record_stream(current_stream) # type: ignore[arg-type] 143*da0073e9SAndroid Build Coastguard Worker return (output,) 144*da0073e9SAndroid Build Coastguard Worker if _is_namedtuple(obj): 145*da0073e9SAndroid Build Coastguard Worker return [type(obj)(*args) for args in zip(*map(to_map, obj))] 146*da0073e9SAndroid Build Coastguard Worker if isinstance(obj, tuple) and len(obj) > 0: 147*da0073e9SAndroid Build Coastguard Worker return list(zip(*map(to_map, obj))) 148*da0073e9SAndroid Build Coastguard Worker if isinstance(obj, list) and len(obj) > 0: 149*da0073e9SAndroid Build Coastguard Worker return [list(i) for i in zip(*map(to_map, obj))] 150*da0073e9SAndroid Build Coastguard Worker if isinstance(obj, dict) and len(obj) > 0: 151*da0073e9SAndroid Build Coastguard Worker return [type(obj)(i) for i in zip(*map(to_map, obj.items()))] 152*da0073e9SAndroid Build Coastguard Worker return [obj] 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker # Avoid reference cycle 155*da0073e9SAndroid Build Coastguard Worker try: 156*da0073e9SAndroid Build Coastguard Worker res = to_map(inputs) 157*da0073e9SAndroid Build Coastguard Worker finally: 158*da0073e9SAndroid Build Coastguard Worker to_map = None # type: ignore[assignment] 159*da0073e9SAndroid Build Coastguard Worker return res 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker 162*da0073e9SAndroid Build Coastguard Workerdef _p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None: 163*da0073e9SAndroid Build Coastguard Worker """Alternate to ``assert`` when in the backward context to print the error message ``s`` since otherwise, it is swallowed.""" 164*da0073e9SAndroid Build Coastguard Worker if not cond: 165*da0073e9SAndroid Build Coastguard Worker print(s) 166*da0073e9SAndroid Build Coastguard Worker traceback.print_stack() 167*da0073e9SAndroid Build Coastguard Worker if raise_assertion_error: 168*da0073e9SAndroid Build Coastguard Worker raise AssertionError(s) 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Workerdef _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> None: 172*da0073e9SAndroid Build Coastguard Worker """ 173*da0073e9SAndroid Build Coastguard Worker Allocate storage for ``tensor`` with the given size. 174*da0073e9SAndroid Build Coastguard Worker 175*da0073e9SAndroid Build Coastguard Worker Returns: 176*da0073e9SAndroid Build Coastguard Worker bool: ``True`` if this method allocated storage and ``False`` if the 177*da0073e9SAndroid Build Coastguard Worker storage was already allocated. 178*da0073e9SAndroid Build Coastguard Worker """ 179*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 180*da0073e9SAndroid Build Coastguard Worker if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): 181*da0073e9SAndroid Build Coastguard Worker already_allocated = tensor._typed_storage()._size() == size.numel() 182*da0073e9SAndroid Build Coastguard Worker if not already_allocated: 183*da0073e9SAndroid Build Coastguard Worker tensor_storage_size = tensor._typed_storage()._size() 184*da0073e9SAndroid Build Coastguard Worker _p_assert( 185*da0073e9SAndroid Build Coastguard Worker tensor_storage_size == 0, 186*da0073e9SAndroid Build Coastguard Worker "Tensor storage should have been resized to be 0 but got PLACEHOLDEr", 187*da0073e9SAndroid Build Coastguard Worker ) 188*da0073e9SAndroid Build Coastguard Worker tensor._typed_storage()._resize_(size.numel()) 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Workerdef _free_storage(tensor: torch.Tensor): 192*da0073e9SAndroid Build Coastguard Worker """ 193*da0073e9SAndroid Build Coastguard Worker Frees the underlying storage of ``tensor``. 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker Returns: 196*da0073e9SAndroid Build Coastguard Worker bool: ``True`` if the method freed the storage and ``False`` if the 197*da0073e9SAndroid Build Coastguard Worker storage was already freed. 198*da0073e9SAndroid Build Coastguard Worker """ 199*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 200*da0073e9SAndroid Build Coastguard Worker if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): 201*da0073e9SAndroid Build Coastguard Worker already_freed = tensor._typed_storage()._size() == 0 202*da0073e9SAndroid Build Coastguard Worker if not already_freed: 203*da0073e9SAndroid Build Coastguard Worker _p_assert( 204*da0073e9SAndroid Build Coastguard Worker tensor.storage_offset() == 0, 205*da0073e9SAndroid Build Coastguard Worker "Freeing a tensor's storage is unsafe when it is not the sole occupant\n" 206*da0073e9SAndroid Build Coastguard Worker f"storage offset: {tensor.storage_offset()}\n" 207*da0073e9SAndroid Build Coastguard Worker f"storage size: {tensor._typed_storage()._size()}\n" 208*da0073e9SAndroid Build Coastguard Worker f"tensor shape: {tensor.shape}", 209*da0073e9SAndroid Build Coastguard Worker ) 210*da0073e9SAndroid Build Coastguard Worker tensor._typed_storage()._resize_(0) 211*da0073e9SAndroid Build Coastguard Worker 212*da0073e9SAndroid Build Coastguard Worker 213*da0073e9SAndroid Build Coastguard WorkerQ = TypeVar("Q") 214*da0073e9SAndroid Build Coastguard WorkerR = TypeVar("R", dict, list, tuple, set, OrderedDict, PackedSequence, Any) 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker 217*da0073e9SAndroid Build Coastguard Worker@overload 218*da0073e9SAndroid Build Coastguard Workerdef _apply_to_tensors(fn: Callable[[torch.Tensor], Q], container: torch.Tensor) -> Q: 219*da0073e9SAndroid Build Coastguard Worker ... 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker@overload 223*da0073e9SAndroid Build Coastguard Workerdef _apply_to_tensors(fn: Callable[[torch.Tensor], Any], container: R) -> R: 224*da0073e9SAndroid Build Coastguard Worker ... 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Workerdef _apply_to_tensors(fn, container): 228*da0073e9SAndroid Build Coastguard Worker """Recursively apply to all tensor in different kinds of container types.""" 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker def apply(x): 231*da0073e9SAndroid Build Coastguard Worker if isinstance(x, torch.Tensor): 232*da0073e9SAndroid Build Coastguard Worker return fn(x) 233*da0073e9SAndroid Build Coastguard Worker elif hasattr(x, "__dataclass_fields__"): 234*da0073e9SAndroid Build Coastguard Worker dc = dataclasses.replace(x) 235*da0073e9SAndroid Build Coastguard Worker changes = { 236*da0073e9SAndroid Build Coastguard Worker f.name: apply(getattr(dc, f.name)) for f in dataclasses.fields(dc) 237*da0073e9SAndroid Build Coastguard Worker } 238*da0073e9SAndroid Build Coastguard Worker return dataclasses.replace(dc, **changes) 239*da0073e9SAndroid Build Coastguard Worker elif isinstance(x, OrderedDict): 240*da0073e9SAndroid Build Coastguard Worker od = x.__class__() 241*da0073e9SAndroid Build Coastguard Worker for key, value in x.items(): 242*da0073e9SAndroid Build Coastguard Worker od[key] = apply(value) 243*da0073e9SAndroid Build Coastguard Worker return od 244*da0073e9SAndroid Build Coastguard Worker elif isinstance(x, PackedSequence): 245*da0073e9SAndroid Build Coastguard Worker apply(x.data) 246*da0073e9SAndroid Build Coastguard Worker return x 247*da0073e9SAndroid Build Coastguard Worker elif isinstance(x, dict): 248*da0073e9SAndroid Build Coastguard Worker return {key: apply(value) for key, value in x.items()} 249*da0073e9SAndroid Build Coastguard Worker elif _is_namedtuple(x): 250*da0073e9SAndroid Build Coastguard Worker res = (apply(el) for el in x) 251*da0073e9SAndroid Build Coastguard Worker return type(x)(*res) 252*da0073e9SAndroid Build Coastguard Worker elif isinstance(x, (list, tuple, set)): 253*da0073e9SAndroid Build Coastguard Worker return type(x)(apply(el) for el in x) 254*da0073e9SAndroid Build Coastguard Worker else: 255*da0073e9SAndroid Build Coastguard Worker return x 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Worker return apply(container) 258*da0073e9SAndroid Build Coastguard Worker 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Workerdef _to_kwargs( 261*da0073e9SAndroid Build Coastguard Worker inputs: Tuple[Any, ...], 262*da0073e9SAndroid Build Coastguard Worker kwargs: Optional[Dict[str, Any]], 263*da0073e9SAndroid Build Coastguard Worker target_device: torch.device, 264*da0073e9SAndroid Build Coastguard Worker use_side_stream_for_tensor_copies: bool, 265*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tuple[Any, ...], Tuple[Dict[str, Any], ...]]: 266*da0073e9SAndroid Build Coastguard Worker moved_inputs = ( 267*da0073e9SAndroid Build Coastguard Worker _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies) 268*da0073e9SAndroid Build Coastguard Worker if inputs 269*da0073e9SAndroid Build Coastguard Worker else [] 270*da0073e9SAndroid Build Coastguard Worker ) 271*da0073e9SAndroid Build Coastguard Worker moved_kwargs = ( 272*da0073e9SAndroid Build Coastguard Worker _recursive_to(kwargs, target_device, use_side_stream_for_tensor_copies) 273*da0073e9SAndroid Build Coastguard Worker if kwargs 274*da0073e9SAndroid Build Coastguard Worker else [] 275*da0073e9SAndroid Build Coastguard Worker ) 276*da0073e9SAndroid Build Coastguard Worker if len(moved_inputs) < len(moved_kwargs): 277*da0073e9SAndroid Build Coastguard Worker moved_inputs.extend([() for _ in range(len(moved_kwargs) - len(inputs))]) 278*da0073e9SAndroid Build Coastguard Worker elif len(moved_kwargs) < len(moved_inputs): 279*da0073e9SAndroid Build Coastguard Worker moved_kwargs.extend([{} for _ in range(len(moved_inputs) - len(moved_kwargs))]) 280*da0073e9SAndroid Build Coastguard Worker return tuple(moved_inputs), tuple(moved_kwargs) 281*da0073e9SAndroid Build Coastguard Worker 282*da0073e9SAndroid Build Coastguard Worker 283*da0073e9SAndroid Build Coastguard Workerdef _verify_param_shape_across_processes( 284*da0073e9SAndroid Build Coastguard Worker process_group: dist.ProcessGroup, 285*da0073e9SAndroid Build Coastguard Worker tensors: List[torch.Tensor], 286*da0073e9SAndroid Build Coastguard Worker logger: Optional["dist.Logger"] = None, 287*da0073e9SAndroid Build Coastguard Worker): 288*da0073e9SAndroid Build Coastguard Worker return dist._verify_params_across_processes(process_group, tensors, logger) 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Workerdef _sync_module_states( 292*da0073e9SAndroid Build Coastguard Worker module: nn.Module, 293*da0073e9SAndroid Build Coastguard Worker process_group: dist.ProcessGroup, 294*da0073e9SAndroid Build Coastguard Worker broadcast_bucket_size: int, 295*da0073e9SAndroid Build Coastguard Worker src: int, 296*da0073e9SAndroid Build Coastguard Worker params_and_buffers_to_ignore: Container[str], 297*da0073e9SAndroid Build Coastguard Worker broadcast_buffers: bool = True, 298*da0073e9SAndroid Build Coastguard Worker) -> None: 299*da0073e9SAndroid Build Coastguard Worker """ 300*da0073e9SAndroid Build Coastguard Worker Sync ``module``'s parameters and buffers state. 301*da0073e9SAndroid Build Coastguard Worker 302*da0073e9SAndroid Build Coastguard Worker Syncs ``module``'s parameters and buffers state so that all ranks contain 303*da0073e9SAndroid Build Coastguard Worker the same module state across all ranks. Note that this API assumes that all 304*da0073e9SAndroid Build Coastguard Worker parameter shapes are consistent before running the synchronization. This can 305*da0073e9SAndroid Build Coastguard Worker be checked with ``_verify_param_shape_across_processes``. 306*da0073e9SAndroid Build Coastguard Worker """ 307*da0073e9SAndroid Build Coastguard Worker module_states: List[torch.Tensor] = [] 308*da0073e9SAndroid Build Coastguard Worker for name, param in module.named_parameters(): 309*da0073e9SAndroid Build Coastguard Worker if name not in params_and_buffers_to_ignore: 310*da0073e9SAndroid Build Coastguard Worker module_states.append(param.detach()) 311*da0073e9SAndroid Build Coastguard Worker 312*da0073e9SAndroid Build Coastguard Worker if broadcast_buffers: 313*da0073e9SAndroid Build Coastguard Worker for name, buffer in module.named_buffers(): 314*da0073e9SAndroid Build Coastguard Worker if name not in params_and_buffers_to_ignore: 315*da0073e9SAndroid Build Coastguard Worker module_states.append(buffer.detach()) 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Worker _sync_params_and_buffers(process_group, module_states, broadcast_bucket_size, src) 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Workerdef _sync_params_and_buffers( 321*da0073e9SAndroid Build Coastguard Worker process_group: dist.ProcessGroup, 322*da0073e9SAndroid Build Coastguard Worker module_states: List[torch.Tensor], 323*da0073e9SAndroid Build Coastguard Worker broadcast_bucket_size: int, 324*da0073e9SAndroid Build Coastguard Worker src: int, 325*da0073e9SAndroid Build Coastguard Worker) -> None: 326*da0073e9SAndroid Build Coastguard Worker """Synchronize ``module_states`` (list of tensors) across all processes by broadcasting them from rank 0.""" 327*da0073e9SAndroid Build Coastguard Worker if len(module_states) > 0: 328*da0073e9SAndroid Build Coastguard Worker dist._broadcast_coalesced( 329*da0073e9SAndroid Build Coastguard Worker process_group, module_states, broadcast_bucket_size, src 330*da0073e9SAndroid Build Coastguard Worker ) 331*da0073e9SAndroid Build Coastguard Worker 332*da0073e9SAndroid Build Coastguard Worker 333*da0073e9SAndroid Build Coastguard Workerdef _replace_by_prefix( 334*da0073e9SAndroid Build Coastguard Worker state_dict: Dict[str, Any], 335*da0073e9SAndroid Build Coastguard Worker old_prefix: str, 336*da0073e9SAndroid Build Coastguard Worker new_prefix: str, 337*da0073e9SAndroid Build Coastguard Worker) -> None: 338*da0073e9SAndroid Build Coastguard Worker """ 339*da0073e9SAndroid Build Coastguard Worker Replace all keys that match a given old_prefix with a new_prefix (in-place). 340*da0073e9SAndroid Build Coastguard Worker 341*da0073e9SAndroid Build Coastguard Worker Usage:: 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard Worker state_dict = {"layer.xyz": torch.tensor(1)} 344*da0073e9SAndroid Build Coastguard Worker replace_by_prefix_(state_dict, "layer.", "module.layer.") 345*da0073e9SAndroid Build Coastguard Worker assert state_dict == {"module.layer.xyz": torch.tensor(1)} 346*da0073e9SAndroid Build Coastguard Worker """ 347*da0073e9SAndroid Build Coastguard Worker if old_prefix == new_prefix: 348*da0073e9SAndroid Build Coastguard Worker raise ValueError("old_prefix and new_prefix must be distinct") 349*da0073e9SAndroid Build Coastguard Worker for key in list(state_dict.keys()): 350*da0073e9SAndroid Build Coastguard Worker if not key.startswith(old_prefix): 351*da0073e9SAndroid Build Coastguard Worker continue 352*da0073e9SAndroid Build Coastguard Worker new_key = new_prefix + key[len(old_prefix) :] 353*da0073e9SAndroid Build Coastguard Worker state_dict[new_key] = state_dict[key] 354*da0073e9SAndroid Build Coastguard Worker del state_dict[key] 355*da0073e9SAndroid Build Coastguard Worker 356*da0073e9SAndroid Build Coastguard Worker 357*da0073e9SAndroid Build Coastguard Workerdef _data_ptr_allocated(tensor: torch.Tensor) -> bool: 358*da0073e9SAndroid Build Coastguard Worker return tensor.untyped_storage().data_ptr() > 0 359*da0073e9SAndroid Build Coastguard Worker 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Workerdef _get_root_modules(modules: List[nn.Module]) -> List[nn.Module]: 362*da0073e9SAndroid Build Coastguard Worker """ 363*da0073e9SAndroid Build Coastguard Worker Returns the modules in ``modules`` that are root modules (i.e. 364*da0073e9SAndroid Build Coastguard Worker parent-less) with respect to the set ``modules``. In other words, these 365*da0073e9SAndroid Build Coastguard Worker are the modules in ``modules`` that are the not child of any other 366*da0073e9SAndroid Build Coastguard Worker module in ``modules``. 367*da0073e9SAndroid Build Coastguard Worker """ 368*da0073e9SAndroid Build Coastguard Worker root_modules: List[nn.Module] = [] 369*da0073e9SAndroid Build Coastguard Worker module_to_modules: Dict[nn.Module, Set[nn.Module]] = { 370*da0073e9SAndroid Build Coastguard Worker module: set(module.modules()) for module in modules 371*da0073e9SAndroid Build Coastguard Worker } 372*da0073e9SAndroid Build Coastguard Worker for candidate_module in modules: 373*da0073e9SAndroid Build Coastguard Worker is_root_module = True 374*da0073e9SAndroid Build Coastguard Worker for module, _modules in module_to_modules.items(): 375*da0073e9SAndroid Build Coastguard Worker is_child_module = ( 376*da0073e9SAndroid Build Coastguard Worker candidate_module is not module and candidate_module in _modules 377*da0073e9SAndroid Build Coastguard Worker ) 378*da0073e9SAndroid Build Coastguard Worker if is_child_module: 379*da0073e9SAndroid Build Coastguard Worker is_root_module = False 380*da0073e9SAndroid Build Coastguard Worker break 381*da0073e9SAndroid Build Coastguard Worker if is_root_module: 382*da0073e9SAndroid Build Coastguard Worker root_modules.append(candidate_module) 383*da0073e9SAndroid Build Coastguard Worker return root_modules 384