xref: /aosp_15_r20/external/pytorch/torch/csrc/cuda/python_comm.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/functional.h>
2 #include <pybind11/pybind11.h>
3 #include <torch/csrc/cuda/Stream.h>
4 #include <torch/csrc/cuda/THCP.h>
5 #include <torch/csrc/cuda/comm.h>
6 #include <torch/csrc/utils/pybind.h>
7 
8 #include <ATen/ATen.h>
9 
10 #include <cstddef>
11 #include <vector>
12 
13 #include <torch/csrc/profiler/unwind/unwind.h>
14 
15 namespace torch::cuda::python {
initCommMethods(PyObject * module)16 void initCommMethods(PyObject* module) {
17   auto m = py::cast<py::module>(module);
18   m.def(
19        "_broadcast_coalesced",
20        [](std::vector<at::Tensor>& tensors,
21           const std::vector<int64_t>& devices,
22           size_t buffer_size) {
23          return broadcast_coalesced(tensors, devices, buffer_size);
24        },
25        py::arg("tensors"),
26        py::arg("devices"),
27        py::arg("buffer_size"),
28        py::call_guard<py::gil_scoped_release>())
29       .def(
30           "_broadcast",
31           [](at::Tensor& tensor, std::vector<int64_t> devices) {
32             return broadcast(tensor, devices);
33           },
34           py::call_guard<py::gil_scoped_release>(),
35           py::arg("tensor"),
36           py::arg("devices"))
37       .def(
38           "_broadcast_out",
39           [](at::Tensor& tensor, std::vector<at::Tensor>& out_tensors) {
40             return broadcast_out(tensor, out_tensors);
41           },
42           py::call_guard<py::gil_scoped_release>(),
43           py::arg("tensor"),
44           py::arg("out"))
45       .def(
46           "_scatter",
47           [](at::Tensor& tensor,
48              std::vector<int64_t>& devices,
49              std::optional<std::vector<int64_t>> chunk_sizes,
50              int64_t dim,
51              std::optional<py::object> py_streams) {
52             std::optional<std::vector<std::optional<at::cuda::CUDAStream>>>
53                 streams;
54             if (py_streams) {
55               py::handle handle = *py_streams;
56               streams = THPUtils_PySequence_to_CUDAStreamList(handle.ptr());
57             }
58             // Note: We're holding the GIL up to here.
59             pybind11::gil_scoped_release no_gil;
60             return scatter(tensor, devices, chunk_sizes, dim, streams);
61           },
62           py::arg("tensor"),
63           py::arg("devices"),
64           py::arg("chunk_sizes"),
65           py::arg("dim"),
66           py::arg("streams"))
67       .def(
68           "_scatter_out",
69           [](at::Tensor& tensor,
70              std::vector<at::Tensor>& out_tensors,
71              int64_t dim,
72              std::optional<py::object> py_streams) {
73             std::optional<std::vector<std::optional<at::cuda::CUDAStream>>>
74                 streams;
75             if (py_streams) {
76               py::handle handle = *py_streams;
77               streams = THPUtils_PySequence_to_CUDAStreamList(handle.ptr());
78             }
79             // Note: We're holding the GIL up to here.
80             pybind11::gil_scoped_release no_gil;
81             return scatter_out(tensor, out_tensors, dim, streams);
82           },
83           py::arg("tensor"),
84           py::arg("out"),
85           py::arg("dim"),
86           py::arg("streams"))
87       .def(
88           "_gather",
89           [](std::vector<at::Tensor>& tensors,
90              int64_t dim,
91              std::optional<int32_t> destination_index) {
92             return gather(tensors, dim, destination_index);
93           },
94           py::arg("tensors"),
95           py::arg("dim"),
96           py::arg("destination_index"),
97           py::call_guard<py::gil_scoped_release>())
98       .def(
99           "_gather_out",
100           [](std::vector<at::Tensor>& tensors,
101              at::Tensor& out_tensor,
102              int64_t dim) { return gather_out(tensors, out_tensor, dim); },
103           py::arg("tensors"),
104           py::arg("out"),
105           py::arg("dim"),
106           py::call_guard<py::gil_scoped_release>());
107 }
108 } // namespace torch::cuda::python
109