xref: /aosp_15_r20/external/pytorch/torch/distributed/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import dataclasses
3import traceback
4from typing import (
5    Any,
6    Callable,
7    Container,
8    Dict,
9    List,
10    Optional,
11    OrderedDict,
12    overload,
13    Set,
14    Tuple,
15    TypeVar,
16)
17
18import torch
19import torch.distributed as dist
20from torch import nn
21from torch.nn.parallel._functions import _get_stream
22from torch.nn.parallel.scatter_gather import _is_namedtuple
23from torch.nn.utils.rnn import PackedSequence
24
25
26__all__ = []  # type: ignore[var-annotated]
27
28
29def _pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[Tuple[Any, ...], Tuple[str, ...]]:
30    """
31    Turn argument list into separate key list and value list (unpack_kwargs does the opposite).
32
33    Inspiration: https://github.com/facebookresearch/fairscale/blob/eeb6684/fairscale/internal/containers.py#L70
34    Usage::
35
36        kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4)
37        assert kwarg_keys == ("a", "b")
38        assert flat_args == (1, 2, 3, 4)
39        args, kwargs = unpack_kwargs(kwarg_keys, flat_args)
40        assert args == (1, 2)
41        assert kwargs == {"a": 3, "b": 4}
42    Returns:
43        Tuple[Tuple[Any, ...], Tuple[str, ...]]: The first tuple element gives
44        gives both positional args and kwarg values, where the positional args
45        proceed kwarg values and kwarg values are ordered consistently with the
46        kwarg keys. The second tuple element gives the kwarg keys.
47        The second tuple element's length is at most the first tuple element's length.
48    """
49    kwarg_keys: List[str] = []
50    flat_args: List[Any] = list(args)
51    for k, v in kwargs.items():
52        kwarg_keys.append(k)
53        flat_args.append(v)
54
55    return tuple(flat_args), tuple(kwarg_keys)
56
57
58def _cast_forward_inputs(
59    dtype: Optional[torch.dtype],
60    *args: Any,
61    **kwargs: Any,
62) -> Tuple[Any, Any]:
63    """
64    Cast floating point tensors in ``args`` and ``kwargs`` to ``input_dtype``.
65
66    This respects the existing ``requires_grad`` on the tensors.
67    """
68    if dtype is None:
69        return args, kwargs
70
71    def cast_fn(x: torch.Tensor) -> torch.Tensor:
72        if not torch.is_floating_point(x) or x.dtype == dtype:
73            return x
74        return x.to(dtype)
75
76    return (_apply_to_tensors(cast_fn, args), _apply_to_tensors(cast_fn, kwargs))
77
78
79def _unpack_kwargs(
80    flat_args: Tuple[Any, ...], kwarg_keys: Tuple[str, ...]
81) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
82    """See _pack_kwargs."""
83    assert len(kwarg_keys) <= len(
84        flat_args
85    ), f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}"
86    if len(kwarg_keys) == 0:
87        return flat_args, {}
88    args = flat_args[: -len(kwarg_keys)]
89    kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys) :]))
90    return args, kwargs
91
92
93S = TypeVar("S", dict, list, tuple)
94T = TypeVar("T", torch.Tensor, PackedSequence)
95
96
97@overload
98def _recursive_to(
99    inputs: S, target_device: torch.device, use_side_stream_for_tensor_copies: bool
100) -> List[S]:
101    ...
102
103
104@overload
105def _recursive_to(
106    inputs: T, target_device: torch.device, use_side_stream_for_tensor_copies: bool
107) -> Tuple[T]:
108    ...
109
110
111def _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies):
112    r"""Recursively moves input to the target_device."""
113
114    def to_map(obj):
115        if isinstance(obj, (torch.Tensor, PackedSequence)):
116            device = obj.data.device if isinstance(obj, PackedSequence) else obj.device
117            if device == target_device:
118                return (obj,)
119            if not use_side_stream_for_tensor_copies:
120                return (obj.to(target_device),)
121            else:
122                # If the custom module is not registered to torch, stream is not used for acceleration
123                device_mod = getattr(torch, device.type, None)
124                if device.type == "cpu" or device_mod is None:
125                    return (obj.to(target_device),)
126                # Perform CPU -> target_device copies in a background stream. This code is
127                # motivated from similar logic in torch/nn/parallel/_functions.py
128                stream = _get_stream(target_device)
129                with device_mod.stream(stream):
130                    output = obj.to(target_device)
131                # synchronize with the copy stream
132                with device_mod.device(target_device.index):
133                    current_stream = device_mod.current_stream()
134                    # Sync the current stream with the copy stream
135                    current_stream.wait_stream(stream)
136                    # Ensure tensor memory is not reused until work on
137                    # main stream is complete
138                    if isinstance(obj, PackedSequence):
139                        output.data.record_stream(current_stream)  # type: ignore[arg-type]
140                    else:
141                        assert isinstance(output, torch.Tensor)
142                        output.record_stream(current_stream)  # type: ignore[arg-type]
143                return (output,)
144        if _is_namedtuple(obj):
145            return [type(obj)(*args) for args in zip(*map(to_map, obj))]
146        if isinstance(obj, tuple) and len(obj) > 0:
147            return list(zip(*map(to_map, obj)))
148        if isinstance(obj, list) and len(obj) > 0:
149            return [list(i) for i in zip(*map(to_map, obj))]
150        if isinstance(obj, dict) and len(obj) > 0:
151            return [type(obj)(i) for i in zip(*map(to_map, obj.items()))]
152        return [obj]
153
154    # Avoid reference cycle
155    try:
156        res = to_map(inputs)
157    finally:
158        to_map = None  # type: ignore[assignment]
159    return res
160
161
162def _p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None:
163    """Alternate to ``assert`` when in the backward context to print the error message ``s`` since otherwise, it is swallowed."""
164    if not cond:
165        print(s)
166        traceback.print_stack()
167        if raise_assertion_error:
168            raise AssertionError(s)
169
170
171def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> None:
172    """
173    Allocate storage for ``tensor`` with the given size.
174
175    Returns:
176        bool: ``True`` if this method allocated storage and ``False`` if the
177        storage was already allocated.
178    """
179    with torch.no_grad():
180        if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
181            already_allocated = tensor._typed_storage()._size() == size.numel()
182            if not already_allocated:
183                tensor_storage_size = tensor._typed_storage()._size()
184                _p_assert(
185                    tensor_storage_size == 0,
186                    "Tensor storage should have been resized to be 0 but got PLACEHOLDEr",
187                )
188                tensor._typed_storage()._resize_(size.numel())
189
190
191def _free_storage(tensor: torch.Tensor):
192    """
193    Frees the underlying storage of ``tensor``.
194
195    Returns:
196        bool: ``True`` if the method freed the storage and ``False`` if the
197        storage was already freed.
198    """
199    with torch.no_grad():
200        if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
201            already_freed = tensor._typed_storage()._size() == 0
202            if not already_freed:
203                _p_assert(
204                    tensor.storage_offset() == 0,
205                    "Freeing a tensor's storage is unsafe when it is not the sole occupant\n"
206                    f"storage offset: {tensor.storage_offset()}\n"
207                    f"storage size: {tensor._typed_storage()._size()}\n"
208                    f"tensor shape: {tensor.shape}",
209                )
210                tensor._typed_storage()._resize_(0)
211
212
213Q = TypeVar("Q")
214R = TypeVar("R", dict, list, tuple, set, OrderedDict, PackedSequence, Any)
215
216
217@overload
218def _apply_to_tensors(fn: Callable[[torch.Tensor], Q], container: torch.Tensor) -> Q:
219    ...
220
221
222@overload
223def _apply_to_tensors(fn: Callable[[torch.Tensor], Any], container: R) -> R:
224    ...
225
226
227def _apply_to_tensors(fn, container):
228    """Recursively apply to all tensor in different kinds of container types."""
229
230    def apply(x):
231        if isinstance(x, torch.Tensor):
232            return fn(x)
233        elif hasattr(x, "__dataclass_fields__"):
234            dc = dataclasses.replace(x)
235            changes = {
236                f.name: apply(getattr(dc, f.name)) for f in dataclasses.fields(dc)
237            }
238            return dataclasses.replace(dc, **changes)
239        elif isinstance(x, OrderedDict):
240            od = x.__class__()
241            for key, value in x.items():
242                od[key] = apply(value)
243            return od
244        elif isinstance(x, PackedSequence):
245            apply(x.data)
246            return x
247        elif isinstance(x, dict):
248            return {key: apply(value) for key, value in x.items()}
249        elif _is_namedtuple(x):
250            res = (apply(el) for el in x)
251            return type(x)(*res)
252        elif isinstance(x, (list, tuple, set)):
253            return type(x)(apply(el) for el in x)
254        else:
255            return x
256
257    return apply(container)
258
259
260def _to_kwargs(
261    inputs: Tuple[Any, ...],
262    kwargs: Optional[Dict[str, Any]],
263    target_device: torch.device,
264    use_side_stream_for_tensor_copies: bool,
265) -> Tuple[Tuple[Any, ...], Tuple[Dict[str, Any], ...]]:
266    moved_inputs = (
267        _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies)
268        if inputs
269        else []
270    )
271    moved_kwargs = (
272        _recursive_to(kwargs, target_device, use_side_stream_for_tensor_copies)
273        if kwargs
274        else []
275    )
276    if len(moved_inputs) < len(moved_kwargs):
277        moved_inputs.extend([() for _ in range(len(moved_kwargs) - len(inputs))])
278    elif len(moved_kwargs) < len(moved_inputs):
279        moved_kwargs.extend([{} for _ in range(len(moved_inputs) - len(moved_kwargs))])
280    return tuple(moved_inputs), tuple(moved_kwargs)
281
282
283def _verify_param_shape_across_processes(
284    process_group: dist.ProcessGroup,
285    tensors: List[torch.Tensor],
286    logger: Optional["dist.Logger"] = None,
287):
288    return dist._verify_params_across_processes(process_group, tensors, logger)
289
290
291def _sync_module_states(
292    module: nn.Module,
293    process_group: dist.ProcessGroup,
294    broadcast_bucket_size: int,
295    src: int,
296    params_and_buffers_to_ignore: Container[str],
297    broadcast_buffers: bool = True,
298) -> None:
299    """
300    Sync ``module``'s parameters and buffers state.
301
302    Syncs ``module``'s parameters and buffers state so that all ranks contain
303    the same module state across all ranks. Note that this API assumes that all
304    parameter shapes are consistent before running the synchronization. This can
305    be checked with ``_verify_param_shape_across_processes``.
306    """
307    module_states: List[torch.Tensor] = []
308    for name, param in module.named_parameters():
309        if name not in params_and_buffers_to_ignore:
310            module_states.append(param.detach())
311
312    if broadcast_buffers:
313        for name, buffer in module.named_buffers():
314            if name not in params_and_buffers_to_ignore:
315                module_states.append(buffer.detach())
316
317    _sync_params_and_buffers(process_group, module_states, broadcast_bucket_size, src)
318
319
320def _sync_params_and_buffers(
321    process_group: dist.ProcessGroup,
322    module_states: List[torch.Tensor],
323    broadcast_bucket_size: int,
324    src: int,
325) -> None:
326    """Synchronize ``module_states`` (list of tensors) across all processes by broadcasting them from rank 0."""
327    if len(module_states) > 0:
328        dist._broadcast_coalesced(
329            process_group, module_states, broadcast_bucket_size, src
330        )
331
332
333def _replace_by_prefix(
334    state_dict: Dict[str, Any],
335    old_prefix: str,
336    new_prefix: str,
337) -> None:
338    """
339    Replace all keys that match a given old_prefix with a new_prefix (in-place).
340
341    Usage::
342
343        state_dict = {"layer.xyz": torch.tensor(1)}
344        replace_by_prefix_(state_dict, "layer.", "module.layer.")
345        assert state_dict == {"module.layer.xyz": torch.tensor(1)}
346    """
347    if old_prefix == new_prefix:
348        raise ValueError("old_prefix and new_prefix must be distinct")
349    for key in list(state_dict.keys()):
350        if not key.startswith(old_prefix):
351            continue
352        new_key = new_prefix + key[len(old_prefix) :]
353        state_dict[new_key] = state_dict[key]
354        del state_dict[key]
355
356
357def _data_ptr_allocated(tensor: torch.Tensor) -> bool:
358    return tensor.untyped_storage().data_ptr() > 0
359
360
361def _get_root_modules(modules: List[nn.Module]) -> List[nn.Module]:
362    """
363    Returns the modules in ``modules`` that are root modules (i.e.
364    parent-less) with respect to the set ``modules``. In other words, these
365    are the modules in ``modules`` that are the not child of any other
366    module in ``modules``.
367    """
368    root_modules: List[nn.Module] = []
369    module_to_modules: Dict[nn.Module, Set[nn.Module]] = {
370        module: set(module.modules()) for module in modules
371    }
372    for candidate_module in modules:
373        is_root_module = True
374        for module, _modules in module_to_modules.items():
375            is_child_module = (
376                candidate_module is not module and candidate_module in _modules
377            )
378            if is_child_module:
379                is_root_module = False
380                break
381        if is_root_module:
382            root_modules.append(candidate_module)
383    return root_modules
384