xref: /aosp_15_r20/external/pytorch/torch/distributed/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport dataclasses
3*da0073e9SAndroid Build Coastguard Workerimport traceback
4*da0073e9SAndroid Build Coastguard Workerfrom typing import (
5*da0073e9SAndroid Build Coastguard Worker    Any,
6*da0073e9SAndroid Build Coastguard Worker    Callable,
7*da0073e9SAndroid Build Coastguard Worker    Container,
8*da0073e9SAndroid Build Coastguard Worker    Dict,
9*da0073e9SAndroid Build Coastguard Worker    List,
10*da0073e9SAndroid Build Coastguard Worker    Optional,
11*da0073e9SAndroid Build Coastguard Worker    OrderedDict,
12*da0073e9SAndroid Build Coastguard Worker    overload,
13*da0073e9SAndroid Build Coastguard Worker    Set,
14*da0073e9SAndroid Build Coastguard Worker    Tuple,
15*da0073e9SAndroid Build Coastguard Worker    TypeVar,
16*da0073e9SAndroid Build Coastguard Worker)
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Workerimport torch
19*da0073e9SAndroid Build Coastguard Workerimport torch.distributed as dist
20*da0073e9SAndroid Build Coastguard Workerfrom torch import nn
21*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.parallel._functions import _get_stream
22*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.parallel.scatter_gather import _is_namedtuple
23*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.utils.rnn import PackedSequence
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker__all__ = []  # type: ignore[var-annotated]
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Workerdef _pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[Tuple[Any, ...], Tuple[str, ...]]:
30*da0073e9SAndroid Build Coastguard Worker    """
31*da0073e9SAndroid Build Coastguard Worker    Turn argument list into separate key list and value list (unpack_kwargs does the opposite).
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker    Inspiration: https://github.com/facebookresearch/fairscale/blob/eeb6684/fairscale/internal/containers.py#L70
34*da0073e9SAndroid Build Coastguard Worker    Usage::
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker        kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4)
37*da0073e9SAndroid Build Coastguard Worker        assert kwarg_keys == ("a", "b")
38*da0073e9SAndroid Build Coastguard Worker        assert flat_args == (1, 2, 3, 4)
39*da0073e9SAndroid Build Coastguard Worker        args, kwargs = unpack_kwargs(kwarg_keys, flat_args)
40*da0073e9SAndroid Build Coastguard Worker        assert args == (1, 2)
41*da0073e9SAndroid Build Coastguard Worker        assert kwargs == {"a": 3, "b": 4}
42*da0073e9SAndroid Build Coastguard Worker    Returns:
43*da0073e9SAndroid Build Coastguard Worker        Tuple[Tuple[Any, ...], Tuple[str, ...]]: The first tuple element gives
44*da0073e9SAndroid Build Coastguard Worker        gives both positional args and kwarg values, where the positional args
45*da0073e9SAndroid Build Coastguard Worker        proceed kwarg values and kwarg values are ordered consistently with the
46*da0073e9SAndroid Build Coastguard Worker        kwarg keys. The second tuple element gives the kwarg keys.
47*da0073e9SAndroid Build Coastguard Worker        The second tuple element's length is at most the first tuple element's length.
48*da0073e9SAndroid Build Coastguard Worker    """
49*da0073e9SAndroid Build Coastguard Worker    kwarg_keys: List[str] = []
50*da0073e9SAndroid Build Coastguard Worker    flat_args: List[Any] = list(args)
51*da0073e9SAndroid Build Coastguard Worker    for k, v in kwargs.items():
52*da0073e9SAndroid Build Coastguard Worker        kwarg_keys.append(k)
53*da0073e9SAndroid Build Coastguard Worker        flat_args.append(v)
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker    return tuple(flat_args), tuple(kwarg_keys)
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Workerdef _cast_forward_inputs(
59*da0073e9SAndroid Build Coastguard Worker    dtype: Optional[torch.dtype],
60*da0073e9SAndroid Build Coastguard Worker    *args: Any,
61*da0073e9SAndroid Build Coastguard Worker    **kwargs: Any,
62*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Any, Any]:
63*da0073e9SAndroid Build Coastguard Worker    """
64*da0073e9SAndroid Build Coastguard Worker    Cast floating point tensors in ``args`` and ``kwargs`` to ``input_dtype``.
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker    This respects the existing ``requires_grad`` on the tensors.
67*da0073e9SAndroid Build Coastguard Worker    """
68*da0073e9SAndroid Build Coastguard Worker    if dtype is None:
69*da0073e9SAndroid Build Coastguard Worker        return args, kwargs
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker    def cast_fn(x: torch.Tensor) -> torch.Tensor:
72*da0073e9SAndroid Build Coastguard Worker        if not torch.is_floating_point(x) or x.dtype == dtype:
73*da0073e9SAndroid Build Coastguard Worker            return x
74*da0073e9SAndroid Build Coastguard Worker        return x.to(dtype)
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker    return (_apply_to_tensors(cast_fn, args), _apply_to_tensors(cast_fn, kwargs))
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Workerdef _unpack_kwargs(
80*da0073e9SAndroid Build Coastguard Worker    flat_args: Tuple[Any, ...], kwarg_keys: Tuple[str, ...]
81*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
82*da0073e9SAndroid Build Coastguard Worker    """See _pack_kwargs."""
83*da0073e9SAndroid Build Coastguard Worker    assert len(kwarg_keys) <= len(
84*da0073e9SAndroid Build Coastguard Worker        flat_args
85*da0073e9SAndroid Build Coastguard Worker    ), f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}"
86*da0073e9SAndroid Build Coastguard Worker    if len(kwarg_keys) == 0:
87*da0073e9SAndroid Build Coastguard Worker        return flat_args, {}
88*da0073e9SAndroid Build Coastguard Worker    args = flat_args[: -len(kwarg_keys)]
89*da0073e9SAndroid Build Coastguard Worker    kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys) :]))
90*da0073e9SAndroid Build Coastguard Worker    return args, kwargs
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard WorkerS = TypeVar("S", dict, list, tuple)
94*da0073e9SAndroid Build Coastguard WorkerT = TypeVar("T", torch.Tensor, PackedSequence)
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker@overload
98*da0073e9SAndroid Build Coastguard Workerdef _recursive_to(
99*da0073e9SAndroid Build Coastguard Worker    inputs: S, target_device: torch.device, use_side_stream_for_tensor_copies: bool
100*da0073e9SAndroid Build Coastguard Worker) -> List[S]:
101*da0073e9SAndroid Build Coastguard Worker    ...
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker@overload
105*da0073e9SAndroid Build Coastguard Workerdef _recursive_to(
106*da0073e9SAndroid Build Coastguard Worker    inputs: T, target_device: torch.device, use_side_stream_for_tensor_copies: bool
107*da0073e9SAndroid Build Coastguard Worker) -> Tuple[T]:
108*da0073e9SAndroid Build Coastguard Worker    ...
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Workerdef _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies):
112*da0073e9SAndroid Build Coastguard Worker    r"""Recursively moves input to the target_device."""
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker    def to_map(obj):
115*da0073e9SAndroid Build Coastguard Worker        if isinstance(obj, (torch.Tensor, PackedSequence)):
116*da0073e9SAndroid Build Coastguard Worker            device = obj.data.device if isinstance(obj, PackedSequence) else obj.device
117*da0073e9SAndroid Build Coastguard Worker            if device == target_device:
118*da0073e9SAndroid Build Coastguard Worker                return (obj,)
119*da0073e9SAndroid Build Coastguard Worker            if not use_side_stream_for_tensor_copies:
120*da0073e9SAndroid Build Coastguard Worker                return (obj.to(target_device),)
121*da0073e9SAndroid Build Coastguard Worker            else:
122*da0073e9SAndroid Build Coastguard Worker                # If the custom module is not registered to torch, stream is not used for acceleration
123*da0073e9SAndroid Build Coastguard Worker                device_mod = getattr(torch, device.type, None)
124*da0073e9SAndroid Build Coastguard Worker                if device.type == "cpu" or device_mod is None:
125*da0073e9SAndroid Build Coastguard Worker                    return (obj.to(target_device),)
126*da0073e9SAndroid Build Coastguard Worker                # Perform CPU -> target_device copies in a background stream. This code is
127*da0073e9SAndroid Build Coastguard Worker                # motivated from similar logic in torch/nn/parallel/_functions.py
128*da0073e9SAndroid Build Coastguard Worker                stream = _get_stream(target_device)
129*da0073e9SAndroid Build Coastguard Worker                with device_mod.stream(stream):
130*da0073e9SAndroid Build Coastguard Worker                    output = obj.to(target_device)
131*da0073e9SAndroid Build Coastguard Worker                # synchronize with the copy stream
132*da0073e9SAndroid Build Coastguard Worker                with device_mod.device(target_device.index):
133*da0073e9SAndroid Build Coastguard Worker                    current_stream = device_mod.current_stream()
134*da0073e9SAndroid Build Coastguard Worker                    # Sync the current stream with the copy stream
135*da0073e9SAndroid Build Coastguard Worker                    current_stream.wait_stream(stream)
136*da0073e9SAndroid Build Coastguard Worker                    # Ensure tensor memory is not reused until work on
137*da0073e9SAndroid Build Coastguard Worker                    # main stream is complete
138*da0073e9SAndroid Build Coastguard Worker                    if isinstance(obj, PackedSequence):
139*da0073e9SAndroid Build Coastguard Worker                        output.data.record_stream(current_stream)  # type: ignore[arg-type]
140*da0073e9SAndroid Build Coastguard Worker                    else:
141*da0073e9SAndroid Build Coastguard Worker                        assert isinstance(output, torch.Tensor)
142*da0073e9SAndroid Build Coastguard Worker                        output.record_stream(current_stream)  # type: ignore[arg-type]
143*da0073e9SAndroid Build Coastguard Worker                return (output,)
144*da0073e9SAndroid Build Coastguard Worker        if _is_namedtuple(obj):
145*da0073e9SAndroid Build Coastguard Worker            return [type(obj)(*args) for args in zip(*map(to_map, obj))]
146*da0073e9SAndroid Build Coastguard Worker        if isinstance(obj, tuple) and len(obj) > 0:
147*da0073e9SAndroid Build Coastguard Worker            return list(zip(*map(to_map, obj)))
148*da0073e9SAndroid Build Coastguard Worker        if isinstance(obj, list) and len(obj) > 0:
149*da0073e9SAndroid Build Coastguard Worker            return [list(i) for i in zip(*map(to_map, obj))]
150*da0073e9SAndroid Build Coastguard Worker        if isinstance(obj, dict) and len(obj) > 0:
151*da0073e9SAndroid Build Coastguard Worker            return [type(obj)(i) for i in zip(*map(to_map, obj.items()))]
152*da0073e9SAndroid Build Coastguard Worker        return [obj]
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker    # Avoid reference cycle
155*da0073e9SAndroid Build Coastguard Worker    try:
156*da0073e9SAndroid Build Coastguard Worker        res = to_map(inputs)
157*da0073e9SAndroid Build Coastguard Worker    finally:
158*da0073e9SAndroid Build Coastguard Worker        to_map = None  # type: ignore[assignment]
159*da0073e9SAndroid Build Coastguard Worker    return res
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker
162*da0073e9SAndroid Build Coastguard Workerdef _p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None:
163*da0073e9SAndroid Build Coastguard Worker    """Alternate to ``assert`` when in the backward context to print the error message ``s`` since otherwise, it is swallowed."""
164*da0073e9SAndroid Build Coastguard Worker    if not cond:
165*da0073e9SAndroid Build Coastguard Worker        print(s)
166*da0073e9SAndroid Build Coastguard Worker        traceback.print_stack()
167*da0073e9SAndroid Build Coastguard Worker        if raise_assertion_error:
168*da0073e9SAndroid Build Coastguard Worker            raise AssertionError(s)
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Workerdef _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> None:
172*da0073e9SAndroid Build Coastguard Worker    """
173*da0073e9SAndroid Build Coastguard Worker    Allocate storage for ``tensor`` with the given size.
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker    Returns:
176*da0073e9SAndroid Build Coastguard Worker        bool: ``True`` if this method allocated storage and ``False`` if the
177*da0073e9SAndroid Build Coastguard Worker        storage was already allocated.
178*da0073e9SAndroid Build Coastguard Worker    """
179*da0073e9SAndroid Build Coastguard Worker    with torch.no_grad():
180*da0073e9SAndroid Build Coastguard Worker        if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
181*da0073e9SAndroid Build Coastguard Worker            already_allocated = tensor._typed_storage()._size() == size.numel()
182*da0073e9SAndroid Build Coastguard Worker            if not already_allocated:
183*da0073e9SAndroid Build Coastguard Worker                tensor_storage_size = tensor._typed_storage()._size()
184*da0073e9SAndroid Build Coastguard Worker                _p_assert(
185*da0073e9SAndroid Build Coastguard Worker                    tensor_storage_size == 0,
186*da0073e9SAndroid Build Coastguard Worker                    "Tensor storage should have been resized to be 0 but got PLACEHOLDEr",
187*da0073e9SAndroid Build Coastguard Worker                )
188*da0073e9SAndroid Build Coastguard Worker                tensor._typed_storage()._resize_(size.numel())
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Workerdef _free_storage(tensor: torch.Tensor):
192*da0073e9SAndroid Build Coastguard Worker    """
193*da0073e9SAndroid Build Coastguard Worker    Frees the underlying storage of ``tensor``.
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker    Returns:
196*da0073e9SAndroid Build Coastguard Worker        bool: ``True`` if the method freed the storage and ``False`` if the
197*da0073e9SAndroid Build Coastguard Worker        storage was already freed.
198*da0073e9SAndroid Build Coastguard Worker    """
199*da0073e9SAndroid Build Coastguard Worker    with torch.no_grad():
200*da0073e9SAndroid Build Coastguard Worker        if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
201*da0073e9SAndroid Build Coastguard Worker            already_freed = tensor._typed_storage()._size() == 0
202*da0073e9SAndroid Build Coastguard Worker            if not already_freed:
203*da0073e9SAndroid Build Coastguard Worker                _p_assert(
204*da0073e9SAndroid Build Coastguard Worker                    tensor.storage_offset() == 0,
205*da0073e9SAndroid Build Coastguard Worker                    "Freeing a tensor's storage is unsafe when it is not the sole occupant\n"
206*da0073e9SAndroid Build Coastguard Worker                    f"storage offset: {tensor.storage_offset()}\n"
207*da0073e9SAndroid Build Coastguard Worker                    f"storage size: {tensor._typed_storage()._size()}\n"
208*da0073e9SAndroid Build Coastguard Worker                    f"tensor shape: {tensor.shape}",
209*da0073e9SAndroid Build Coastguard Worker                )
210*da0073e9SAndroid Build Coastguard Worker                tensor._typed_storage()._resize_(0)
211*da0073e9SAndroid Build Coastguard Worker
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard WorkerQ = TypeVar("Q")
214*da0073e9SAndroid Build Coastguard WorkerR = TypeVar("R", dict, list, tuple, set, OrderedDict, PackedSequence, Any)
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker
217*da0073e9SAndroid Build Coastguard Worker@overload
218*da0073e9SAndroid Build Coastguard Workerdef _apply_to_tensors(fn: Callable[[torch.Tensor], Q], container: torch.Tensor) -> Q:
219*da0073e9SAndroid Build Coastguard Worker    ...
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker@overload
223*da0073e9SAndroid Build Coastguard Workerdef _apply_to_tensors(fn: Callable[[torch.Tensor], Any], container: R) -> R:
224*da0073e9SAndroid Build Coastguard Worker    ...
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Workerdef _apply_to_tensors(fn, container):
228*da0073e9SAndroid Build Coastguard Worker    """Recursively apply to all tensor in different kinds of container types."""
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker    def apply(x):
231*da0073e9SAndroid Build Coastguard Worker        if isinstance(x, torch.Tensor):
232*da0073e9SAndroid Build Coastguard Worker            return fn(x)
233*da0073e9SAndroid Build Coastguard Worker        elif hasattr(x, "__dataclass_fields__"):
234*da0073e9SAndroid Build Coastguard Worker            dc = dataclasses.replace(x)
235*da0073e9SAndroid Build Coastguard Worker            changes = {
236*da0073e9SAndroid Build Coastguard Worker                f.name: apply(getattr(dc, f.name)) for f in dataclasses.fields(dc)
237*da0073e9SAndroid Build Coastguard Worker            }
238*da0073e9SAndroid Build Coastguard Worker            return dataclasses.replace(dc, **changes)
239*da0073e9SAndroid Build Coastguard Worker        elif isinstance(x, OrderedDict):
240*da0073e9SAndroid Build Coastguard Worker            od = x.__class__()
241*da0073e9SAndroid Build Coastguard Worker            for key, value in x.items():
242*da0073e9SAndroid Build Coastguard Worker                od[key] = apply(value)
243*da0073e9SAndroid Build Coastguard Worker            return od
244*da0073e9SAndroid Build Coastguard Worker        elif isinstance(x, PackedSequence):
245*da0073e9SAndroid Build Coastguard Worker            apply(x.data)
246*da0073e9SAndroid Build Coastguard Worker            return x
247*da0073e9SAndroid Build Coastguard Worker        elif isinstance(x, dict):
248*da0073e9SAndroid Build Coastguard Worker            return {key: apply(value) for key, value in x.items()}
249*da0073e9SAndroid Build Coastguard Worker        elif _is_namedtuple(x):
250*da0073e9SAndroid Build Coastguard Worker            res = (apply(el) for el in x)
251*da0073e9SAndroid Build Coastguard Worker            return type(x)(*res)
252*da0073e9SAndroid Build Coastguard Worker        elif isinstance(x, (list, tuple, set)):
253*da0073e9SAndroid Build Coastguard Worker            return type(x)(apply(el) for el in x)
254*da0073e9SAndroid Build Coastguard Worker        else:
255*da0073e9SAndroid Build Coastguard Worker            return x
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Worker    return apply(container)
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Workerdef _to_kwargs(
261*da0073e9SAndroid Build Coastguard Worker    inputs: Tuple[Any, ...],
262*da0073e9SAndroid Build Coastguard Worker    kwargs: Optional[Dict[str, Any]],
263*da0073e9SAndroid Build Coastguard Worker    target_device: torch.device,
264*da0073e9SAndroid Build Coastguard Worker    use_side_stream_for_tensor_copies: bool,
265*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tuple[Any, ...], Tuple[Dict[str, Any], ...]]:
266*da0073e9SAndroid Build Coastguard Worker    moved_inputs = (
267*da0073e9SAndroid Build Coastguard Worker        _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies)
268*da0073e9SAndroid Build Coastguard Worker        if inputs
269*da0073e9SAndroid Build Coastguard Worker        else []
270*da0073e9SAndroid Build Coastguard Worker    )
271*da0073e9SAndroid Build Coastguard Worker    moved_kwargs = (
272*da0073e9SAndroid Build Coastguard Worker        _recursive_to(kwargs, target_device, use_side_stream_for_tensor_copies)
273*da0073e9SAndroid Build Coastguard Worker        if kwargs
274*da0073e9SAndroid Build Coastguard Worker        else []
275*da0073e9SAndroid Build Coastguard Worker    )
276*da0073e9SAndroid Build Coastguard Worker    if len(moved_inputs) < len(moved_kwargs):
277*da0073e9SAndroid Build Coastguard Worker        moved_inputs.extend([() for _ in range(len(moved_kwargs) - len(inputs))])
278*da0073e9SAndroid Build Coastguard Worker    elif len(moved_kwargs) < len(moved_inputs):
279*da0073e9SAndroid Build Coastguard Worker        moved_kwargs.extend([{} for _ in range(len(moved_inputs) - len(moved_kwargs))])
280*da0073e9SAndroid Build Coastguard Worker    return tuple(moved_inputs), tuple(moved_kwargs)
281*da0073e9SAndroid Build Coastguard Worker
282*da0073e9SAndroid Build Coastguard Worker
283*da0073e9SAndroid Build Coastguard Workerdef _verify_param_shape_across_processes(
284*da0073e9SAndroid Build Coastguard Worker    process_group: dist.ProcessGroup,
285*da0073e9SAndroid Build Coastguard Worker    tensors: List[torch.Tensor],
286*da0073e9SAndroid Build Coastguard Worker    logger: Optional["dist.Logger"] = None,
287*da0073e9SAndroid Build Coastguard Worker):
288*da0073e9SAndroid Build Coastguard Worker    return dist._verify_params_across_processes(process_group, tensors, logger)
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Workerdef _sync_module_states(
292*da0073e9SAndroid Build Coastguard Worker    module: nn.Module,
293*da0073e9SAndroid Build Coastguard Worker    process_group: dist.ProcessGroup,
294*da0073e9SAndroid Build Coastguard Worker    broadcast_bucket_size: int,
295*da0073e9SAndroid Build Coastguard Worker    src: int,
296*da0073e9SAndroid Build Coastguard Worker    params_and_buffers_to_ignore: Container[str],
297*da0073e9SAndroid Build Coastguard Worker    broadcast_buffers: bool = True,
298*da0073e9SAndroid Build Coastguard Worker) -> None:
299*da0073e9SAndroid Build Coastguard Worker    """
300*da0073e9SAndroid Build Coastguard Worker    Sync ``module``'s parameters and buffers state.
301*da0073e9SAndroid Build Coastguard Worker
302*da0073e9SAndroid Build Coastguard Worker    Syncs ``module``'s parameters and buffers state so that all ranks contain
303*da0073e9SAndroid Build Coastguard Worker    the same module state across all ranks. Note that this API assumes that all
304*da0073e9SAndroid Build Coastguard Worker    parameter shapes are consistent before running the synchronization. This can
305*da0073e9SAndroid Build Coastguard Worker    be checked with ``_verify_param_shape_across_processes``.
306*da0073e9SAndroid Build Coastguard Worker    """
307*da0073e9SAndroid Build Coastguard Worker    module_states: List[torch.Tensor] = []
308*da0073e9SAndroid Build Coastguard Worker    for name, param in module.named_parameters():
309*da0073e9SAndroid Build Coastguard Worker        if name not in params_and_buffers_to_ignore:
310*da0073e9SAndroid Build Coastguard Worker            module_states.append(param.detach())
311*da0073e9SAndroid Build Coastguard Worker
312*da0073e9SAndroid Build Coastguard Worker    if broadcast_buffers:
313*da0073e9SAndroid Build Coastguard Worker        for name, buffer in module.named_buffers():
314*da0073e9SAndroid Build Coastguard Worker            if name not in params_and_buffers_to_ignore:
315*da0073e9SAndroid Build Coastguard Worker                module_states.append(buffer.detach())
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker    _sync_params_and_buffers(process_group, module_states, broadcast_bucket_size, src)
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Workerdef _sync_params_and_buffers(
321*da0073e9SAndroid Build Coastguard Worker    process_group: dist.ProcessGroup,
322*da0073e9SAndroid Build Coastguard Worker    module_states: List[torch.Tensor],
323*da0073e9SAndroid Build Coastguard Worker    broadcast_bucket_size: int,
324*da0073e9SAndroid Build Coastguard Worker    src: int,
325*da0073e9SAndroid Build Coastguard Worker) -> None:
326*da0073e9SAndroid Build Coastguard Worker    """Synchronize ``module_states`` (list of tensors) across all processes by broadcasting them from rank 0."""
327*da0073e9SAndroid Build Coastguard Worker    if len(module_states) > 0:
328*da0073e9SAndroid Build Coastguard Worker        dist._broadcast_coalesced(
329*da0073e9SAndroid Build Coastguard Worker            process_group, module_states, broadcast_bucket_size, src
330*da0073e9SAndroid Build Coastguard Worker        )
331*da0073e9SAndroid Build Coastguard Worker
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Workerdef _replace_by_prefix(
334*da0073e9SAndroid Build Coastguard Worker    state_dict: Dict[str, Any],
335*da0073e9SAndroid Build Coastguard Worker    old_prefix: str,
336*da0073e9SAndroid Build Coastguard Worker    new_prefix: str,
337*da0073e9SAndroid Build Coastguard Worker) -> None:
338*da0073e9SAndroid Build Coastguard Worker    """
339*da0073e9SAndroid Build Coastguard Worker    Replace all keys that match a given old_prefix with a new_prefix (in-place).
340*da0073e9SAndroid Build Coastguard Worker
341*da0073e9SAndroid Build Coastguard Worker    Usage::
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Worker        state_dict = {"layer.xyz": torch.tensor(1)}
344*da0073e9SAndroid Build Coastguard Worker        replace_by_prefix_(state_dict, "layer.", "module.layer.")
345*da0073e9SAndroid Build Coastguard Worker        assert state_dict == {"module.layer.xyz": torch.tensor(1)}
346*da0073e9SAndroid Build Coastguard Worker    """
347*da0073e9SAndroid Build Coastguard Worker    if old_prefix == new_prefix:
348*da0073e9SAndroid Build Coastguard Worker        raise ValueError("old_prefix and new_prefix must be distinct")
349*da0073e9SAndroid Build Coastguard Worker    for key in list(state_dict.keys()):
350*da0073e9SAndroid Build Coastguard Worker        if not key.startswith(old_prefix):
351*da0073e9SAndroid Build Coastguard Worker            continue
352*da0073e9SAndroid Build Coastguard Worker        new_key = new_prefix + key[len(old_prefix) :]
353*da0073e9SAndroid Build Coastguard Worker        state_dict[new_key] = state_dict[key]
354*da0073e9SAndroid Build Coastguard Worker        del state_dict[key]
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard Workerdef _data_ptr_allocated(tensor: torch.Tensor) -> bool:
358*da0073e9SAndroid Build Coastguard Worker    return tensor.untyped_storage().data_ptr() > 0
359*da0073e9SAndroid Build Coastguard Worker
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Workerdef _get_root_modules(modules: List[nn.Module]) -> List[nn.Module]:
362*da0073e9SAndroid Build Coastguard Worker    """
363*da0073e9SAndroid Build Coastguard Worker    Returns the modules in ``modules`` that are root modules (i.e.
364*da0073e9SAndroid Build Coastguard Worker    parent-less) with respect to the set ``modules``. In other words, these
365*da0073e9SAndroid Build Coastguard Worker    are the modules in ``modules`` that are the not child of any other
366*da0073e9SAndroid Build Coastguard Worker    module in ``modules``.
367*da0073e9SAndroid Build Coastguard Worker    """
368*da0073e9SAndroid Build Coastguard Worker    root_modules: List[nn.Module] = []
369*da0073e9SAndroid Build Coastguard Worker    module_to_modules: Dict[nn.Module, Set[nn.Module]] = {
370*da0073e9SAndroid Build Coastguard Worker        module: set(module.modules()) for module in modules
371*da0073e9SAndroid Build Coastguard Worker    }
372*da0073e9SAndroid Build Coastguard Worker    for candidate_module in modules:
373*da0073e9SAndroid Build Coastguard Worker        is_root_module = True
374*da0073e9SAndroid Build Coastguard Worker        for module, _modules in module_to_modules.items():
375*da0073e9SAndroid Build Coastguard Worker            is_child_module = (
376*da0073e9SAndroid Build Coastguard Worker                candidate_module is not module and candidate_module in _modules
377*da0073e9SAndroid Build Coastguard Worker            )
378*da0073e9SAndroid Build Coastguard Worker            if is_child_module:
379*da0073e9SAndroid Build Coastguard Worker                is_root_module = False
380*da0073e9SAndroid Build Coastguard Worker                break
381*da0073e9SAndroid Build Coastguard Worker        if is_root_module:
382*da0073e9SAndroid Build Coastguard Worker            root_modules.append(candidate_module)
383*da0073e9SAndroid Build Coastguard Worker    return root_modules
384