1 #include <pybind11/pybind11.h>
2 #include <torch/csrc/Device.h>
3 #include <torch/csrc/THP.h>
4 #include <torch/csrc/utils/pybind.h>
5 #include <torch/csrc/utils/python_numbers.h>
6 #include <torch/csrc/xpu/Module.h>
7 #include <torch/csrc/xpu/Stream.h>
8
9 #include <structmember.h>
10
11 PyObject* THXPStreamClass = nullptr;
12
THXPStream_pynew(PyTypeObject * type,PyObject * args,PyObject * kwargs)13 static PyObject* THXPStream_pynew(
14 PyTypeObject* type,
15 PyObject* args,
16 PyObject* kwargs) {
17 HANDLE_TH_ERRORS
18
19 const auto current_device = c10::xpu::current_device();
20
21 int32_t priority = 0;
22 int64_t stream_id = 0;
23 int64_t device_index = 0;
24 int64_t device_type = 0;
25
26 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
27 constexpr const char* kwlist[] = {
28 "priority", "stream_id", "device_index", "device_type", nullptr};
29 if (!PyArg_ParseTupleAndKeywords(
30 args,
31 kwargs,
32 "|iLLL",
33 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
34 const_cast<char**>(kwlist),
35 &priority,
36 &stream_id,
37 &device_index,
38 &device_type)) {
39 return nullptr;
40 }
41
42 THPObjectPtr ptr(type->tp_alloc(type, 0));
43 if (!ptr) {
44 return nullptr;
45 }
46
47 at::xpu::XPUStream stream = (stream_id || device_index || device_type)
48 ? at::xpu::XPUStream::unpack3(
49 stream_id,
50 static_cast<c10::DeviceIndex>(device_index),
51 static_cast<c10::DeviceType>(device_type))
52 : at::xpu::getStreamFromPool(priority, current_device);
53
54 THXPStream* self = (THXPStream*)ptr.get();
55 self->stream_id = static_cast<int64_t>(stream.id());
56 // NOLINTNEXTLINE(bugprone-signed-char-misuse)
57 self->device_index = static_cast<int64_t>(stream.device_index());
58 self->device_type = static_cast<int64_t>(stream.device_type());
59 new (&self->xpu_stream) at::xpu::XPUStream(stream);
60
61 return (PyObject*)ptr.release();
62 END_HANDLE_TH_ERRORS
63 }
64
THXPStream_dealloc(THXPStream * self)65 static void THXPStream_dealloc(THXPStream* self) {
66 self->xpu_stream.~XPUStream();
67 Py_TYPE(self)->tp_free((PyObject*)self);
68 }
69
THXPStream_get_sycl_queue(THXPStream * self,void * unused)70 static PyObject* THXPStream_get_sycl_queue(THXPStream* self, void* unused) {
71 HANDLE_TH_ERRORS
72 return PyLong_FromVoidPtr(&self->xpu_stream.queue());
73 END_HANDLE_TH_ERRORS
74 }
75
THXPStream_get_priority(THXPStream * self,void * unused)76 static PyObject* THXPStream_get_priority(THXPStream* self, void* unused) {
77 HANDLE_TH_ERRORS
78 return THPUtils_packInt64(self->xpu_stream.priority());
79 END_HANDLE_TH_ERRORS
80 }
81
THXPStream_priority_range(PyObject * _unused,PyObject * noargs)82 static PyObject* THXPStream_priority_range(
83 PyObject* _unused,
84 PyObject* noargs) {
85 HANDLE_TH_ERRORS
86 auto [least_priority, greatest_priority] =
87 at::xpu::XPUStream::priority_range();
88 return Py_BuildValue("(ii)", least_priority, greatest_priority);
89 END_HANDLE_TH_ERRORS
90 }
91
THXPStream_query(PyObject * _self,PyObject * noargs)92 static PyObject* THXPStream_query(PyObject* _self, PyObject* noargs) {
93 HANDLE_TH_ERRORS
94 auto* self = (THXPStream*)_self;
95 return PyBool_FromLong(self->xpu_stream.query());
96 END_HANDLE_TH_ERRORS
97 }
98
THXPStream_synchronize(PyObject * _self,PyObject * noargs)99 static PyObject* THXPStream_synchronize(PyObject* _self, PyObject* noargs) {
100 HANDLE_TH_ERRORS {
101 pybind11::gil_scoped_release no_gil;
102 auto* self = (THXPStream*)_self;
103 self->xpu_stream.synchronize();
104 }
105 Py_RETURN_NONE;
106 END_HANDLE_TH_ERRORS
107 }
108
THXPStream_eq(PyObject * _self,PyObject * _other)109 static PyObject* THXPStream_eq(PyObject* _self, PyObject* _other) {
110 HANDLE_TH_ERRORS
111 auto* self = (THXPStream*)_self;
112 auto* other = (THXPStream*)_other;
113 return PyBool_FromLong(self->xpu_stream == other->xpu_stream);
114 END_HANDLE_TH_ERRORS
115 }
116
117 // NOLINTNEXTLINE(*-c-arrays*, *-global-variables)
118 static struct PyMemberDef THXPStream_members[] = {{nullptr}};
119
120 // NOLINTNEXTLINE(*-c-arrays*, *-global-variables)
121 static struct PyGetSetDef THXPStream_properties[] = {
122 {"sycl_queue",
123 (getter)THXPStream_get_sycl_queue,
124 nullptr,
125 nullptr,
126 nullptr},
127 {"priority", (getter)THXPStream_get_priority, nullptr, nullptr, nullptr},
128 {nullptr}};
129
130 // NOLINTNEXTLINE(*-c-arrays*, *-global-variables)
131 static PyMethodDef THXPStream_methods[] = {
132 {"query", THXPStream_query, METH_NOARGS, nullptr},
133 {"synchronize", THXPStream_synchronize, METH_NOARGS, nullptr},
134 {"priority_range",
135 THXPStream_priority_range,
136 METH_STATIC | METH_NOARGS,
137 nullptr},
138 {"__eq__", THXPStream_eq, METH_O, nullptr},
139 {nullptr}};
140
141 PyTypeObject THXPStreamType = {
142 PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._XpuStreamBase", /* tp_name */
143 sizeof(THXPStream), /* tp_basicsize */
144 0, /* tp_itemsize */
145 (destructor)THXPStream_dealloc, /* tp_dealloc */
146 0, /* tp_vectorcall_offset */
147 nullptr, /* tp_getattr */
148 nullptr, /* tp_setattr */
149 nullptr, /* tp_reserved */
150 nullptr, /* tp_repr */
151 nullptr, /* tp_as_number */
152 nullptr, /* tp_as_sequence */
153 nullptr, /* tp_as_mapping */
154 nullptr, /* tp_hash */
155 nullptr, /* tp_call */
156 nullptr, /* tp_str */
157 nullptr, /* tp_getattro */
158 nullptr, /* tp_setattro */
159 nullptr, /* tp_as_buffer */
160 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
161 nullptr, /* tp_doc */
162 nullptr, /* tp_traverse */
163 nullptr, /* tp_clear */
164 nullptr, /* tp_richcompare */
165 0, /* tp_weaklistoffset */
166 nullptr, /* tp_iter */
167 nullptr, /* tp_iternext */
168 THXPStream_methods, /* tp_methods */
169 THXPStream_members, /* tp_members */
170 THXPStream_properties, /* tp_getset */
171 nullptr, /* tp_base */
172 nullptr, /* tp_dict */
173 nullptr, /* tp_descr_get */
174 nullptr, /* tp_descr_set */
175 0, /* tp_dictoffset */
176 nullptr, /* tp_init */
177 nullptr, /* tp_alloc */
178 THXPStream_pynew, /* tp_new */
179 };
180
THXPStream_init(PyObject * module)181 void THXPStream_init(PyObject* module) {
182 Py_INCREF(THPStreamClass);
183 THXPStreamType.tp_base = THPStreamClass;
184 THXPStreamClass = (PyObject*)&THXPStreamType;
185 if (PyType_Ready(&THXPStreamType) < 0) {
186 throw python_error();
187 }
188 Py_INCREF(&THXPStreamType);
189 if (PyModule_AddObject(module, "_XpuStreamBase", (PyObject*)&THXPStreamType) <
190 0) {
191 throw python_error();
192 }
193 }
194