1# mypy: allow-untyped-defs 2import operator 3import warnings 4from itertools import chain 5from typing import Any, Dict, Generic, List, Optional, Sequence, Tuple, TypeVar, Union 6 7import torch 8from torch._utils import ( 9 _get_all_device_indices, 10 _get_available_device_type, 11 _get_device_index, 12 _get_devices_properties, 13) 14from torch.nn.modules import Module 15from torch.nn.parallel.parallel_apply import parallel_apply 16from torch.nn.parallel.replicate import replicate 17from torch.nn.parallel.scatter_gather import gather, scatter_kwargs 18 19 20__all__ = ["DataParallel", "data_parallel"] 21 22 23def _check_balance(device_ids: Sequence[Union[int, torch.device]]) -> None: 24 imbalance_warn = """ 25 There is an imbalance between your GPUs. You may want to exclude GPU {} which 26 has less than 75% of the memory or cores of GPU {}. You can do so by setting 27 the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES 28 environment variable.""" 29 device_ids = [_get_device_index(x, True) for x in device_ids] 30 dev_props = _get_devices_properties(device_ids) 31 32 def warn_imbalance(get_prop): 33 values = [get_prop(props) for props in dev_props] 34 min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1)) 35 max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1)) 36 if min_val / max_val < 0.75: 37 warnings.warn( 38 imbalance_warn.format(device_ids[min_pos], device_ids[max_pos]) 39 ) 40 return True 41 return False 42 43 if warn_imbalance(lambda props: props.total_memory): 44 return 45 if warn_imbalance(lambda props: props.multi_processor_count): 46 return 47 48 49T = TypeVar("T", bound=Module) 50 51 52class DataParallel(Module, Generic[T]): 53 r"""Implements data parallelism at the module level. 54 55 This container parallelizes the application of the given :attr:`module` by 56 splitting the input across the specified devices by chunking in the batch 57 dimension (other objects will be copied once per device). In the forward 58 pass, the module is replicated on each device, and each replica handles a 59 portion of the input. During the backwards pass, gradients from each replica 60 are summed into the original module. 61 62 The batch size should be larger than the number of GPUs used. 63 64 .. warning:: 65 It is recommended to use :class:`~torch.nn.parallel.DistributedDataParallel`, 66 instead of this class, to do multi-GPU training, even if there is only a single 67 node. See: :ref:`cuda-nn-ddp-instead` and :ref:`ddp`. 68 69 Arbitrary positional and keyword inputs are allowed to be passed into 70 DataParallel but some types are specially handled. tensors will be 71 **scattered** on dim specified (default 0). tuple, list and dict types will 72 be shallow copied. The other types will be shared among different threads 73 and can be corrupted if written to in the model's forward pass. 74 75 The parallelized :attr:`module` must have its parameters and buffers on 76 ``device_ids[0]`` before running this :class:`~torch.nn.DataParallel` 77 module. 78 79 .. warning:: 80 In each forward, :attr:`module` is **replicated** on each device, so any 81 updates to the running module in ``forward`` will be lost. For example, 82 if :attr:`module` has a counter attribute that is incremented in each 83 ``forward``, it will always stay at the initial value because the update 84 is done on the replicas which are destroyed after ``forward``. However, 85 :class:`~torch.nn.DataParallel` guarantees that the replica on 86 ``device[0]`` will have its parameters and buffers sharing storage with 87 the base parallelized :attr:`module`. So **in-place** updates to the 88 parameters or buffers on ``device[0]`` will be recorded. E.g., 89 :class:`~torch.nn.BatchNorm2d` and :func:`~torch.nn.utils.spectral_norm` 90 rely on this behavior to update the buffers. 91 92 .. warning:: 93 Forward and backward hooks defined on :attr:`module` and its submodules 94 will be invoked ``len(device_ids)`` times, each with inputs located on 95 a particular device. Particularly, the hooks are only guaranteed to be 96 executed in correct order with respect to operations on corresponding 97 devices. For example, it is not guaranteed that hooks set via 98 :meth:`~torch.nn.Module.register_forward_pre_hook` be executed before 99 `all` ``len(device_ids)`` :meth:`~torch.nn.Module.forward` calls, but 100 that each such hook be executed before the corresponding 101 :meth:`~torch.nn.Module.forward` call of that device. 102 103 .. warning:: 104 When :attr:`module` returns a scalar (i.e., 0-dimensional tensor) in 105 :func:`forward`, this wrapper will return a vector of length equal to 106 number of devices used in data parallelism, containing the result from 107 each device. 108 109 .. note:: 110 There is a subtlety in using the 111 ``pack sequence -> recurrent network -> unpack sequence`` pattern in a 112 :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`. 113 See :ref:`pack-rnn-unpack-with-data-parallelism` section in FAQ for 114 details. 115 116 117 Args: 118 module (Module): module to be parallelized 119 device_ids (list of int or torch.device): CUDA devices (default: all devices) 120 output_device (int or torch.device): device location of output (default: device_ids[0]) 121 122 Attributes: 123 module (Module): the module to be parallelized 124 125 Example:: 126 127 >>> # xdoctest: +SKIP 128 >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) 129 >>> output = net(input_var) # input_var can be on any device, including CPU 130 """ 131 132 # TODO: update notes/cuda.rst when this class handles 8+ GPUs well 133 134 def __init__( 135 self, 136 module: T, 137 device_ids: Optional[Sequence[Union[int, torch.device]]] = None, 138 output_device: Optional[Union[int, torch.device]] = None, 139 dim: int = 0, 140 ) -> None: 141 super().__init__() 142 torch._C._log_api_usage_once("torch.nn.parallel.DataParallel") 143 device_type = _get_available_device_type() 144 if device_type is None: 145 self.module = module 146 self.device_ids = [] 147 return 148 149 if device_ids is None: 150 device_ids = _get_all_device_indices() 151 152 if device_ids is None: 153 raise RuntimeError("no available devices were found") 154 155 if output_device is None: 156 output_device = device_ids[0] 157 158 self.dim = dim 159 self.module = module 160 self.device_ids = [_get_device_index(x, True) for x in device_ids] 161 self.output_device = _get_device_index(output_device, True) 162 self.src_device_obj = torch.device(device_type, self.device_ids[0]) 163 164 if device_type == "cuda": 165 _check_balance(self.device_ids) 166 167 if len(self.device_ids) == 1: 168 self.module.to(self.src_device_obj) 169 170 def forward(self, *inputs: Any, **kwargs: Any) -> Any: 171 with torch.autograd.profiler.record_function("DataParallel.forward"): 172 if not self.device_ids: 173 return self.module(*inputs, **kwargs) 174 175 for t in chain(self.module.parameters(), self.module.buffers()): 176 if t.device != self.src_device_obj: 177 raise RuntimeError( 178 "module must have its parameters and buffers " 179 f"on device {self.src_device_obj} (device_ids[0]) but found one of " 180 f"them on device: {t.device}" 181 ) 182 183 inputs, module_kwargs = self.scatter(inputs, kwargs, self.device_ids) 184 # for forward function without any inputs, empty list and dict will be created 185 # so the module can be executed on one device which is the first one in device_ids 186 if not inputs and not module_kwargs: 187 inputs = ((),) 188 module_kwargs = ({},) 189 190 if len(self.device_ids) == 1: 191 return self.module(*inputs[0], **module_kwargs[0]) 192 replicas = self.replicate(self.module, self.device_ids[: len(inputs)]) 193 outputs = self.parallel_apply(replicas, inputs, module_kwargs) 194 return self.gather(outputs, self.output_device) 195 196 def replicate( 197 self, module: T, device_ids: Sequence[Union[int, torch.device]] 198 ) -> List[T]: 199 return replicate(module, device_ids, not torch.is_grad_enabled()) 200 201 def scatter( 202 self, 203 inputs: Tuple[Any, ...], 204 kwargs: Optional[Dict[str, Any]], 205 device_ids: Sequence[Union[int, torch.device]], 206 ) -> Any: 207 return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) 208 209 def parallel_apply( 210 self, replicas: Sequence[T], inputs: Sequence[Any], kwargs: Any 211 ) -> List[Any]: 212 return parallel_apply( 213 replicas, inputs, kwargs, self.device_ids[: len(replicas)] 214 ) 215 216 def gather(self, outputs: Any, output_device: Union[int, torch.device]) -> Any: 217 return gather(outputs, output_device, dim=self.dim) 218 219 220def data_parallel( 221 module: Module, 222 inputs: Any, 223 device_ids: Optional[Sequence[Union[int, torch.device]]] = None, 224 output_device: Optional[Union[int, torch.device]] = None, 225 dim: int = 0, 226 module_kwargs: Optional[Any] = None, 227) -> torch.Tensor: 228 r"""Evaluate module(input) in parallel across the GPUs given in device_ids. 229 230 This is the functional version of the DataParallel module. 231 232 Args: 233 module (Module): the module to evaluate in parallel 234 inputs (Tensor): inputs to the module 235 device_ids (list of int or torch.device): GPU ids on which to replicate module 236 output_device (list of int or torch.device): GPU location of the output Use -1 to indicate the CPU. 237 (default: device_ids[0]) 238 Returns: 239 a Tensor containing the result of module(input) located on 240 output_device 241 """ 242 if not isinstance(inputs, tuple): 243 inputs = (inputs,) if inputs is not None else () 244 245 device_type = _get_available_device_type() 246 247 if device_type is None: 248 raise RuntimeError("device type could not be determined") 249 250 if device_ids is None: 251 device_ids = _get_all_device_indices() 252 253 if device_ids is None: 254 raise RuntimeError("no available devices were found") 255 256 if output_device is None: 257 output_device = device_ids[0] 258 259 device_ids = [_get_device_index(x, True) for x in device_ids] 260 output_device = _get_device_index(output_device, True) 261 src_device_obj = torch.device(device_type, device_ids[0]) 262 263 for t in chain(module.parameters(), module.buffers()): 264 if t.device != src_device_obj: 265 raise RuntimeError( 266 "module must have its parameters and buffers " 267 f"on device {src_device_obj} (device_ids[0]) but found one of " 268 f"them on device: {t.device}" 269 ) 270 271 inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim) 272 # for module without any inputs, empty list and dict will be created 273 # so the module can be executed on one device which is the first one in device_ids 274 if not inputs and not module_kwargs: 275 inputs = ((),) 276 module_kwargs = ({},) 277 278 assert module_kwargs is not None 279 280 if len(device_ids) == 1: 281 return module(*inputs[0], **module_kwargs[0]) 282 used_device_ids = device_ids[: len(inputs)] 283 replicas = replicate(module, used_device_ids) 284 outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids) 285 return gather(outputs, output_device, dim) 286