xref: /aosp_15_r20/external/pytorch/torch/nn/parallel/scatter_gather.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Any, Dict, List, Optional, overload, Sequence, Tuple, TypeVar, Union
3from typing_extensions import deprecated
4
5import torch
6from torch.nn.parallel._functions import Gather, Scatter
7
8
9__all__ = ["scatter", "scatter_kwargs", "gather"]
10
11
12@deprecated(
13    "`is_namedtuple` is deprecated, please use the python checks instead",
14    category=FutureWarning,
15)
16def is_namedtuple(obj: Any) -> bool:
17    # Check if type was created from collections.namedtuple or a typing.NamedTuple.
18    return _is_namedtuple(obj)
19
20
21def _is_namedtuple(obj: Any) -> bool:
22    # Check if type was created from collections.namedtuple or a typing.NamedTuple.
23    return (
24        isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
25    )
26
27
28T = TypeVar("T", dict, list, tuple)
29
30
31# For some reason, 'scatter' returns a tuple when given a single Tensor input but a list otherwise.
32@overload
33def scatter(
34    inputs: torch.Tensor,
35    target_gpus: Sequence[Union[int, torch.device]],
36    dim: int = ...,
37) -> Tuple[torch.Tensor, ...]:
38    ...
39
40
41@overload
42def scatter(
43    inputs: T,
44    target_gpus: Sequence[Union[int, torch.device]],
45    dim: int = ...,
46) -> List[T]:
47    ...
48
49
50def scatter(inputs, target_gpus, dim=0):
51    r"""Slice tensors into approximately equal chunks and distributes them across given GPUs.
52
53    Duplicates references to objects that are not tensors.
54    """
55
56    def scatter_map(obj):
57        if isinstance(obj, torch.Tensor):
58            return Scatter.apply(target_gpus, None, dim, obj)
59        if _is_namedtuple(obj):
60            return [type(obj)(*args) for args in zip(*map(scatter_map, obj))]
61        if isinstance(obj, tuple) and len(obj) > 0:
62            return list(zip(*map(scatter_map, obj)))
63        if isinstance(obj, list) and len(obj) > 0:
64            return [list(i) for i in zip(*map(scatter_map, obj))]
65        if isinstance(obj, dict) and len(obj) > 0:
66            return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))]
67        return [obj for _ in target_gpus]
68
69    # After scatter_map is called, a scatter_map cell will exist. This cell
70    # has a reference to the actual function scatter_map, which has references
71    # to a closure that has a reference to the scatter_map cell (because the
72    # fn is recursive). To avoid this reference cycle, we set the function to
73    # None, clearing the cell
74    try:
75        res = scatter_map(inputs)
76    finally:
77        scatter_map = None  # type: ignore[assignment]
78    return res
79
80
81def scatter_kwargs(
82    inputs: Tuple[Any, ...],
83    kwargs: Optional[Dict[str, Any]],
84    target_gpus: Sequence[Union[int, torch.device]],
85    dim: int = 0,
86) -> Tuple[Tuple[Any, ...], Tuple[Dict[str, Any], ...]]:
87    r"""Scatter with support for kwargs dictionary."""
88    scattered_inputs = scatter(inputs, target_gpus, dim) if inputs else []
89    scattered_kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
90    if len(scattered_inputs) < len(scattered_kwargs):
91        scattered_inputs.extend(
92            () for _ in range(len(scattered_kwargs) - len(scattered_inputs))
93        )
94    elif len(scattered_kwargs) < len(inputs):
95        scattered_kwargs.extend(
96            {} for _ in range(len(scattered_inputs) - len(scattered_kwargs))
97        )
98    return tuple(scattered_inputs), tuple(scattered_kwargs)
99
100
101def gather(outputs: Any, target_device: Union[int, torch.device], dim: int = 0) -> Any:
102    r"""Gather tensors from different GPUs on a specified device.
103
104    This function is useful for gathering the results of a distributed computation.
105    It takes a sequence of objects, one for each GPU, and returns a single object
106    on the specified device.
107
108    Args:
109        outputs (Any): A sequence of objects (potentially tensors) to gather.
110        target_device (Union[int, torch.device]): The device to gather the tensors to.
111            Use 'cpu' for CPU to avoid a deprecation warning.
112        dim (int, optional): The dimension along which to gather. Default: 0.
113
114    Returns:
115        Any: A gathered object (potentially tensor) on the specified device.
116    """
117
118    def gather_map(outputs):
119        out = outputs[0]
120        if isinstance(out, torch.Tensor):
121            return Gather.apply(target_device, dim, *outputs)
122        if out is None:
123            return None
124        if isinstance(out, dict):
125            if not all(len(out) == len(d) for d in outputs):
126                raise ValueError("All dicts must have the same number of keys")
127            return type(out)((k, gather_map([d[k] for d in outputs])) for k in out)
128        if _is_namedtuple(out):
129            return type(out)._make(map(gather_map, zip(*outputs)))
130        return type(out)(map(gather_map, zip(*outputs)))
131
132    # Recursive function calls like this create reference cycles.
133    # Setting the function to None clears the refcycle.
134    try:
135        res = gather_map(outputs)
136    finally:
137        gather_map = None  # type: ignore[assignment]
138    return res
139