1 #include <ATen/ATen.h>
2 #include <ATen/xpu/XPUContext.h>
3 #include <ATen/xpu/XPUGeneratorImpl.h>
4 #include <c10/util/CallOnce.h>
5 #include <c10/xpu/XPUCachingAllocator.h>
6 #include <c10/xpu/XPUFunctions.h>
7 #include <torch/csrc/Module.h>
8 #include <torch/csrc/THP.h>
9 #include <torch/csrc/utils/device_lazy_init.h>
10 #include <torch/csrc/utils/pycfunction_helpers.h>
11 #include <torch/csrc/utils/python_numbers.h>
12 #include <torch/csrc/utils/python_strings.h>
13
14 #ifndef WIN32
15 #include <pthread.h>
16 #endif
17
18 using namespace torch;
19
20 static bool in_bad_fork = false; // True for children forked after xpu init
21
22 #ifndef WIN32
23 // Called in the forked child if xpu has already been initialized
forked_child()24 static void forked_child() {
25 in_bad_fork = true;
26 torch::utils::set_requires_device_init(at::kXPU, true);
27 }
28 #endif
29
30 // Should be called before the first xpu call. It is mainly called in lazy_init.
31 // Note: This is distinct from initExtension because a stub xpu implementation
32 // has some working functions (e.g. device_count) but cannot fully initialize.
poison_fork()33 static void poison_fork() {
34 #ifndef WIN32
35 static c10::once_flag flag;
36 c10::call_once(flag, [] { pthread_atfork(nullptr, nullptr, forked_child); });
37 #endif
38 }
39
40 // XPU management methods
41
THXPModule_isInBadFork_wrap(PyObject * self,PyObject * noargs)42 static PyObject* THXPModule_isInBadFork_wrap(PyObject* self, PyObject* noargs) {
43 HANDLE_TH_ERRORS
44 return PyBool_FromLong(in_bad_fork);
45 END_HANDLE_TH_ERRORS
46 }
47
THXPModule_setDevice_wrap(PyObject * self,PyObject * arg)48 PyObject* THXPModule_setDevice_wrap(PyObject* self, PyObject* arg) {
49 HANDLE_TH_ERRORS
50 TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to set_device");
51
52 auto device_index = THPUtils_unpackDeviceIndex(arg);
53 c10::xpu::set_device(device_index);
54
55 Py_RETURN_NONE;
56 END_HANDLE_TH_ERRORS
57 }
58
THXPModule_exchangeDevice_wrap(PyObject * self,PyObject * arg)59 PyObject* THXPModule_exchangeDevice_wrap(PyObject* self, PyObject* arg) {
60 HANDLE_TH_ERRORS
61 TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to exchange_device");
62
63 auto device_index = THPUtils_unpackDeviceIndex(arg);
64 if (device_index < 0) {
65 return THPUtils_packInt32(-1);
66 }
67
68 torch::utils::device_lazy_init(at::kXPU);
69 auto current_device = c10::xpu::exchange_device(device_index);
70
71 return THPUtils_packDeviceIndex(current_device);
72 END_HANDLE_TH_ERRORS
73 }
74
THXPModule_maybeExchangeDevice_wrap(PyObject * self,PyObject * arg)75 PyObject* THXPModule_maybeExchangeDevice_wrap(PyObject* self, PyObject* arg) {
76 HANDLE_TH_ERRORS
77 TORCH_CHECK(
78 THPUtils_checkLong(arg), "invalid argument to maybe_exchange_device");
79
80 auto device_index = THPUtils_unpackDeviceIndex(arg);
81 if (device_index < 0) {
82 return THPUtils_packInt32(-1);
83 }
84
85 torch::utils::device_lazy_init(at::kXPU);
86 auto current_device = c10::xpu::maybe_exchange_device(device_index);
87
88 return THPUtils_packDeviceIndex(current_device);
89 END_HANDLE_TH_ERRORS
90 }
91
THXPModule_getDevice_wrap(PyObject * self,PyObject * noargs)92 PyObject* THXPModule_getDevice_wrap(PyObject* self, PyObject* noargs) {
93 HANDLE_TH_ERRORS
94
95 auto device_index = c10::xpu::current_device();
96
97 return THPUtils_packDeviceIndex(device_index);
98 END_HANDLE_TH_ERRORS
99 }
100
THXPModule_getDeviceCount_wrap(PyObject * self,PyObject * noargs)101 PyObject* THXPModule_getDeviceCount_wrap(PyObject* self, PyObject* noargs) {
102 HANDLE_TH_ERRORS
103 poison_fork();
104 return THPUtils_packUInt64(at::xpu::device_count());
105 END_HANDLE_TH_ERRORS
106 }
107
THXPModule_getCurrentStream_wrap(PyObject * self,PyObject * device_index)108 PyObject* THXPModule_getCurrentStream_wrap(
109 PyObject* self,
110 PyObject* device_index) {
111 HANDLE_TH_ERRORS
112 TORCH_CHECK(
113 THPUtils_checkLong(device_index), "invalid argument to current_stream");
114 auto c10_device_index = THPUtils_unpackDeviceIndex(device_index);
115 auto stream = at::xpu::getCurrentXPUStream(c10_device_index);
116 PyObject* output_tuple = PyTuple_New(3);
117 PyTuple_SetItem(
118 output_tuple, 0, THPUtils_packInt64(static_cast<int64_t>(stream.id())));
119 PyTuple_SetItem(
120 output_tuple, 1, THPUtils_packDeviceIndex(stream.device_index()));
121 PyTuple_SetItem(
122 output_tuple,
123 2,
124 THPUtils_packInt64(static_cast<int64_t>(stream.device_type())));
125 return output_tuple;
126 END_HANDLE_TH_ERRORS
127 }
128
THXPModule_getCurrentStream_raw(PyObject * self,PyObject * device_index)129 PyObject* THXPModule_getCurrentStream_raw(
130 PyObject* self,
131 PyObject* device_index) {
132 HANDLE_TH_ERRORS
133 TORCH_CHECK(
134 THPUtils_checkLong(device_index),
135 "invalid argument to getCurrentRawStream");
136 auto c10_device_index = THPUtils_unpackDeviceIndex(device_index);
137 return PyLong_FromVoidPtr(
138 &at::xpu::getCurrentXPUStream(c10_device_index).queue());
139 END_HANDLE_TH_ERRORS
140 }
141
THXPModule_setStream_wrap(PyObject * self,PyObject * args,PyObject * kwargs)142 PyObject* THXPModule_setStream_wrap(
143 PyObject* self,
144 PyObject* args,
145 PyObject* kwargs) {
146 HANDLE_TH_ERRORS
147 int64_t stream_id = 0;
148 int64_t device_index = 0;
149 int64_t device_type = 0;
150
151 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
152 constexpr const char* kwlist[] = {
153 "stream_id", "device_index", "device_type", nullptr};
154 if (!PyArg_ParseTupleAndKeywords(
155 args,
156 kwargs,
157 "|LLL",
158 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
159 const_cast<char**>(kwlist),
160 &stream_id,
161 &device_index,
162 &device_type)) {
163 }
164
165 auto stream = at::xpu::XPUStream::unpack3(
166 stream_id,
167 static_cast<c10::DeviceIndex>(device_index),
168 static_cast<c10::DeviceType>(device_type));
169
170 auto device = c10::xpu::current_device();
171 if (device != stream.device_index()) {
172 c10::xpu::set_device(stream.device_index());
173 }
174 at::xpu::setCurrentXPUStream(stream);
175 Py_RETURN_NONE;
176 END_HANDLE_TH_ERRORS
177 }
178
THXPModule_xpuSynchronize(PyObject * self,PyObject * arg)179 PyObject* THXPModule_xpuSynchronize(PyObject* self, PyObject* arg) {
180 HANDLE_TH_ERRORS
181 TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to synchronize");
182 auto device_index = THPUtils_unpackDeviceIndex(arg);
183 {
184 pybind11::gil_scoped_release no_gil;
185 // Only the SYCL queues we have reserved will be synchronized, see Note
186 // [Synchronize Streams on Device].
187 c10::xpu::syncStreamsOnDevice(device_index);
188 }
189 Py_RETURN_NONE;
190 END_HANDLE_TH_ERRORS
191 }
192
THXPModule_emptyCache(PyObject * self,PyObject * noargs)193 PyObject* THXPModule_emptyCache(PyObject* self, PyObject* noargs) {
194 HANDLE_TH_ERRORS
195 c10::xpu::XPUCachingAllocator::emptyCache();
196 END_HANDLE_TH_ERRORS
197 Py_RETURN_NONE;
198 }
199
THXPModule_memoryStats(PyObject * self,PyObject * arg)200 PyObject* THXPModule_memoryStats(PyObject* self, PyObject* arg) {
201 HANDLE_TH_ERRORS
202 TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to memory_stats");
203 const auto device_index = THPUtils_unpackDeviceIndex(arg);
204
205 using c10::CachingDeviceAllocator::DeviceStats;
206 using c10::CachingDeviceAllocator::Stat;
207 using c10::CachingDeviceAllocator::StatArray;
208 using c10::CachingDeviceAllocator::StatType;
209
210 const auto statToDict = [](const Stat& stat) {
211 py::dict dict;
212
213 dict["current"] = stat.current;
214 dict["peak"] = stat.peak;
215 dict["allocated"] = stat.allocated;
216 dict["freed"] = stat.freed;
217 return dict;
218 };
219
220 const auto statArrayToDict = [=](const StatArray& statArray) {
221 const std::array<const char*, static_cast<size_t>(StatType::NUM_TYPES)>
222 statTypeNames = {"all", "small_pool", "large_pool"};
223 py::dict dict;
224 for (const auto i : c10::irange(statTypeNames.size())) {
225 dict[statTypeNames[i]] = statToDict(statArray[i]);
226 }
227 return dict;
228 };
229
230 const DeviceStats stats =
231 c10::xpu::XPUCachingAllocator::getDeviceStats(device_index);
232
233 py::dict result;
234 result["allocated_bytes"] = statArrayToDict(stats.allocated_bytes);
235 result["reserved_bytes"] = statArrayToDict(stats.reserved_bytes);
236 result["active_bytes"] = statArrayToDict(stats.active_bytes);
237 result["requested_bytes"] = statArrayToDict(stats.requested_bytes);
238
239 return result.release().ptr();
240 END_HANDLE_TH_ERRORS
241 }
242
THXPModule_resetPeakMemoryStats(PyObject * self,PyObject * arg)243 PyObject* THXPModule_resetPeakMemoryStats(PyObject* self, PyObject* arg) {
244 HANDLE_TH_ERRORS
245 TORCH_CHECK(
246 THPUtils_checkLong(arg), "invalid argument to reset_peak_memory_stats");
247 const auto device_index = THPUtils_unpackDeviceIndex(arg);
248 c10::xpu::XPUCachingAllocator::resetPeakStats(device_index);
249 END_HANDLE_TH_ERRORS
250 Py_RETURN_NONE;
251 }
252
THXPModule_resetAccumulatedMemoryStats(PyObject * self,PyObject * arg)253 PyObject* THXPModule_resetAccumulatedMemoryStats(
254 PyObject* self,
255 PyObject* arg) {
256 HANDLE_TH_ERRORS
257 TORCH_CHECK(
258 THPUtils_checkLong(arg),
259 "invalid argument to reset_accumulated_memory_stats");
260 const auto device_index = THPUtils_unpackDeviceIndex(arg);
261 c10::xpu::XPUCachingAllocator::resetAccumulatedStats(device_index);
262 END_HANDLE_TH_ERRORS
263 Py_RETURN_NONE;
264 }
265
266 // XPU module initialization
267
registerXpuDeviceProperties(PyObject * module)268 static void registerXpuDeviceProperties(PyObject* module) {
269 // Add _xpuDevicePropertires class to torch._C
270 using namespace c10::xpu;
271 auto get_device_type = [](const DeviceProp& prop) {
272 std::ostringstream stream;
273 using namespace sycl::info;
274 switch (prop.device_type) {
275 case device_type::cpu:
276 stream << "cpu";
277 break;
278 case device_type::gpu:
279 stream << "gpu";
280 break;
281 case device_type::accelerator:
282 stream << "accelerator";
283 break;
284 case device_type::host:
285 stream << "host";
286 break;
287 default:
288 stream << "unknown device type:"
289 << static_cast<typename std::underlying_type<device_type>::type>(
290 prop.device_type);
291 break;
292 }
293 return stream.str();
294 };
295 auto gpu_subslice_count = [](const DeviceProp& prop) {
296 return (prop.gpu_eu_count / prop.gpu_eu_count_per_subslice);
297 };
298 auto m = py::handle(module).cast<py::module>();
299
300 #define DEFINE_READONLY_MEMBER(member) \
301 def_readonly(#member, &DeviceProp::member)
302
303 #define THXP_FORALL_DEVICE_PROPERTIES(_) \
304 py::class_<DeviceProp>(m, "_XpuDeviceProperties") \
305 ._(name) \
306 ._(platform_name) \
307 ._(vendor) \
308 ._(driver_version) \
309 ._(version) \
310 ._(max_compute_units) \
311 ._(gpu_eu_count) \
312 ._(max_work_group_size) \
313 ._(max_num_sub_groups) \
314 ._(sub_group_sizes) \
315 ._(has_fp16) \
316 ._(has_fp64) \
317 ._(has_atomic64) \
318 ._(has_bfloat16_conversions) \
319 ._(has_subgroup_matrix_multiply_accumulate) \
320 ._(has_subgroup_matrix_multiply_accumulate_tensor_float32) \
321 ._(has_subgroup_2d_block_io)
322
323 THXP_FORALL_DEVICE_PROPERTIES(DEFINE_READONLY_MEMBER)
324 .def_readonly("total_memory", &DeviceProp::global_mem_size)
325 .def_property_readonly("gpu_subslice_count", gpu_subslice_count)
326 .def_property_readonly("type", get_device_type)
327 .def(
328 "__repr__",
329 [&get_device_type, &gpu_subslice_count](const DeviceProp& prop) {
330 std::ostringstream stream;
331 stream << "_XpuDeviceProperties(name='" << prop.name
332 << "', platform_name='" << prop.platform_name << "', type='"
333 << get_device_type(prop) << "', driver_version='"
334 << prop.driver_version << "', total_memory="
335 << prop.global_mem_size / (1024ull * 1024)
336 << "MB, max_compute_units=" << prop.max_compute_units
337 << ", gpu_eu_count=" << prop.gpu_eu_count
338 << ", gpu_subslice_count=" << gpu_subslice_count(prop)
339 << ", max_work_group_size=" << prop.max_work_group_size
340 << ", max_num_sub_groups=" << prop.max_num_sub_groups
341 << ", sub_group_sizes=[" << prop.sub_group_sizes
342 << "], has_fp16=" << prop.has_fp16
343 << ", has_fp64=" << prop.has_fp64
344 << ", has_atomic64=" << prop.has_atomic64 << ")";
345 return stream.str();
346 });
347 }
348
bindGetDeviceProperties(PyObject * module)349 static void bindGetDeviceProperties(PyObject* module) {
350 // Add method to torch.xpu
351 auto m = py::handle(module).cast<py::module>();
352 m.def(
353 "_get_device_properties",
354 [](c10::DeviceIndex device) -> c10::xpu::DeviceProp* {
355 return at::xpu::getDeviceProperties(device);
356 },
357 py::return_value_policy::reference);
358 }
359
360 // Callback for python part. Used for additional initialization of python
361 // classes
THXPModule_initExtension(PyObject * self,PyObject * noargs)362 static PyObject* THXPModule_initExtension(PyObject* self, PyObject* noargs) {
363 HANDLE_TH_ERRORS
364 TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level
365 poison_fork();
366 at::globalContext().lazyInitXPU();
367
368 auto m = THPObjectPtr(PyImport_ImportModule("torch.xpu"));
369 if (!m)
370 throw python_error();
371
372 auto set_module_attr = [&](const char* name, PyObject* v) {
373 if (PyObject_SetAttrString(m, name, v) < 0) {
374 throw python_error();
375 }
376 };
377
378 auto num_gpus = c10::xpu::device_count();
379 THPObjectPtr default_xpu_generators(
380 PyTuple_New(static_cast<Py_ssize_t>(num_gpus)));
381 for (const auto i : c10::irange(num_gpus)) {
382 const auto& gen = at::xpu::detail::getDefaultXPUGenerator(i);
383 auto* cast_gen = THPGenerator_initDefaultGenerator(gen);
384 PyTuple_SetItem(default_xpu_generators.get(), i, cast_gen);
385 }
386 set_module_attr("default_generators", default_xpu_generators.get());
387 bindGetDeviceProperties(m);
388
389 Py_RETURN_NONE;
390 END_HANDLE_TH_ERRORS
391 }
392
393 // NOLINTNEXTLINE(*-c-arrays*, *-global-variables)
394 static struct PyMethodDef _THXPModule_methods[] = {
395 {"_xpu_init", THXPModule_initExtension, METH_NOARGS, nullptr},
396 {"_xpu_setDevice", THXPModule_setDevice_wrap, METH_O, nullptr},
397 {"_xpu_exchangeDevice", THXPModule_exchangeDevice_wrap, METH_O, nullptr},
398 {"_xpu_maybeExchangeDevice",
399 THXPModule_maybeExchangeDevice_wrap,
400 METH_O,
401 nullptr},
402 {"_xpu_getDevice", THXPModule_getDevice_wrap, METH_NOARGS, nullptr},
403 {"_xpu_getDeviceCount",
404 THXPModule_getDeviceCount_wrap,
405 METH_NOARGS,
406 nullptr},
407 {"_xpu_isInBadFork", THXPModule_isInBadFork_wrap, METH_NOARGS, nullptr},
408 {"_xpu_getCurrentStream",
409 THXPModule_getCurrentStream_wrap,
410 METH_O,
411 nullptr},
412 {"_xpu_getCurrentRawStream",
413 THXPModule_getCurrentStream_raw,
414 METH_O,
415 nullptr},
416 {"_xpu_setStream",
417 castPyCFunctionWithKeywords(THXPModule_setStream_wrap),
418 METH_VARARGS | METH_KEYWORDS,
419 nullptr},
420 {"_xpu_synchronize", THXPModule_xpuSynchronize, METH_O, nullptr},
421 {"_xpu_emptyCache", THXPModule_emptyCache, METH_NOARGS, nullptr},
422 {"_xpu_memoryStats", THXPModule_memoryStats, METH_O, nullptr},
423 {"_xpu_resetAccumulatedMemoryStats",
424 THXPModule_resetAccumulatedMemoryStats,
425 METH_O,
426 nullptr},
427 {"_xpu_resetPeakMemoryStats",
428 THXPModule_resetPeakMemoryStats,
429 METH_O,
430 nullptr},
431 {nullptr}};
432
THXPModule_methods()433 PyMethodDef* THXPModule_methods() {
434 return _THXPModule_methods;
435 }
436
437 namespace torch::xpu {
438
initModule(PyObject * module)439 void initModule(PyObject* module) {
440 registerXpuDeviceProperties(module);
441 }
442
443 } // namespace torch::xpu
444