xref: /aosp_15_r20/external/pytorch/torch/csrc/xpu/Stream.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <pybind11/pybind11.h>
2 #include <torch/csrc/Device.h>
3 #include <torch/csrc/THP.h>
4 #include <torch/csrc/utils/pybind.h>
5 #include <torch/csrc/utils/python_numbers.h>
6 #include <torch/csrc/xpu/Module.h>
7 #include <torch/csrc/xpu/Stream.h>
8 
9 #include <structmember.h>
10 
11 PyObject* THXPStreamClass = nullptr;
12 
THXPStream_pynew(PyTypeObject * type,PyObject * args,PyObject * kwargs)13 static PyObject* THXPStream_pynew(
14     PyTypeObject* type,
15     PyObject* args,
16     PyObject* kwargs) {
17   HANDLE_TH_ERRORS
18 
19   const auto current_device = c10::xpu::current_device();
20 
21   int32_t priority = 0;
22   int64_t stream_id = 0;
23   int64_t device_index = 0;
24   int64_t device_type = 0;
25 
26   // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
27   constexpr const char* kwlist[] = {
28       "priority", "stream_id", "device_index", "device_type", nullptr};
29   if (!PyArg_ParseTupleAndKeywords(
30           args,
31           kwargs,
32           "|iLLL",
33           // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
34           const_cast<char**>(kwlist),
35           &priority,
36           &stream_id,
37           &device_index,
38           &device_type)) {
39     return nullptr;
40   }
41 
42   THPObjectPtr ptr(type->tp_alloc(type, 0));
43   if (!ptr) {
44     return nullptr;
45   }
46 
47   at::xpu::XPUStream stream = (stream_id || device_index || device_type)
48       ? at::xpu::XPUStream::unpack3(
49             stream_id,
50             static_cast<c10::DeviceIndex>(device_index),
51             static_cast<c10::DeviceType>(device_type))
52       : at::xpu::getStreamFromPool(priority, current_device);
53 
54   THXPStream* self = (THXPStream*)ptr.get();
55   self->stream_id = static_cast<int64_t>(stream.id());
56   // NOLINTNEXTLINE(bugprone-signed-char-misuse)
57   self->device_index = static_cast<int64_t>(stream.device_index());
58   self->device_type = static_cast<int64_t>(stream.device_type());
59   new (&self->xpu_stream) at::xpu::XPUStream(stream);
60 
61   return (PyObject*)ptr.release();
62   END_HANDLE_TH_ERRORS
63 }
64 
THXPStream_dealloc(THXPStream * self)65 static void THXPStream_dealloc(THXPStream* self) {
66   self->xpu_stream.~XPUStream();
67   Py_TYPE(self)->tp_free((PyObject*)self);
68 }
69 
THXPStream_get_sycl_queue(THXPStream * self,void * unused)70 static PyObject* THXPStream_get_sycl_queue(THXPStream* self, void* unused) {
71   HANDLE_TH_ERRORS
72   return PyLong_FromVoidPtr(&self->xpu_stream.queue());
73   END_HANDLE_TH_ERRORS
74 }
75 
THXPStream_get_priority(THXPStream * self,void * unused)76 static PyObject* THXPStream_get_priority(THXPStream* self, void* unused) {
77   HANDLE_TH_ERRORS
78   return THPUtils_packInt64(self->xpu_stream.priority());
79   END_HANDLE_TH_ERRORS
80 }
81 
THXPStream_priority_range(PyObject * _unused,PyObject * noargs)82 static PyObject* THXPStream_priority_range(
83     PyObject* _unused,
84     PyObject* noargs) {
85   HANDLE_TH_ERRORS
86   auto [least_priority, greatest_priority] =
87       at::xpu::XPUStream::priority_range();
88   return Py_BuildValue("(ii)", least_priority, greatest_priority);
89   END_HANDLE_TH_ERRORS
90 }
91 
THXPStream_query(PyObject * _self,PyObject * noargs)92 static PyObject* THXPStream_query(PyObject* _self, PyObject* noargs) {
93   HANDLE_TH_ERRORS
94   auto* self = (THXPStream*)_self;
95   return PyBool_FromLong(self->xpu_stream.query());
96   END_HANDLE_TH_ERRORS
97 }
98 
THXPStream_synchronize(PyObject * _self,PyObject * noargs)99 static PyObject* THXPStream_synchronize(PyObject* _self, PyObject* noargs) {
100   HANDLE_TH_ERRORS {
101     pybind11::gil_scoped_release no_gil;
102     auto* self = (THXPStream*)_self;
103     self->xpu_stream.synchronize();
104   }
105   Py_RETURN_NONE;
106   END_HANDLE_TH_ERRORS
107 }
108 
THXPStream_eq(PyObject * _self,PyObject * _other)109 static PyObject* THXPStream_eq(PyObject* _self, PyObject* _other) {
110   HANDLE_TH_ERRORS
111   auto* self = (THXPStream*)_self;
112   auto* other = (THXPStream*)_other;
113   return PyBool_FromLong(self->xpu_stream == other->xpu_stream);
114   END_HANDLE_TH_ERRORS
115 }
116 
117 // NOLINTNEXTLINE(*-c-arrays*, *-global-variables)
118 static struct PyMemberDef THXPStream_members[] = {{nullptr}};
119 
120 // NOLINTNEXTLINE(*-c-arrays*, *-global-variables)
121 static struct PyGetSetDef THXPStream_properties[] = {
122     {"sycl_queue",
123      (getter)THXPStream_get_sycl_queue,
124      nullptr,
125      nullptr,
126      nullptr},
127     {"priority", (getter)THXPStream_get_priority, nullptr, nullptr, nullptr},
128     {nullptr}};
129 
130 // NOLINTNEXTLINE(*-c-arrays*, *-global-variables)
131 static PyMethodDef THXPStream_methods[] = {
132     {"query", THXPStream_query, METH_NOARGS, nullptr},
133     {"synchronize", THXPStream_synchronize, METH_NOARGS, nullptr},
134     {"priority_range",
135      THXPStream_priority_range,
136      METH_STATIC | METH_NOARGS,
137      nullptr},
138     {"__eq__", THXPStream_eq, METH_O, nullptr},
139     {nullptr}};
140 
141 PyTypeObject THXPStreamType = {
142     PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._XpuStreamBase", /* tp_name */
143     sizeof(THXPStream), /* tp_basicsize */
144     0, /* tp_itemsize */
145     (destructor)THXPStream_dealloc, /* tp_dealloc */
146     0, /* tp_vectorcall_offset */
147     nullptr, /* tp_getattr */
148     nullptr, /* tp_setattr */
149     nullptr, /* tp_reserved */
150     nullptr, /* tp_repr */
151     nullptr, /* tp_as_number */
152     nullptr, /* tp_as_sequence */
153     nullptr, /* tp_as_mapping */
154     nullptr, /* tp_hash  */
155     nullptr, /* tp_call */
156     nullptr, /* tp_str */
157     nullptr, /* tp_getattro */
158     nullptr, /* tp_setattro */
159     nullptr, /* tp_as_buffer */
160     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
161     nullptr, /* tp_doc */
162     nullptr, /* tp_traverse */
163     nullptr, /* tp_clear */
164     nullptr, /* tp_richcompare */
165     0, /* tp_weaklistoffset */
166     nullptr, /* tp_iter */
167     nullptr, /* tp_iternext */
168     THXPStream_methods, /* tp_methods */
169     THXPStream_members, /* tp_members */
170     THXPStream_properties, /* tp_getset */
171     nullptr, /* tp_base */
172     nullptr, /* tp_dict */
173     nullptr, /* tp_descr_get */
174     nullptr, /* tp_descr_set */
175     0, /* tp_dictoffset */
176     nullptr, /* tp_init */
177     nullptr, /* tp_alloc */
178     THXPStream_pynew, /* tp_new */
179 };
180 
THXPStream_init(PyObject * module)181 void THXPStream_init(PyObject* module) {
182   Py_INCREF(THPStreamClass);
183   THXPStreamType.tp_base = THPStreamClass;
184   THXPStreamClass = (PyObject*)&THXPStreamType;
185   if (PyType_Ready(&THXPStreamType) < 0) {
186     throw python_error();
187   }
188   Py_INCREF(&THXPStreamType);
189   if (PyModule_AddObject(module, "_XpuStreamBase", (PyObject*)&THXPStreamType) <
190       0) {
191     throw python_error();
192   }
193 }
194