xref: /aosp_15_r20/external/pytorch/torch/cuda/nccl.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import collections
3import warnings
4from typing import Optional, Sequence, Union
5
6import torch.cuda
7
8
9__all__ = ["all_reduce", "reduce", "broadcast", "all_gather", "reduce_scatter"]
10
11SUM = 0  # ncclRedOp_t
12
13
14def is_available(tensors):
15    if not hasattr(torch._C, "_nccl_all_reduce"):
16        warnings.warn("PyTorch is not compiled with NCCL support")
17        return False
18
19    devices = set()
20    for tensor in tensors:
21        if tensor.is_sparse:
22            return False
23        if not tensor.is_contiguous():
24            return False
25        if not tensor.is_cuda:
26            return False
27        device = tensor.get_device()
28        if device in devices:
29            return False
30        devices.add(device)
31
32    return True
33
34
35def version():
36    """
37    Returns the version of the NCCL.
38
39
40    This function returns a tuple containing the major, minor, and patch version numbers of the NCCL.
41    The suffix is also included in the tuple if a version suffix exists.
42    Returns:
43        tuple: The version information of the NCCL.
44    """
45    ver = torch._C._nccl_version()
46    major = ver >> 32
47    minor = (ver >> 16) & 65535
48    patch = ver & 65535
49    suffix = torch._C._nccl_version_suffix().decode("utf-8")
50    if suffix == "":
51        return (major, minor, patch)
52    else:
53        return (major, minor, patch, suffix)
54
55
56def unique_id():
57    return torch._C._nccl_unique_id()
58
59
60def init_rank(num_ranks, uid, rank):
61    return torch._C._nccl_init_rank(num_ranks, uid, rank)
62
63
64def _check_sequence_type(inputs: Union[torch.Tensor, Sequence[torch.Tensor]]) -> None:
65    if not isinstance(inputs, collections.abc.Container) or isinstance(
66        inputs, torch.Tensor
67    ):
68        raise TypeError("Inputs should be a collection of tensors")
69
70
71def all_reduce(inputs, outputs=None, op=SUM, streams=None, comms=None):
72    _check_sequence_type(inputs)
73    if outputs is None:
74        outputs = inputs
75    _check_sequence_type(outputs)
76    torch._C._nccl_all_reduce(inputs, outputs, op, streams, comms)
77
78
79# `output` used to be `outputs`, taking in a list of tensors. So we have two
80# arguments for BC reasons.
81def reduce(
82    inputs: Sequence[torch.Tensor],
83    output: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]] = None,
84    root: int = 0,
85    op: int = SUM,
86    streams: Optional[Sequence[torch.cuda.Stream]] = None,
87    comms=None,
88    *,
89    outputs: Optional[Sequence[torch.Tensor]] = None,
90) -> None:
91    _check_sequence_type(inputs)
92    _output: torch.Tensor
93    if outputs is not None:
94        if output is not None:
95            raise ValueError(
96                "'output' and 'outputs' can not be both specified. 'outputs' is deprecated in "
97                "favor of 'output', taking in a single output tensor. The signature of reduce is: "
98                "reduce(inputs, output=None, root=0, op=SUM, streams=None, comms=None)."
99            )
100        else:
101            warnings.warn(
102                "`nccl.reduce` with an output tensor list is deprecated. "
103                "Please specify a single output tensor with argument 'output' instead instead.",
104                FutureWarning,
105                stacklevel=2,
106            )
107            _output = outputs[root]
108    elif not isinstance(output, torch.Tensor) and isinstance(
109        output, collections.abc.Sequence
110    ):
111        # User called old API with positional arguments of list of output tensors.
112        warnings.warn(
113            "nccl.reduce with an output tensor list is deprecated. "
114            "Please specify a single output tensor.",
115            FutureWarning,
116            stacklevel=2,
117        )
118        _output = output[root]
119    else:
120        _output = inputs[root] if output is None else output
121    torch._C._nccl_reduce(inputs, _output, root, op, streams, comms)
122
123
124def broadcast(
125    inputs: Sequence[torch.Tensor], root: int = 0, streams=None, comms=None
126) -> None:
127    _check_sequence_type(inputs)
128    torch._C._nccl_broadcast(inputs, root, streams, comms)
129
130
131def all_gather(
132    inputs: Sequence[torch.Tensor],
133    outputs: Sequence[torch.Tensor],
134    streams=None,
135    comms=None,
136) -> None:
137    _check_sequence_type(inputs)
138    _check_sequence_type(outputs)
139    torch._C._nccl_all_gather(inputs, outputs, streams, comms)
140
141
142def reduce_scatter(
143    inputs: Sequence[torch.Tensor],
144    outputs: Sequence[torch.Tensor],
145    op: int = SUM,
146    streams=None,
147    comms=None,
148) -> None:
149    _check_sequence_type(inputs)
150    _check_sequence_type(outputs)
151    torch._C._nccl_reduce_scatter(inputs, outputs, op, streams, comms)
152