xref: /aosp_15_r20/external/pytorch/torch/csrc/cuda/nccl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <ATen/cuda/CUDAContext.h>
5 
6 #include <cstddef>
7 #include <optional>
8 #include <vector>
9 
10 // NCCL BFloat16 is enabled only for CUDA 11+ and NCCL versions 2.10+, or for
11 // HIP 3.1+
12 #if defined(__CUDA_BF16_TYPES_EXIST__)
13 #define HAS_NCCL_BF16_DATATYPE \
14   ((NCCL_MAJOR > 2) || (NCCL_MAJOR == 2) && (NCCL_MINOR >= 10))
15 #elif defined(USE_ROCM) && (TORCH_HIP_VERSION >= 301)
16 #define HAS_NCCL_BF16_DATATYPE 1
17 #else
18 #define HAS_NCCL_BF16_DATATYPE 0
19 #endif
20 
21 namespace torch::cuda::nccl {
22 
23 /* The following are copied from <nccl.h> and redefined in torch::cuda::nccl
24  * namespace */
25 /* pytorch should only use the following definition within pytorch scope */
26 
27 /* Opaque handle to communicator to ncclComm*, this will reinterpret as ncclComm
28  * in nccl.cpp */
29 typedef void* ncclComm_t;
30 
31 /** redefine nccl unique ID in torch scope. this should be identical to native
32  * nccl impp. */
33 #define NCCL_UNIQUE_ID_BYTES 128
34 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
35 typedef struct {
36   char internal[NCCL_UNIQUE_ID_BYTES];
37 } ncclUniqueId;
38 
39 /* Error type */
40 enum class ncclResult {
41   Success = 0,
42   UnhandledCudaError = 1,
43   SystemError = 2,
44   InternalError = 3,
45   InvalidArgument = 4,
46   InvalidUsage = 5,
47   RemoteError = 6,
48   InProgress = 7,
49   NumResults = 8
50 };
51 
52 /* Reduction operation selector */
53 enum class ncclRedOp { Sum = 0, Prod = 1, Max = 2, Min = 3, NumOps = 4 };
54 
55 /* Data types */
56 enum class ncclDataType {
57   Int8 = 0,
58   Char = 0,
59   Uint8 = 1,
60   Int32 = 2,
61   Int = 2,
62   Uint32 = 3,
63   Int64 = 4,
64   Uint64 = 5,
65   Float16 = 6,
66   Half = 6,
67   Float32 = 7,
68   Float = 7,
69   Float64 = 8,
70   Double = 8,
71   Bfloat16 = 9,
72   NumTypes = 10
73 };
74 
75 // RAII helper class to manage NCCL group API and CUDA free mutex.
76 // The destructor is allowed to throw since this helper class only
77 // manages group and lock lifetimes.
78 struct AutoNcclGroup {
79   AutoNcclGroup();
80   AutoNcclGroup(ncclComm_t comm, bool comm_nonblocking);
81   ~AutoNcclGroup() noexcept(false);
82   ncclComm_t comm_;
83   bool comm_nonblocking_;
84 };
85 
86 // NOTE: this is exposed only so that python_nccl.cpp can some of these helpers.
87 // Don't use them outside of these files.
88 namespace detail {
89 
90 TORCH_CUDA_CPP_API void throw_nccl_error(ncclResult status);
91 
NCCL_CHECK(ncclResult status)92 inline void NCCL_CHECK(ncclResult status) {
93   if (status != ncclResult::Success) {
94     throw_nccl_error(status);
95   }
96 }
97 
98 TORCH_CUDA_CPP_API at::ArrayRef<ncclComm_t> get_communicators(
99     at::TensorList inputs);
100 TORCH_CUDA_CPP_API void check_inputs(
101     at::TensorList inputs,
102     at::TensorList outputs,
103     int input_multiplier,
104     int output_multiplier);
105 TORCH_CUDA_CPP_API void check_inputs(
106     at::TensorList inputs,
107     const at::Tensor& output,
108     int root,
109     int input_multiplier,
110     int output_multiplier);
111 
112 } // namespace detail
113 
114 using comm_list = std::vector<ncclComm_t>;
115 using stream_list = std::vector<std::optional<at::cuda::CUDAStream>>;
116 
117 TORCH_CUDA_CPP_API std::uint64_t version();
118 TORCH_CUDA_CPP_API const char* version_suffix();
119 
120 bool is_available(at::TensorList tensors);
121 
122 TORCH_CUDA_CPP_API void get_unique_id(ncclUniqueId& id);
123 TORCH_CUDA_CPP_API ncclComm_t
124 comm_init_rank(int nranks, const ncclUniqueId& comm_id, int rank);
125 TORCH_CUDA_CPP_API void comm_destroy(ncclComm_t comm);
126 
127 TORCH_CUDA_CPP_API void broadcast(
128     at::TensorList tensors,
129     const stream_list& streams = {},
130     const comm_list& user_comms = {});
131 
132 size_t get_max_count();
133 
134 TORCH_CUDA_CPP_API void reduce(
135     const std::vector<at::Tensor>& inputs,
136     at::Tensor& output,
137     int32_t root = 0,
138     int32_t op = static_cast<int>(ncclRedOp::Sum),
139     const stream_list& streams = {},
140     const comm_list& user_comms = {});
141 
142 TORCH_CUDA_CPP_API void reduce(
143     std::vector<at::Tensor>& inputs,
144     int32_t root = 0,
145     int32_t op = static_cast<int>(ncclRedOp::Sum),
146     const stream_list& streams = {},
147     const comm_list& user_comms = {});
148 
149 TORCH_CUDA_CPP_API void all_reduce(
150     const std::vector<at::Tensor>& inputs,
151     std::vector<at::Tensor>& outputs,
152     int32_t op = static_cast<int>(ncclRedOp::Sum),
153     const stream_list& streams = {},
154     const comm_list& user_comms = {});
155 
156 TORCH_CUDA_CPP_API void reduce_scatter(
157     const std::vector<at::Tensor>& inputs,
158     std::vector<at::Tensor>& outputs,
159     int32_t op = static_cast<int>(ncclRedOp::Sum),
160     const stream_list& streams = {},
161     const comm_list& user_comms = {});
162 
163 TORCH_CUDA_CPP_API void scatter(
164     const std::vector<at::Tensor>& inputs,
165     at::Tensor& outputs,
166     ncclComm_t comm,
167     at::cuda::CUDAStream& stream,
168     int32_t root = 0);
169 
170 TORCH_CUDA_CPP_API void all_gather(
171     const std::vector<at::Tensor>& inputs,
172     std::vector<at::Tensor>& outputs,
173     const stream_list& streams = {},
174     const comm_list& user_comms = {});
175 
176 TORCH_CUDA_CPP_API void gather(
177     const at::Tensor& inputs,
178     std::vector<at::Tensor>& outputs,
179     ncclComm_t comm,
180     at::cuda::CUDAStream& stream,
181     int32_t root = 0);
182 
183 TORCH_CUDA_CPP_API void all2all_single_equal_split(
184     at::Tensor& input,
185     at::Tensor& output,
186     int size,
187     ncclComm_t comm,
188     at::cuda::CUDAStream& stream);
189 
190 TORCH_CUDA_CPP_API void all2all_single_unequal_split(
191     void* sendbuff,
192     const size_t* sendcounts,
193     const size_t* senddispls,
194     void* recvbuff,
195     const size_t* recvcounts,
196     const size_t* recvdispls,
197     size_t size,
198     c10::ScalarType type,
199     ncclComm_t comm,
200     at::cuda::CUDAStream& stream);
201 
202 TORCH_CUDA_CPP_API void all2all(
203     std::vector<at::Tensor>& outputTensors,
204     std::vector<at::Tensor>& inputTensors,
205     ncclComm_t _comm,
206     at::cuda::CUDAStream& stream);
207 
208 TORCH_CUDA_CPP_API void send(
209     const at::Tensor& input,
210     ncclComm_t comm,
211     at::cuda::CUDAStream stream,
212     int dst);
213 
214 TORCH_CUDA_CPP_API void recv(
215     at::Tensor& output,
216     ncclComm_t comm,
217     at::cuda::CUDAStream stream,
218     int src);
219 } // namespace torch::cuda::nccl
220