xref: /aosp_15_r20/external/pytorch/torch/nn/parallel/data_parallel.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import operator
3import warnings
4from itertools import chain
5from typing import Any, Dict, Generic, List, Optional, Sequence, Tuple, TypeVar, Union
6
7import torch
8from torch._utils import (
9    _get_all_device_indices,
10    _get_available_device_type,
11    _get_device_index,
12    _get_devices_properties,
13)
14from torch.nn.modules import Module
15from torch.nn.parallel.parallel_apply import parallel_apply
16from torch.nn.parallel.replicate import replicate
17from torch.nn.parallel.scatter_gather import gather, scatter_kwargs
18
19
20__all__ = ["DataParallel", "data_parallel"]
21
22
23def _check_balance(device_ids: Sequence[Union[int, torch.device]]) -> None:
24    imbalance_warn = """
25    There is an imbalance between your GPUs. You may want to exclude GPU {} which
26    has less than 75% of the memory or cores of GPU {}. You can do so by setting
27    the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
28    environment variable."""
29    device_ids = [_get_device_index(x, True) for x in device_ids]
30    dev_props = _get_devices_properties(device_ids)
31
32    def warn_imbalance(get_prop):
33        values = [get_prop(props) for props in dev_props]
34        min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1))
35        max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1))
36        if min_val / max_val < 0.75:
37            warnings.warn(
38                imbalance_warn.format(device_ids[min_pos], device_ids[max_pos])
39            )
40            return True
41        return False
42
43    if warn_imbalance(lambda props: props.total_memory):
44        return
45    if warn_imbalance(lambda props: props.multi_processor_count):
46        return
47
48
49T = TypeVar("T", bound=Module)
50
51
52class DataParallel(Module, Generic[T]):
53    r"""Implements data parallelism at the module level.
54
55    This container parallelizes the application of the given :attr:`module` by
56    splitting the input across the specified devices by chunking in the batch
57    dimension (other objects will be copied once per device). In the forward
58    pass, the module is replicated on each device, and each replica handles a
59    portion of the input. During the backwards pass, gradients from each replica
60    are summed into the original module.
61
62    The batch size should be larger than the number of GPUs used.
63
64    .. warning::
65        It is recommended to use :class:`~torch.nn.parallel.DistributedDataParallel`,
66        instead of this class, to do multi-GPU training, even if there is only a single
67        node. See: :ref:`cuda-nn-ddp-instead` and :ref:`ddp`.
68
69    Arbitrary positional and keyword inputs are allowed to be passed into
70    DataParallel but some types are specially handled. tensors will be
71    **scattered** on dim specified (default 0). tuple, list and dict types will
72    be shallow copied. The other types will be shared among different threads
73    and can be corrupted if written to in the model's forward pass.
74
75    The parallelized :attr:`module` must have its parameters and buffers on
76    ``device_ids[0]`` before running this :class:`~torch.nn.DataParallel`
77    module.
78
79    .. warning::
80        In each forward, :attr:`module` is **replicated** on each device, so any
81        updates to the running module in ``forward`` will be lost. For example,
82        if :attr:`module` has a counter attribute that is incremented in each
83        ``forward``, it will always stay at the initial value because the update
84        is done on the replicas which are destroyed after ``forward``. However,
85        :class:`~torch.nn.DataParallel` guarantees that the replica on
86        ``device[0]`` will have its parameters and buffers sharing storage with
87        the base parallelized :attr:`module`. So **in-place** updates to the
88        parameters or buffers on ``device[0]`` will be recorded. E.g.,
89        :class:`~torch.nn.BatchNorm2d` and :func:`~torch.nn.utils.spectral_norm`
90        rely on this behavior to update the buffers.
91
92    .. warning::
93        Forward and backward hooks defined on :attr:`module` and its submodules
94        will be invoked ``len(device_ids)`` times, each with inputs located on
95        a particular device. Particularly, the hooks are only guaranteed to be
96        executed in correct order with respect to operations on corresponding
97        devices. For example, it is not guaranteed that hooks set via
98        :meth:`~torch.nn.Module.register_forward_pre_hook` be executed before
99        `all` ``len(device_ids)`` :meth:`~torch.nn.Module.forward` calls, but
100        that each such hook be executed before the corresponding
101        :meth:`~torch.nn.Module.forward` call of that device.
102
103    .. warning::
104        When :attr:`module` returns a scalar (i.e., 0-dimensional tensor) in
105        :func:`forward`, this wrapper will return a vector of length equal to
106        number of devices used in data parallelism, containing the result from
107        each device.
108
109    .. note::
110        There is a subtlety in using the
111        ``pack sequence -> recurrent network -> unpack sequence`` pattern in a
112        :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`.
113        See :ref:`pack-rnn-unpack-with-data-parallelism` section in FAQ for
114        details.
115
116
117    Args:
118        module (Module): module to be parallelized
119        device_ids (list of int or torch.device): CUDA devices (default: all devices)
120        output_device (int or torch.device): device location of output (default: device_ids[0])
121
122    Attributes:
123        module (Module): the module to be parallelized
124
125    Example::
126
127        >>> # xdoctest: +SKIP
128        >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
129        >>> output = net(input_var)  # input_var can be on any device, including CPU
130    """
131
132    # TODO: update notes/cuda.rst when this class handles 8+ GPUs well
133
134    def __init__(
135        self,
136        module: T,
137        device_ids: Optional[Sequence[Union[int, torch.device]]] = None,
138        output_device: Optional[Union[int, torch.device]] = None,
139        dim: int = 0,
140    ) -> None:
141        super().__init__()
142        torch._C._log_api_usage_once("torch.nn.parallel.DataParallel")
143        device_type = _get_available_device_type()
144        if device_type is None:
145            self.module = module
146            self.device_ids = []
147            return
148
149        if device_ids is None:
150            device_ids = _get_all_device_indices()
151
152        if device_ids is None:
153            raise RuntimeError("no available devices were found")
154
155        if output_device is None:
156            output_device = device_ids[0]
157
158        self.dim = dim
159        self.module = module
160        self.device_ids = [_get_device_index(x, True) for x in device_ids]
161        self.output_device = _get_device_index(output_device, True)
162        self.src_device_obj = torch.device(device_type, self.device_ids[0])
163
164        if device_type == "cuda":
165            _check_balance(self.device_ids)
166
167        if len(self.device_ids) == 1:
168            self.module.to(self.src_device_obj)
169
170    def forward(self, *inputs: Any, **kwargs: Any) -> Any:
171        with torch.autograd.profiler.record_function("DataParallel.forward"):
172            if not self.device_ids:
173                return self.module(*inputs, **kwargs)
174
175            for t in chain(self.module.parameters(), self.module.buffers()):
176                if t.device != self.src_device_obj:
177                    raise RuntimeError(
178                        "module must have its parameters and buffers "
179                        f"on device {self.src_device_obj} (device_ids[0]) but found one of "
180                        f"them on device: {t.device}"
181                    )
182
183            inputs, module_kwargs = self.scatter(inputs, kwargs, self.device_ids)
184            # for forward function without any inputs, empty list and dict will be created
185            # so the module can be executed on one device which is the first one in device_ids
186            if not inputs and not module_kwargs:
187                inputs = ((),)
188                module_kwargs = ({},)
189
190            if len(self.device_ids) == 1:
191                return self.module(*inputs[0], **module_kwargs[0])
192            replicas = self.replicate(self.module, self.device_ids[: len(inputs)])
193            outputs = self.parallel_apply(replicas, inputs, module_kwargs)
194            return self.gather(outputs, self.output_device)
195
196    def replicate(
197        self, module: T, device_ids: Sequence[Union[int, torch.device]]
198    ) -> List[T]:
199        return replicate(module, device_ids, not torch.is_grad_enabled())
200
201    def scatter(
202        self,
203        inputs: Tuple[Any, ...],
204        kwargs: Optional[Dict[str, Any]],
205        device_ids: Sequence[Union[int, torch.device]],
206    ) -> Any:
207        return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
208
209    def parallel_apply(
210        self, replicas: Sequence[T], inputs: Sequence[Any], kwargs: Any
211    ) -> List[Any]:
212        return parallel_apply(
213            replicas, inputs, kwargs, self.device_ids[: len(replicas)]
214        )
215
216    def gather(self, outputs: Any, output_device: Union[int, torch.device]) -> Any:
217        return gather(outputs, output_device, dim=self.dim)
218
219
220def data_parallel(
221    module: Module,
222    inputs: Any,
223    device_ids: Optional[Sequence[Union[int, torch.device]]] = None,
224    output_device: Optional[Union[int, torch.device]] = None,
225    dim: int = 0,
226    module_kwargs: Optional[Any] = None,
227) -> torch.Tensor:
228    r"""Evaluate module(input) in parallel across the GPUs given in device_ids.
229
230    This is the functional version of the DataParallel module.
231
232    Args:
233        module (Module): the module to evaluate in parallel
234        inputs (Tensor): inputs to the module
235        device_ids (list of int or torch.device): GPU ids on which to replicate module
236        output_device (list of int or torch.device): GPU location of the output  Use -1 to indicate the CPU.
237            (default: device_ids[0])
238    Returns:
239        a Tensor containing the result of module(input) located on
240        output_device
241    """
242    if not isinstance(inputs, tuple):
243        inputs = (inputs,) if inputs is not None else ()
244
245    device_type = _get_available_device_type()
246
247    if device_type is None:
248        raise RuntimeError("device type could not be determined")
249
250    if device_ids is None:
251        device_ids = _get_all_device_indices()
252
253    if device_ids is None:
254        raise RuntimeError("no available devices were found")
255
256    if output_device is None:
257        output_device = device_ids[0]
258
259    device_ids = [_get_device_index(x, True) for x in device_ids]
260    output_device = _get_device_index(output_device, True)
261    src_device_obj = torch.device(device_type, device_ids[0])
262
263    for t in chain(module.parameters(), module.buffers()):
264        if t.device != src_device_obj:
265            raise RuntimeError(
266                "module must have its parameters and buffers "
267                f"on device {src_device_obj} (device_ids[0]) but found one of "
268                f"them on device: {t.device}"
269            )
270
271    inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim)
272    # for module without any inputs, empty list and dict will be created
273    # so the module can be executed on one device which is the first one in device_ids
274    if not inputs and not module_kwargs:
275        inputs = ((),)
276        module_kwargs = ({},)
277
278    assert module_kwargs is not None
279
280    if len(device_ids) == 1:
281        return module(*inputs[0], **module_kwargs[0])
282    used_device_ids = device_ids[: len(inputs)]
283    replicas = replicate(module, used_device_ids)
284    outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
285    return gather(outputs, output_device, dim)
286