1 #pragma once
2
3 #include <ATen/ATen.h>
4 #include <ATen/cuda/CUDAContext.h>
5
6 #include <cstddef>
7 #include <optional>
8 #include <vector>
9
10 // NCCL BFloat16 is enabled only for CUDA 11+ and NCCL versions 2.10+, or for
11 // HIP 3.1+
12 #if defined(__CUDA_BF16_TYPES_EXIST__)
13 #define HAS_NCCL_BF16_DATATYPE \
14 ((NCCL_MAJOR > 2) || (NCCL_MAJOR == 2) && (NCCL_MINOR >= 10))
15 #elif defined(USE_ROCM) && (TORCH_HIP_VERSION >= 301)
16 #define HAS_NCCL_BF16_DATATYPE 1
17 #else
18 #define HAS_NCCL_BF16_DATATYPE 0
19 #endif
20
21 namespace torch::cuda::nccl {
22
23 /* The following are copied from <nccl.h> and redefined in torch::cuda::nccl
24 * namespace */
25 /* pytorch should only use the following definition within pytorch scope */
26
27 /* Opaque handle to communicator to ncclComm*, this will reinterpret as ncclComm
28 * in nccl.cpp */
29 typedef void* ncclComm_t;
30
31 /** redefine nccl unique ID in torch scope. this should be identical to native
32 * nccl impp. */
33 #define NCCL_UNIQUE_ID_BYTES 128
34 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
35 typedef struct {
36 char internal[NCCL_UNIQUE_ID_BYTES];
37 } ncclUniqueId;
38
39 /* Error type */
40 enum class ncclResult {
41 Success = 0,
42 UnhandledCudaError = 1,
43 SystemError = 2,
44 InternalError = 3,
45 InvalidArgument = 4,
46 InvalidUsage = 5,
47 RemoteError = 6,
48 InProgress = 7,
49 NumResults = 8
50 };
51
52 /* Reduction operation selector */
53 enum class ncclRedOp { Sum = 0, Prod = 1, Max = 2, Min = 3, NumOps = 4 };
54
55 /* Data types */
56 enum class ncclDataType {
57 Int8 = 0,
58 Char = 0,
59 Uint8 = 1,
60 Int32 = 2,
61 Int = 2,
62 Uint32 = 3,
63 Int64 = 4,
64 Uint64 = 5,
65 Float16 = 6,
66 Half = 6,
67 Float32 = 7,
68 Float = 7,
69 Float64 = 8,
70 Double = 8,
71 Bfloat16 = 9,
72 NumTypes = 10
73 };
74
75 // RAII helper class to manage NCCL group API and CUDA free mutex.
76 // The destructor is allowed to throw since this helper class only
77 // manages group and lock lifetimes.
78 struct AutoNcclGroup {
79 AutoNcclGroup();
80 AutoNcclGroup(ncclComm_t comm, bool comm_nonblocking);
81 ~AutoNcclGroup() noexcept(false);
82 ncclComm_t comm_;
83 bool comm_nonblocking_;
84 };
85
86 // NOTE: this is exposed only so that python_nccl.cpp can some of these helpers.
87 // Don't use them outside of these files.
88 namespace detail {
89
90 TORCH_CUDA_CPP_API void throw_nccl_error(ncclResult status);
91
NCCL_CHECK(ncclResult status)92 inline void NCCL_CHECK(ncclResult status) {
93 if (status != ncclResult::Success) {
94 throw_nccl_error(status);
95 }
96 }
97
98 TORCH_CUDA_CPP_API at::ArrayRef<ncclComm_t> get_communicators(
99 at::TensorList inputs);
100 TORCH_CUDA_CPP_API void check_inputs(
101 at::TensorList inputs,
102 at::TensorList outputs,
103 int input_multiplier,
104 int output_multiplier);
105 TORCH_CUDA_CPP_API void check_inputs(
106 at::TensorList inputs,
107 const at::Tensor& output,
108 int root,
109 int input_multiplier,
110 int output_multiplier);
111
112 } // namespace detail
113
114 using comm_list = std::vector<ncclComm_t>;
115 using stream_list = std::vector<std::optional<at::cuda::CUDAStream>>;
116
117 TORCH_CUDA_CPP_API std::uint64_t version();
118 TORCH_CUDA_CPP_API const char* version_suffix();
119
120 bool is_available(at::TensorList tensors);
121
122 TORCH_CUDA_CPP_API void get_unique_id(ncclUniqueId& id);
123 TORCH_CUDA_CPP_API ncclComm_t
124 comm_init_rank(int nranks, const ncclUniqueId& comm_id, int rank);
125 TORCH_CUDA_CPP_API void comm_destroy(ncclComm_t comm);
126
127 TORCH_CUDA_CPP_API void broadcast(
128 at::TensorList tensors,
129 const stream_list& streams = {},
130 const comm_list& user_comms = {});
131
132 size_t get_max_count();
133
134 TORCH_CUDA_CPP_API void reduce(
135 const std::vector<at::Tensor>& inputs,
136 at::Tensor& output,
137 int32_t root = 0,
138 int32_t op = static_cast<int>(ncclRedOp::Sum),
139 const stream_list& streams = {},
140 const comm_list& user_comms = {});
141
142 TORCH_CUDA_CPP_API void reduce(
143 std::vector<at::Tensor>& inputs,
144 int32_t root = 0,
145 int32_t op = static_cast<int>(ncclRedOp::Sum),
146 const stream_list& streams = {},
147 const comm_list& user_comms = {});
148
149 TORCH_CUDA_CPP_API void all_reduce(
150 const std::vector<at::Tensor>& inputs,
151 std::vector<at::Tensor>& outputs,
152 int32_t op = static_cast<int>(ncclRedOp::Sum),
153 const stream_list& streams = {},
154 const comm_list& user_comms = {});
155
156 TORCH_CUDA_CPP_API void reduce_scatter(
157 const std::vector<at::Tensor>& inputs,
158 std::vector<at::Tensor>& outputs,
159 int32_t op = static_cast<int>(ncclRedOp::Sum),
160 const stream_list& streams = {},
161 const comm_list& user_comms = {});
162
163 TORCH_CUDA_CPP_API void scatter(
164 const std::vector<at::Tensor>& inputs,
165 at::Tensor& outputs,
166 ncclComm_t comm,
167 at::cuda::CUDAStream& stream,
168 int32_t root = 0);
169
170 TORCH_CUDA_CPP_API void all_gather(
171 const std::vector<at::Tensor>& inputs,
172 std::vector<at::Tensor>& outputs,
173 const stream_list& streams = {},
174 const comm_list& user_comms = {});
175
176 TORCH_CUDA_CPP_API void gather(
177 const at::Tensor& inputs,
178 std::vector<at::Tensor>& outputs,
179 ncclComm_t comm,
180 at::cuda::CUDAStream& stream,
181 int32_t root = 0);
182
183 TORCH_CUDA_CPP_API void all2all_single_equal_split(
184 at::Tensor& input,
185 at::Tensor& output,
186 int size,
187 ncclComm_t comm,
188 at::cuda::CUDAStream& stream);
189
190 TORCH_CUDA_CPP_API void all2all_single_unequal_split(
191 void* sendbuff,
192 const size_t* sendcounts,
193 const size_t* senddispls,
194 void* recvbuff,
195 const size_t* recvcounts,
196 const size_t* recvdispls,
197 size_t size,
198 c10::ScalarType type,
199 ncclComm_t comm,
200 at::cuda::CUDAStream& stream);
201
202 TORCH_CUDA_CPP_API void all2all(
203 std::vector<at::Tensor>& outputTensors,
204 std::vector<at::Tensor>& inputTensors,
205 ncclComm_t _comm,
206 at::cuda::CUDAStream& stream);
207
208 TORCH_CUDA_CPP_API void send(
209 const at::Tensor& input,
210 ncclComm_t comm,
211 at::cuda::CUDAStream stream,
212 int dst);
213
214 TORCH_CUDA_CPP_API void recv(
215 at::Tensor& output,
216 ncclComm_t comm,
217 at::cuda::CUDAStream stream,
218 int src);
219 } // namespace torch::cuda::nccl
220