1# mypy: allow-untyped-defs 2from typing import Any, Dict, List, Optional, overload, Sequence, Tuple, TypeVar, Union 3from typing_extensions import deprecated 4 5import torch 6from torch.nn.parallel._functions import Gather, Scatter 7 8 9__all__ = ["scatter", "scatter_kwargs", "gather"] 10 11 12@deprecated( 13 "`is_namedtuple` is deprecated, please use the python checks instead", 14 category=FutureWarning, 15) 16def is_namedtuple(obj: Any) -> bool: 17 # Check if type was created from collections.namedtuple or a typing.NamedTuple. 18 return _is_namedtuple(obj) 19 20 21def _is_namedtuple(obj: Any) -> bool: 22 # Check if type was created from collections.namedtuple or a typing.NamedTuple. 23 return ( 24 isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") 25 ) 26 27 28T = TypeVar("T", dict, list, tuple) 29 30 31# For some reason, 'scatter' returns a tuple when given a single Tensor input but a list otherwise. 32@overload 33def scatter( 34 inputs: torch.Tensor, 35 target_gpus: Sequence[Union[int, torch.device]], 36 dim: int = ..., 37) -> Tuple[torch.Tensor, ...]: 38 ... 39 40 41@overload 42def scatter( 43 inputs: T, 44 target_gpus: Sequence[Union[int, torch.device]], 45 dim: int = ..., 46) -> List[T]: 47 ... 48 49 50def scatter(inputs, target_gpus, dim=0): 51 r"""Slice tensors into approximately equal chunks and distributes them across given GPUs. 52 53 Duplicates references to objects that are not tensors. 54 """ 55 56 def scatter_map(obj): 57 if isinstance(obj, torch.Tensor): 58 return Scatter.apply(target_gpus, None, dim, obj) 59 if _is_namedtuple(obj): 60 return [type(obj)(*args) for args in zip(*map(scatter_map, obj))] 61 if isinstance(obj, tuple) and len(obj) > 0: 62 return list(zip(*map(scatter_map, obj))) 63 if isinstance(obj, list) and len(obj) > 0: 64 return [list(i) for i in zip(*map(scatter_map, obj))] 65 if isinstance(obj, dict) and len(obj) > 0: 66 return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))] 67 return [obj for _ in target_gpus] 68 69 # After scatter_map is called, a scatter_map cell will exist. This cell 70 # has a reference to the actual function scatter_map, which has references 71 # to a closure that has a reference to the scatter_map cell (because the 72 # fn is recursive). To avoid this reference cycle, we set the function to 73 # None, clearing the cell 74 try: 75 res = scatter_map(inputs) 76 finally: 77 scatter_map = None # type: ignore[assignment] 78 return res 79 80 81def scatter_kwargs( 82 inputs: Tuple[Any, ...], 83 kwargs: Optional[Dict[str, Any]], 84 target_gpus: Sequence[Union[int, torch.device]], 85 dim: int = 0, 86) -> Tuple[Tuple[Any, ...], Tuple[Dict[str, Any], ...]]: 87 r"""Scatter with support for kwargs dictionary.""" 88 scattered_inputs = scatter(inputs, target_gpus, dim) if inputs else [] 89 scattered_kwargs = scatter(kwargs, target_gpus, dim) if kwargs else [] 90 if len(scattered_inputs) < len(scattered_kwargs): 91 scattered_inputs.extend( 92 () for _ in range(len(scattered_kwargs) - len(scattered_inputs)) 93 ) 94 elif len(scattered_kwargs) < len(inputs): 95 scattered_kwargs.extend( 96 {} for _ in range(len(scattered_inputs) - len(scattered_kwargs)) 97 ) 98 return tuple(scattered_inputs), tuple(scattered_kwargs) 99 100 101def gather(outputs: Any, target_device: Union[int, torch.device], dim: int = 0) -> Any: 102 r"""Gather tensors from different GPUs on a specified device. 103 104 This function is useful for gathering the results of a distributed computation. 105 It takes a sequence of objects, one for each GPU, and returns a single object 106 on the specified device. 107 108 Args: 109 outputs (Any): A sequence of objects (potentially tensors) to gather. 110 target_device (Union[int, torch.device]): The device to gather the tensors to. 111 Use 'cpu' for CPU to avoid a deprecation warning. 112 dim (int, optional): The dimension along which to gather. Default: 0. 113 114 Returns: 115 Any: A gathered object (potentially tensor) on the specified device. 116 """ 117 118 def gather_map(outputs): 119 out = outputs[0] 120 if isinstance(out, torch.Tensor): 121 return Gather.apply(target_device, dim, *outputs) 122 if out is None: 123 return None 124 if isinstance(out, dict): 125 if not all(len(out) == len(d) for d in outputs): 126 raise ValueError("All dicts must have the same number of keys") 127 return type(out)((k, gather_map([d[k] for d in outputs])) for k in out) 128 if _is_namedtuple(out): 129 return type(out)._make(map(gather_map, zip(*outputs))) 130 return type(out)(map(gather_map, zip(*outputs))) 131 132 # Recursive function calls like this create reference cycles. 133 # Setting the function to None clears the refcycle. 134 try: 135 res = gather_map(outputs) 136 finally: 137 gather_map = None # type: ignore[assignment] 138 return res 139