xref: /aosp_15_r20/external/pytorch/torch/csrc/cuda/python_nccl.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/cuda/python_nccl.h>
2 
3 #include <ATen/core/functional.h>
4 #include <pybind11/pybind11.h>
5 #include <torch/csrc/DynamicTypes.h>
6 #include <torch/csrc/Exceptions.h>
7 #include <torch/csrc/THP.h>
8 #include <torch/csrc/Types.h>
9 #include <torch/csrc/cuda/THCP.h>
10 #include <torch/csrc/cuda/nccl.h>
11 #include <torch/csrc/utils/pybind.h>
12 
13 #include <c10/cuda/CUDAGuard.h>
14 #include <c10/util/irange.h>
15 
16 using namespace at;
17 using namespace torch;
18 using namespace torch::cuda::nccl;
19 using namespace torch::cuda::nccl::detail;
20 
21 static const char* COMM_CAPSULE_NAME = "torch.cuda.nccl.Communicator";
22 
THCPModule_nccl_version(PyObject * self,PyObject * args)23 PyObject* THCPModule_nccl_version(PyObject* self, PyObject* args) {
24   return PyLong_FromUnsignedLongLong(version());
25 }
26 
THCPModule_nccl_version_suffix(PyObject * self,PyObject * args)27 PyObject* THCPModule_nccl_version_suffix(PyObject* self, PyObject* args) {
28   HANDLE_TH_ERRORS
29   return PyBytes_FromString(version_suffix());
30   END_HANDLE_TH_ERRORS
31 }
32 
THCPModule_nccl_unique_id(PyObject * self,PyObject * args)33 PyObject* THCPModule_nccl_unique_id(PyObject* self, PyObject* args) {
34   HANDLE_TH_ERRORS
35   ncclUniqueId id;
36   get_unique_id(id);
37   return PyBytes_FromStringAndSize((char*)&id, NCCL_UNIQUE_ID_BYTES);
38   END_HANDLE_TH_ERRORS
39 }
40 
unpack_nccl_comm(PyObject * capsule)41 static ncclComm_t unpack_nccl_comm(PyObject* capsule) {
42   ncclComm_t comm =
43       (ncclComm_t)PyCapsule_GetPointer(capsule, COMM_CAPSULE_NAME);
44   if (!comm)
45     throw python_error();
46   return comm;
47 }
48 
destroy_nccl_comm(PyObject * capsule)49 static void destroy_nccl_comm(PyObject* capsule) {
50   HANDLE_TH_ERRORS
51   ncclComm_t comm = unpack_nccl_comm(capsule);
52   {
53     pybind11::gil_scoped_release no_gil;
54     comm_destroy(comm);
55   }
56   END_HANDLE_TH_ERRORS_RET()
57 }
58 
unpack_streams(PyObject * obj,size_t size)59 static std::vector<std::optional<at::cuda::CUDAStream>> unpack_streams(
60     PyObject* obj,
61     size_t size) {
62   if (obj == Py_None) {
63     return std::vector<std::optional<at::cuda::CUDAStream>>(size, std::nullopt);
64   }
65   auto streams = THPUtils_PySequence_to_CUDAStreamList(obj);
66   if (streams.size() != size) {
67     throw std::runtime_error(
68         "number of streams is not equal to number of inputs");
69   }
70   return streams;
71 }
72 
73 static inline at::Tensor extract_tensor(PyObject* obj);
74 static inline std::vector<at::Tensor> extract_tensors(PyObject* obj);
75 
unpack_comms(PyObject * obj,size_t size)76 static std::vector<ncclComm_t> unpack_comms(PyObject* obj, size_t size) {
77   if (obj == Py_None) {
78     return std::vector<ncclComm_t>();
79   }
80   std::vector<ncclComm_t> comms;
81   if (PyCapsule_CheckExact(obj)) {
82     comms = {unpack_nccl_comm(obj)};
83   } else {
84     auto seq = THPObjectPtr(PySequence_Fast(obj, "comm is not a sequence"));
85     if (!seq)
86       throw python_error();
87     auto size = PySequence_Fast_GET_SIZE(seq.get());
88     comms = std::vector<ncclComm_t>(size);
89     for (const auto i : c10::irange(size)) {
90       comms[i] = unpack_nccl_comm(PySequence_Fast_GET_ITEM(seq.get(), i));
91     }
92   }
93   if (comms.size() != size) {
94     throw std::runtime_error(
95         "number of communicators is not equal to number of inputs");
96   }
97   return comms;
98 }
99 
THCPModule_nccl_init_rank(PyObject * self,PyObject * args)100 PyObject* THCPModule_nccl_init_rank(PyObject* self, PyObject* args) {
101   HANDLE_TH_ERRORS
102   int nranks = 0;
103   const char* id = nullptr;
104   Py_ssize_t id_len = 0;
105   int rank = 0;
106 
107   if (!PyArg_ParseTuple(
108           args, "is#i:nccl_init_rank", &nranks, &id, &id_len, &rank)) {
109     return nullptr;
110   }
111   TORCH_CHECK(
112       id_len == NCCL_UNIQUE_ID_BYTES,
113       "invalid unqiue_id (expected ",
114       NCCL_UNIQUE_ID_BYTES,
115       " bytes, got ",
116       id_len,
117       ")");
118 
119   ncclUniqueId commId;
120   memcpy(&commId, id, NCCL_UNIQUE_ID_BYTES);
121   ncclComm_t comm = nullptr;
122   {
123     pybind11::gil_scoped_release no_gil;
124     comm = comm_init_rank(nranks, commId, rank);
125   }
126   return PyCapsule_New(comm, COMM_CAPSULE_NAME, &destroy_nccl_comm);
127   END_HANDLE_TH_ERRORS
128 }
129 
THCPModule_nccl_reduce(PyObject * self,PyObject * args)130 PyObject* THCPModule_nccl_reduce(PyObject* self, PyObject* args) {
131   HANDLE_TH_ERRORS
132   PyObject *_inputs = nullptr, *_output = nullptr, *_streams = nullptr,
133            *_comms = nullptr;
134   int root = 0, op = 0;
135 
136   if (!PyArg_ParseTuple(
137           args, "OOiiOO", &_inputs, &_output, &root, &op, &_streams, &_comms)) {
138     THPUtils_invalidArguments(
139         args,
140         nullptr,
141         "nccl_reduce",
142         1,
143         "(sequence[Tensor] inputs, Tensor output, int root,"
144         " int op, sequence[torch.cuda.Stream or None]");
145     return nullptr;
146   }
147 
148   std::vector<at::Tensor> inputs = extract_tensors(_inputs);
149   auto output = extract_tensor(_output);
150   std::vector<std::optional<at::cuda::CUDAStream>> streams =
151       unpack_streams(_streams, inputs.size());
152   auto user_comms = unpack_comms(_comms, inputs.size());
153 
154   {
155     pybind11::gil_scoped_release no_gil;
156     torch::cuda::nccl::reduce(inputs, output, root, op, streams, user_comms);
157   }
158 
159   Py_RETURN_NONE;
160   END_HANDLE_TH_ERRORS
161 }
162 
THCPModule_nccl_all_reduce(PyObject * self,PyObject * args)163 PyObject* THCPModule_nccl_all_reduce(PyObject* self, PyObject* args) {
164   HANDLE_TH_ERRORS
165   PyObject *_inputs = nullptr, *_outputs = nullptr, *_streams = nullptr,
166            *_comms = nullptr;
167   int op = 0;
168 
169   if (!PyArg_ParseTuple(
170           args, "OOiOO", &_inputs, &_outputs, &op, &_streams, &_comms)) {
171     THPUtils_invalidArguments(
172         args,
173         nullptr,
174         "nccl_all_reduce",
175         1,
176         "(sequence[Tensor] inputs, sequence[Tensor] outputs, int op,"
177         " sequence[torch.cuda.Stream] streams,"
178         " sequence[torch.cuda.nccl.Communicator] comms)");
179     return nullptr;
180   }
181 
182   std::vector<at::Tensor> inputs = extract_tensors(_inputs);
183   std::vector<at::Tensor> outputs = extract_tensors(_outputs);
184   auto streams = unpack_streams(_streams, inputs.size());
185   auto user_comms = unpack_comms(_comms, inputs.size());
186 
187   {
188     pybind11::gil_scoped_release no_gil;
189     all_reduce(inputs, outputs, op, streams, user_comms);
190   }
191 
192   Py_RETURN_NONE;
193   END_HANDLE_TH_ERRORS
194 }
195 
THCPModule_nccl_broadcast(PyObject * self,PyObject * args)196 PyObject* THCPModule_nccl_broadcast(PyObject* self, PyObject* args) {
197   HANDLE_TH_ERRORS
198   PyObject *_inputs = nullptr, *_streams = nullptr, *_comms = nullptr;
199   int root = 0;
200 
201   if (!PyArg_ParseTuple(args, "OiOO", &_inputs, &root, &_streams, &_comms)) {
202     THPUtils_invalidArguments(
203         args,
204         nullptr,
205         "nccl_broadcast",
206         1,
207         "(sequence[Tensor] inputs, int root"
208         " sequence[torch.cuda.Stream] streams,"
209         " sequence[torch.cuda.nccl.Communicator] comms)");
210     return nullptr;
211   }
212 
213   std::vector<at::Tensor> inputs = extract_tensors(_inputs);
214   TORCH_CHECK(root >= 0 && (size_t)root < inputs.size(), "invalid root");
215   auto streams = unpack_streams(_streams, inputs.size());
216   auto user_comms = unpack_comms(_comms, inputs.size());
217 
218   {
219     pybind11::gil_scoped_release no_gil;
220     torch::cuda::nccl::broadcast(inputs, streams, user_comms);
221   }
222 
223   Py_RETURN_NONE;
224   END_HANDLE_TH_ERRORS
225 }
226 
THCPModule_nccl_all_gather(PyObject * self,PyObject * args)227 PyObject* THCPModule_nccl_all_gather(PyObject* self, PyObject* args) {
228   HANDLE_TH_ERRORS
229   PyObject *_inputs = nullptr, *_outputs = nullptr, *_streams = nullptr,
230            *_comms = nullptr;
231 
232   if (!PyArg_ParseTuple(
233           args, "OOOO", &_inputs, &_outputs, &_streams, &_comms)) {
234     THPUtils_invalidArguments(
235         args,
236         nullptr,
237         "nccl_all_gather",
238         1,
239         "(sequence[Tensor] inputs, sequence[Tensor] outputs"
240         " sequence[torch.cuda.Stream] streams,"
241         " sequence[torch.cuda.nccl.Communicator] comms)");
242     return nullptr;
243   }
244 
245   std::vector<at::Tensor> inputs = extract_tensors(_inputs);
246   std::vector<at::Tensor> outputs = extract_tensors(_outputs);
247   auto streams = unpack_streams(_streams, inputs.size());
248   auto user_comms = unpack_comms(_comms, inputs.size());
249 
250   {
251     pybind11::gil_scoped_release no_gil;
252     all_gather(inputs, outputs, streams, user_comms);
253   }
254 
255   Py_RETURN_NONE;
256   END_HANDLE_TH_ERRORS
257 }
258 
THCPModule_nccl_reduce_scatter(PyObject * self,PyObject * args)259 PyObject* THCPModule_nccl_reduce_scatter(PyObject* self, PyObject* args) {
260   HANDLE_TH_ERRORS
261   PyObject *_inputs = nullptr, *_outputs = nullptr, *_streams = nullptr,
262            *_comms = nullptr;
263   int op = 0;
264 
265   if (!PyArg_ParseTuple(
266           args, "OOiOO", &_inputs, &_outputs, &op, &_streams, &_comms)) {
267     THPUtils_invalidArguments(
268         args,
269         nullptr,
270         "nccl_reduce_scatter",
271         1,
272         "(sequence[Tensor] inputs, sequence[Tensor] outputs, int op"
273         " sequence[torch.cuda.Stream] streams,"
274         " sequence[torch.cuda.nccl.Communicator] comms)");
275     return nullptr;
276   }
277 
278   std::vector<at::Tensor> inputs = extract_tensors(_inputs);
279   std::vector<at::Tensor> outputs = extract_tensors(_outputs);
280   auto streams = unpack_streams(_streams, inputs.size());
281   auto user_comms = unpack_comms(_comms, inputs.size());
282 
283   {
284     pybind11::gil_scoped_release no_gil;
285     reduce_scatter(inputs, outputs, op, streams, user_comms);
286   }
287 
288   Py_RETURN_NONE;
289   END_HANDLE_TH_ERRORS
290 }
291 
extract_tensor(PyObject * obj)292 static inline at::Tensor extract_tensor(PyObject* obj) {
293   TORCH_CHECK_TYPE(
294       THPVariable_Check(obj),
295       "expected Tensor (got ",
296       Py_TYPE(obj)->tp_name,
297       ")");
298   return THPVariable_Unpack(obj);
299 }
300 
extract_tensors(PyObject * obj)301 static inline std::vector<at::Tensor> extract_tensors(PyObject* obj) {
302   auto seq = THPObjectPtr(PySequence_Fast(obj, "expected a sequence"));
303   if (!seq)
304     throw python_error();
305 
306   const Py_ssize_t length = PySequence_Fast_GET_SIZE(seq.get());
307   std::vector<at::Tensor> list;
308   if (length >= 0) {
309     list.reserve(length);
310   }
311   for (Py_ssize_t i = 0; i < length; i++) {
312     PyObject* item = PySequence_Fast_GET_ITEM(seq.get(), i);
313     TORCH_CHECK_TYPE(
314         THPVariable_Check(item),
315         "expected Tensor at ",
316         i,
317         " (got ",
318         Py_TYPE(item)->tp_name,
319         ")");
320     list.emplace_back(THPVariable_Unpack(item));
321   }
322   return list;
323 }
324