xref: /aosp_15_r20/external/pytorch/torch/cuda/nccl.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport collections
3*da0073e9SAndroid Build Coastguard Workerimport warnings
4*da0073e9SAndroid Build Coastguard Workerfrom typing import Optional, Sequence, Union
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport torch.cuda
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker__all__ = ["all_reduce", "reduce", "broadcast", "all_gather", "reduce_scatter"]
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard WorkerSUM = 0  # ncclRedOp_t
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workerdef is_available(tensors):
15*da0073e9SAndroid Build Coastguard Worker    if not hasattr(torch._C, "_nccl_all_reduce"):
16*da0073e9SAndroid Build Coastguard Worker        warnings.warn("PyTorch is not compiled with NCCL support")
17*da0073e9SAndroid Build Coastguard Worker        return False
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker    devices = set()
20*da0073e9SAndroid Build Coastguard Worker    for tensor in tensors:
21*da0073e9SAndroid Build Coastguard Worker        if tensor.is_sparse:
22*da0073e9SAndroid Build Coastguard Worker            return False
23*da0073e9SAndroid Build Coastguard Worker        if not tensor.is_contiguous():
24*da0073e9SAndroid Build Coastguard Worker            return False
25*da0073e9SAndroid Build Coastguard Worker        if not tensor.is_cuda:
26*da0073e9SAndroid Build Coastguard Worker            return False
27*da0073e9SAndroid Build Coastguard Worker        device = tensor.get_device()
28*da0073e9SAndroid Build Coastguard Worker        if device in devices:
29*da0073e9SAndroid Build Coastguard Worker            return False
30*da0073e9SAndroid Build Coastguard Worker        devices.add(device)
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker    return True
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Workerdef version():
36*da0073e9SAndroid Build Coastguard Worker    """
37*da0073e9SAndroid Build Coastguard Worker    Returns the version of the NCCL.
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker    This function returns a tuple containing the major, minor, and patch version numbers of the NCCL.
41*da0073e9SAndroid Build Coastguard Worker    The suffix is also included in the tuple if a version suffix exists.
42*da0073e9SAndroid Build Coastguard Worker    Returns:
43*da0073e9SAndroid Build Coastguard Worker        tuple: The version information of the NCCL.
44*da0073e9SAndroid Build Coastguard Worker    """
45*da0073e9SAndroid Build Coastguard Worker    ver = torch._C._nccl_version()
46*da0073e9SAndroid Build Coastguard Worker    major = ver >> 32
47*da0073e9SAndroid Build Coastguard Worker    minor = (ver >> 16) & 65535
48*da0073e9SAndroid Build Coastguard Worker    patch = ver & 65535
49*da0073e9SAndroid Build Coastguard Worker    suffix = torch._C._nccl_version_suffix().decode("utf-8")
50*da0073e9SAndroid Build Coastguard Worker    if suffix == "":
51*da0073e9SAndroid Build Coastguard Worker        return (major, minor, patch)
52*da0073e9SAndroid Build Coastguard Worker    else:
53*da0073e9SAndroid Build Coastguard Worker        return (major, minor, patch, suffix)
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Workerdef unique_id():
57*da0073e9SAndroid Build Coastguard Worker    return torch._C._nccl_unique_id()
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Workerdef init_rank(num_ranks, uid, rank):
61*da0073e9SAndroid Build Coastguard Worker    return torch._C._nccl_init_rank(num_ranks, uid, rank)
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Workerdef _check_sequence_type(inputs: Union[torch.Tensor, Sequence[torch.Tensor]]) -> None:
65*da0073e9SAndroid Build Coastguard Worker    if not isinstance(inputs, collections.abc.Container) or isinstance(
66*da0073e9SAndroid Build Coastguard Worker        inputs, torch.Tensor
67*da0073e9SAndroid Build Coastguard Worker    ):
68*da0073e9SAndroid Build Coastguard Worker        raise TypeError("Inputs should be a collection of tensors")
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Workerdef all_reduce(inputs, outputs=None, op=SUM, streams=None, comms=None):
72*da0073e9SAndroid Build Coastguard Worker    _check_sequence_type(inputs)
73*da0073e9SAndroid Build Coastguard Worker    if outputs is None:
74*da0073e9SAndroid Build Coastguard Worker        outputs = inputs
75*da0073e9SAndroid Build Coastguard Worker    _check_sequence_type(outputs)
76*da0073e9SAndroid Build Coastguard Worker    torch._C._nccl_all_reduce(inputs, outputs, op, streams, comms)
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker# `output` used to be `outputs`, taking in a list of tensors. So we have two
80*da0073e9SAndroid Build Coastguard Worker# arguments for BC reasons.
81*da0073e9SAndroid Build Coastguard Workerdef reduce(
82*da0073e9SAndroid Build Coastguard Worker    inputs: Sequence[torch.Tensor],
83*da0073e9SAndroid Build Coastguard Worker    output: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]] = None,
84*da0073e9SAndroid Build Coastguard Worker    root: int = 0,
85*da0073e9SAndroid Build Coastguard Worker    op: int = SUM,
86*da0073e9SAndroid Build Coastguard Worker    streams: Optional[Sequence[torch.cuda.Stream]] = None,
87*da0073e9SAndroid Build Coastguard Worker    comms=None,
88*da0073e9SAndroid Build Coastguard Worker    *,
89*da0073e9SAndroid Build Coastguard Worker    outputs: Optional[Sequence[torch.Tensor]] = None,
90*da0073e9SAndroid Build Coastguard Worker) -> None:
91*da0073e9SAndroid Build Coastguard Worker    _check_sequence_type(inputs)
92*da0073e9SAndroid Build Coastguard Worker    _output: torch.Tensor
93*da0073e9SAndroid Build Coastguard Worker    if outputs is not None:
94*da0073e9SAndroid Build Coastguard Worker        if output is not None:
95*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
96*da0073e9SAndroid Build Coastguard Worker                "'output' and 'outputs' can not be both specified. 'outputs' is deprecated in "
97*da0073e9SAndroid Build Coastguard Worker                "favor of 'output', taking in a single output tensor. The signature of reduce is: "
98*da0073e9SAndroid Build Coastguard Worker                "reduce(inputs, output=None, root=0, op=SUM, streams=None, comms=None)."
99*da0073e9SAndroid Build Coastguard Worker            )
100*da0073e9SAndroid Build Coastguard Worker        else:
101*da0073e9SAndroid Build Coastguard Worker            warnings.warn(
102*da0073e9SAndroid Build Coastguard Worker                "`nccl.reduce` with an output tensor list is deprecated. "
103*da0073e9SAndroid Build Coastguard Worker                "Please specify a single output tensor with argument 'output' instead instead.",
104*da0073e9SAndroid Build Coastguard Worker                FutureWarning,
105*da0073e9SAndroid Build Coastguard Worker                stacklevel=2,
106*da0073e9SAndroid Build Coastguard Worker            )
107*da0073e9SAndroid Build Coastguard Worker            _output = outputs[root]
108*da0073e9SAndroid Build Coastguard Worker    elif not isinstance(output, torch.Tensor) and isinstance(
109*da0073e9SAndroid Build Coastguard Worker        output, collections.abc.Sequence
110*da0073e9SAndroid Build Coastguard Worker    ):
111*da0073e9SAndroid Build Coastguard Worker        # User called old API with positional arguments of list of output tensors.
112*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
113*da0073e9SAndroid Build Coastguard Worker            "nccl.reduce with an output tensor list is deprecated. "
114*da0073e9SAndroid Build Coastguard Worker            "Please specify a single output tensor.",
115*da0073e9SAndroid Build Coastguard Worker            FutureWarning,
116*da0073e9SAndroid Build Coastguard Worker            stacklevel=2,
117*da0073e9SAndroid Build Coastguard Worker        )
118*da0073e9SAndroid Build Coastguard Worker        _output = output[root]
119*da0073e9SAndroid Build Coastguard Worker    else:
120*da0073e9SAndroid Build Coastguard Worker        _output = inputs[root] if output is None else output
121*da0073e9SAndroid Build Coastguard Worker    torch._C._nccl_reduce(inputs, _output, root, op, streams, comms)
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Workerdef broadcast(
125*da0073e9SAndroid Build Coastguard Worker    inputs: Sequence[torch.Tensor], root: int = 0, streams=None, comms=None
126*da0073e9SAndroid Build Coastguard Worker) -> None:
127*da0073e9SAndroid Build Coastguard Worker    _check_sequence_type(inputs)
128*da0073e9SAndroid Build Coastguard Worker    torch._C._nccl_broadcast(inputs, root, streams, comms)
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Workerdef all_gather(
132*da0073e9SAndroid Build Coastguard Worker    inputs: Sequence[torch.Tensor],
133*da0073e9SAndroid Build Coastguard Worker    outputs: Sequence[torch.Tensor],
134*da0073e9SAndroid Build Coastguard Worker    streams=None,
135*da0073e9SAndroid Build Coastguard Worker    comms=None,
136*da0073e9SAndroid Build Coastguard Worker) -> None:
137*da0073e9SAndroid Build Coastguard Worker    _check_sequence_type(inputs)
138*da0073e9SAndroid Build Coastguard Worker    _check_sequence_type(outputs)
139*da0073e9SAndroid Build Coastguard Worker    torch._C._nccl_all_gather(inputs, outputs, streams, comms)
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Workerdef reduce_scatter(
143*da0073e9SAndroid Build Coastguard Worker    inputs: Sequence[torch.Tensor],
144*da0073e9SAndroid Build Coastguard Worker    outputs: Sequence[torch.Tensor],
145*da0073e9SAndroid Build Coastguard Worker    op: int = SUM,
146*da0073e9SAndroid Build Coastguard Worker    streams=None,
147*da0073e9SAndroid Build Coastguard Worker    comms=None,
148*da0073e9SAndroid Build Coastguard Worker) -> None:
149*da0073e9SAndroid Build Coastguard Worker    _check_sequence_type(inputs)
150*da0073e9SAndroid Build Coastguard Worker    _check_sequence_type(outputs)
151*da0073e9SAndroid Build Coastguard Worker    torch._C._nccl_reduce_scatter(inputs, outputs, op, streams, comms)
152