1 #include <torch/csrc/cuda/python_nccl.h>
2
3 #include <ATen/core/functional.h>
4 #include <pybind11/pybind11.h>
5 #include <torch/csrc/DynamicTypes.h>
6 #include <torch/csrc/Exceptions.h>
7 #include <torch/csrc/THP.h>
8 #include <torch/csrc/Types.h>
9 #include <torch/csrc/cuda/THCP.h>
10 #include <torch/csrc/cuda/nccl.h>
11 #include <torch/csrc/utils/pybind.h>
12
13 #include <c10/cuda/CUDAGuard.h>
14 #include <c10/util/irange.h>
15
16 using namespace at;
17 using namespace torch;
18 using namespace torch::cuda::nccl;
19 using namespace torch::cuda::nccl::detail;
20
21 static const char* COMM_CAPSULE_NAME = "torch.cuda.nccl.Communicator";
22
THCPModule_nccl_version(PyObject * self,PyObject * args)23 PyObject* THCPModule_nccl_version(PyObject* self, PyObject* args) {
24 return PyLong_FromUnsignedLongLong(version());
25 }
26
THCPModule_nccl_version_suffix(PyObject * self,PyObject * args)27 PyObject* THCPModule_nccl_version_suffix(PyObject* self, PyObject* args) {
28 HANDLE_TH_ERRORS
29 return PyBytes_FromString(version_suffix());
30 END_HANDLE_TH_ERRORS
31 }
32
THCPModule_nccl_unique_id(PyObject * self,PyObject * args)33 PyObject* THCPModule_nccl_unique_id(PyObject* self, PyObject* args) {
34 HANDLE_TH_ERRORS
35 ncclUniqueId id;
36 get_unique_id(id);
37 return PyBytes_FromStringAndSize((char*)&id, NCCL_UNIQUE_ID_BYTES);
38 END_HANDLE_TH_ERRORS
39 }
40
unpack_nccl_comm(PyObject * capsule)41 static ncclComm_t unpack_nccl_comm(PyObject* capsule) {
42 ncclComm_t comm =
43 (ncclComm_t)PyCapsule_GetPointer(capsule, COMM_CAPSULE_NAME);
44 if (!comm)
45 throw python_error();
46 return comm;
47 }
48
destroy_nccl_comm(PyObject * capsule)49 static void destroy_nccl_comm(PyObject* capsule) {
50 HANDLE_TH_ERRORS
51 ncclComm_t comm = unpack_nccl_comm(capsule);
52 {
53 pybind11::gil_scoped_release no_gil;
54 comm_destroy(comm);
55 }
56 END_HANDLE_TH_ERRORS_RET()
57 }
58
unpack_streams(PyObject * obj,size_t size)59 static std::vector<std::optional<at::cuda::CUDAStream>> unpack_streams(
60 PyObject* obj,
61 size_t size) {
62 if (obj == Py_None) {
63 return std::vector<std::optional<at::cuda::CUDAStream>>(size, std::nullopt);
64 }
65 auto streams = THPUtils_PySequence_to_CUDAStreamList(obj);
66 if (streams.size() != size) {
67 throw std::runtime_error(
68 "number of streams is not equal to number of inputs");
69 }
70 return streams;
71 }
72
73 static inline at::Tensor extract_tensor(PyObject* obj);
74 static inline std::vector<at::Tensor> extract_tensors(PyObject* obj);
75
unpack_comms(PyObject * obj,size_t size)76 static std::vector<ncclComm_t> unpack_comms(PyObject* obj, size_t size) {
77 if (obj == Py_None) {
78 return std::vector<ncclComm_t>();
79 }
80 std::vector<ncclComm_t> comms;
81 if (PyCapsule_CheckExact(obj)) {
82 comms = {unpack_nccl_comm(obj)};
83 } else {
84 auto seq = THPObjectPtr(PySequence_Fast(obj, "comm is not a sequence"));
85 if (!seq)
86 throw python_error();
87 auto size = PySequence_Fast_GET_SIZE(seq.get());
88 comms = std::vector<ncclComm_t>(size);
89 for (const auto i : c10::irange(size)) {
90 comms[i] = unpack_nccl_comm(PySequence_Fast_GET_ITEM(seq.get(), i));
91 }
92 }
93 if (comms.size() != size) {
94 throw std::runtime_error(
95 "number of communicators is not equal to number of inputs");
96 }
97 return comms;
98 }
99
THCPModule_nccl_init_rank(PyObject * self,PyObject * args)100 PyObject* THCPModule_nccl_init_rank(PyObject* self, PyObject* args) {
101 HANDLE_TH_ERRORS
102 int nranks = 0;
103 const char* id = nullptr;
104 Py_ssize_t id_len = 0;
105 int rank = 0;
106
107 if (!PyArg_ParseTuple(
108 args, "is#i:nccl_init_rank", &nranks, &id, &id_len, &rank)) {
109 return nullptr;
110 }
111 TORCH_CHECK(
112 id_len == NCCL_UNIQUE_ID_BYTES,
113 "invalid unqiue_id (expected ",
114 NCCL_UNIQUE_ID_BYTES,
115 " bytes, got ",
116 id_len,
117 ")");
118
119 ncclUniqueId commId;
120 memcpy(&commId, id, NCCL_UNIQUE_ID_BYTES);
121 ncclComm_t comm = nullptr;
122 {
123 pybind11::gil_scoped_release no_gil;
124 comm = comm_init_rank(nranks, commId, rank);
125 }
126 return PyCapsule_New(comm, COMM_CAPSULE_NAME, &destroy_nccl_comm);
127 END_HANDLE_TH_ERRORS
128 }
129
THCPModule_nccl_reduce(PyObject * self,PyObject * args)130 PyObject* THCPModule_nccl_reduce(PyObject* self, PyObject* args) {
131 HANDLE_TH_ERRORS
132 PyObject *_inputs = nullptr, *_output = nullptr, *_streams = nullptr,
133 *_comms = nullptr;
134 int root = 0, op = 0;
135
136 if (!PyArg_ParseTuple(
137 args, "OOiiOO", &_inputs, &_output, &root, &op, &_streams, &_comms)) {
138 THPUtils_invalidArguments(
139 args,
140 nullptr,
141 "nccl_reduce",
142 1,
143 "(sequence[Tensor] inputs, Tensor output, int root,"
144 " int op, sequence[torch.cuda.Stream or None]");
145 return nullptr;
146 }
147
148 std::vector<at::Tensor> inputs = extract_tensors(_inputs);
149 auto output = extract_tensor(_output);
150 std::vector<std::optional<at::cuda::CUDAStream>> streams =
151 unpack_streams(_streams, inputs.size());
152 auto user_comms = unpack_comms(_comms, inputs.size());
153
154 {
155 pybind11::gil_scoped_release no_gil;
156 torch::cuda::nccl::reduce(inputs, output, root, op, streams, user_comms);
157 }
158
159 Py_RETURN_NONE;
160 END_HANDLE_TH_ERRORS
161 }
162
THCPModule_nccl_all_reduce(PyObject * self,PyObject * args)163 PyObject* THCPModule_nccl_all_reduce(PyObject* self, PyObject* args) {
164 HANDLE_TH_ERRORS
165 PyObject *_inputs = nullptr, *_outputs = nullptr, *_streams = nullptr,
166 *_comms = nullptr;
167 int op = 0;
168
169 if (!PyArg_ParseTuple(
170 args, "OOiOO", &_inputs, &_outputs, &op, &_streams, &_comms)) {
171 THPUtils_invalidArguments(
172 args,
173 nullptr,
174 "nccl_all_reduce",
175 1,
176 "(sequence[Tensor] inputs, sequence[Tensor] outputs, int op,"
177 " sequence[torch.cuda.Stream] streams,"
178 " sequence[torch.cuda.nccl.Communicator] comms)");
179 return nullptr;
180 }
181
182 std::vector<at::Tensor> inputs = extract_tensors(_inputs);
183 std::vector<at::Tensor> outputs = extract_tensors(_outputs);
184 auto streams = unpack_streams(_streams, inputs.size());
185 auto user_comms = unpack_comms(_comms, inputs.size());
186
187 {
188 pybind11::gil_scoped_release no_gil;
189 all_reduce(inputs, outputs, op, streams, user_comms);
190 }
191
192 Py_RETURN_NONE;
193 END_HANDLE_TH_ERRORS
194 }
195
THCPModule_nccl_broadcast(PyObject * self,PyObject * args)196 PyObject* THCPModule_nccl_broadcast(PyObject* self, PyObject* args) {
197 HANDLE_TH_ERRORS
198 PyObject *_inputs = nullptr, *_streams = nullptr, *_comms = nullptr;
199 int root = 0;
200
201 if (!PyArg_ParseTuple(args, "OiOO", &_inputs, &root, &_streams, &_comms)) {
202 THPUtils_invalidArguments(
203 args,
204 nullptr,
205 "nccl_broadcast",
206 1,
207 "(sequence[Tensor] inputs, int root"
208 " sequence[torch.cuda.Stream] streams,"
209 " sequence[torch.cuda.nccl.Communicator] comms)");
210 return nullptr;
211 }
212
213 std::vector<at::Tensor> inputs = extract_tensors(_inputs);
214 TORCH_CHECK(root >= 0 && (size_t)root < inputs.size(), "invalid root");
215 auto streams = unpack_streams(_streams, inputs.size());
216 auto user_comms = unpack_comms(_comms, inputs.size());
217
218 {
219 pybind11::gil_scoped_release no_gil;
220 torch::cuda::nccl::broadcast(inputs, streams, user_comms);
221 }
222
223 Py_RETURN_NONE;
224 END_HANDLE_TH_ERRORS
225 }
226
THCPModule_nccl_all_gather(PyObject * self,PyObject * args)227 PyObject* THCPModule_nccl_all_gather(PyObject* self, PyObject* args) {
228 HANDLE_TH_ERRORS
229 PyObject *_inputs = nullptr, *_outputs = nullptr, *_streams = nullptr,
230 *_comms = nullptr;
231
232 if (!PyArg_ParseTuple(
233 args, "OOOO", &_inputs, &_outputs, &_streams, &_comms)) {
234 THPUtils_invalidArguments(
235 args,
236 nullptr,
237 "nccl_all_gather",
238 1,
239 "(sequence[Tensor] inputs, sequence[Tensor] outputs"
240 " sequence[torch.cuda.Stream] streams,"
241 " sequence[torch.cuda.nccl.Communicator] comms)");
242 return nullptr;
243 }
244
245 std::vector<at::Tensor> inputs = extract_tensors(_inputs);
246 std::vector<at::Tensor> outputs = extract_tensors(_outputs);
247 auto streams = unpack_streams(_streams, inputs.size());
248 auto user_comms = unpack_comms(_comms, inputs.size());
249
250 {
251 pybind11::gil_scoped_release no_gil;
252 all_gather(inputs, outputs, streams, user_comms);
253 }
254
255 Py_RETURN_NONE;
256 END_HANDLE_TH_ERRORS
257 }
258
THCPModule_nccl_reduce_scatter(PyObject * self,PyObject * args)259 PyObject* THCPModule_nccl_reduce_scatter(PyObject* self, PyObject* args) {
260 HANDLE_TH_ERRORS
261 PyObject *_inputs = nullptr, *_outputs = nullptr, *_streams = nullptr,
262 *_comms = nullptr;
263 int op = 0;
264
265 if (!PyArg_ParseTuple(
266 args, "OOiOO", &_inputs, &_outputs, &op, &_streams, &_comms)) {
267 THPUtils_invalidArguments(
268 args,
269 nullptr,
270 "nccl_reduce_scatter",
271 1,
272 "(sequence[Tensor] inputs, sequence[Tensor] outputs, int op"
273 " sequence[torch.cuda.Stream] streams,"
274 " sequence[torch.cuda.nccl.Communicator] comms)");
275 return nullptr;
276 }
277
278 std::vector<at::Tensor> inputs = extract_tensors(_inputs);
279 std::vector<at::Tensor> outputs = extract_tensors(_outputs);
280 auto streams = unpack_streams(_streams, inputs.size());
281 auto user_comms = unpack_comms(_comms, inputs.size());
282
283 {
284 pybind11::gil_scoped_release no_gil;
285 reduce_scatter(inputs, outputs, op, streams, user_comms);
286 }
287
288 Py_RETURN_NONE;
289 END_HANDLE_TH_ERRORS
290 }
291
extract_tensor(PyObject * obj)292 static inline at::Tensor extract_tensor(PyObject* obj) {
293 TORCH_CHECK_TYPE(
294 THPVariable_Check(obj),
295 "expected Tensor (got ",
296 Py_TYPE(obj)->tp_name,
297 ")");
298 return THPVariable_Unpack(obj);
299 }
300
extract_tensors(PyObject * obj)301 static inline std::vector<at::Tensor> extract_tensors(PyObject* obj) {
302 auto seq = THPObjectPtr(PySequence_Fast(obj, "expected a sequence"));
303 if (!seq)
304 throw python_error();
305
306 const Py_ssize_t length = PySequence_Fast_GET_SIZE(seq.get());
307 std::vector<at::Tensor> list;
308 if (length >= 0) {
309 list.reserve(length);
310 }
311 for (Py_ssize_t i = 0; i < length; i++) {
312 PyObject* item = PySequence_Fast_GET_ITEM(seq.get(), i);
313 TORCH_CHECK_TYPE(
314 THPVariable_Check(item),
315 "expected Tensor at ",
316 i,
317 " (got ",
318 Py_TYPE(item)->tp_name,
319 ")");
320 list.emplace_back(THPVariable_Unpack(item));
321 }
322 return list;
323 }
324