1# mypy: allow-untyped-defs 2import collections 3import warnings 4from typing import Optional, Sequence, Union 5 6import torch.cuda 7 8 9__all__ = ["all_reduce", "reduce", "broadcast", "all_gather", "reduce_scatter"] 10 11SUM = 0 # ncclRedOp_t 12 13 14def is_available(tensors): 15 if not hasattr(torch._C, "_nccl_all_reduce"): 16 warnings.warn("PyTorch is not compiled with NCCL support") 17 return False 18 19 devices = set() 20 for tensor in tensors: 21 if tensor.is_sparse: 22 return False 23 if not tensor.is_contiguous(): 24 return False 25 if not tensor.is_cuda: 26 return False 27 device = tensor.get_device() 28 if device in devices: 29 return False 30 devices.add(device) 31 32 return True 33 34 35def version(): 36 """ 37 Returns the version of the NCCL. 38 39 40 This function returns a tuple containing the major, minor, and patch version numbers of the NCCL. 41 The suffix is also included in the tuple if a version suffix exists. 42 Returns: 43 tuple: The version information of the NCCL. 44 """ 45 ver = torch._C._nccl_version() 46 major = ver >> 32 47 minor = (ver >> 16) & 65535 48 patch = ver & 65535 49 suffix = torch._C._nccl_version_suffix().decode("utf-8") 50 if suffix == "": 51 return (major, minor, patch) 52 else: 53 return (major, minor, patch, suffix) 54 55 56def unique_id(): 57 return torch._C._nccl_unique_id() 58 59 60def init_rank(num_ranks, uid, rank): 61 return torch._C._nccl_init_rank(num_ranks, uid, rank) 62 63 64def _check_sequence_type(inputs: Union[torch.Tensor, Sequence[torch.Tensor]]) -> None: 65 if not isinstance(inputs, collections.abc.Container) or isinstance( 66 inputs, torch.Tensor 67 ): 68 raise TypeError("Inputs should be a collection of tensors") 69 70 71def all_reduce(inputs, outputs=None, op=SUM, streams=None, comms=None): 72 _check_sequence_type(inputs) 73 if outputs is None: 74 outputs = inputs 75 _check_sequence_type(outputs) 76 torch._C._nccl_all_reduce(inputs, outputs, op, streams, comms) 77 78 79# `output` used to be `outputs`, taking in a list of tensors. So we have two 80# arguments for BC reasons. 81def reduce( 82 inputs: Sequence[torch.Tensor], 83 output: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]] = None, 84 root: int = 0, 85 op: int = SUM, 86 streams: Optional[Sequence[torch.cuda.Stream]] = None, 87 comms=None, 88 *, 89 outputs: Optional[Sequence[torch.Tensor]] = None, 90) -> None: 91 _check_sequence_type(inputs) 92 _output: torch.Tensor 93 if outputs is not None: 94 if output is not None: 95 raise ValueError( 96 "'output' and 'outputs' can not be both specified. 'outputs' is deprecated in " 97 "favor of 'output', taking in a single output tensor. The signature of reduce is: " 98 "reduce(inputs, output=None, root=0, op=SUM, streams=None, comms=None)." 99 ) 100 else: 101 warnings.warn( 102 "`nccl.reduce` with an output tensor list is deprecated. " 103 "Please specify a single output tensor with argument 'output' instead instead.", 104 FutureWarning, 105 stacklevel=2, 106 ) 107 _output = outputs[root] 108 elif not isinstance(output, torch.Tensor) and isinstance( 109 output, collections.abc.Sequence 110 ): 111 # User called old API with positional arguments of list of output tensors. 112 warnings.warn( 113 "nccl.reduce with an output tensor list is deprecated. " 114 "Please specify a single output tensor.", 115 FutureWarning, 116 stacklevel=2, 117 ) 118 _output = output[root] 119 else: 120 _output = inputs[root] if output is None else output 121 torch._C._nccl_reduce(inputs, _output, root, op, streams, comms) 122 123 124def broadcast( 125 inputs: Sequence[torch.Tensor], root: int = 0, streams=None, comms=None 126) -> None: 127 _check_sequence_type(inputs) 128 torch._C._nccl_broadcast(inputs, root, streams, comms) 129 130 131def all_gather( 132 inputs: Sequence[torch.Tensor], 133 outputs: Sequence[torch.Tensor], 134 streams=None, 135 comms=None, 136) -> None: 137 _check_sequence_type(inputs) 138 _check_sequence_type(outputs) 139 torch._C._nccl_all_gather(inputs, outputs, streams, comms) 140 141 142def reduce_scatter( 143 inputs: Sequence[torch.Tensor], 144 outputs: Sequence[torch.Tensor], 145 op: int = SUM, 146 streams=None, 147 comms=None, 148) -> None: 149 _check_sequence_type(inputs) 150 _check_sequence_type(outputs) 151 torch._C._nccl_reduce_scatter(inputs, outputs, op, streams, comms) 152