xref: /aosp_15_r20/external/pytorch/torch/csrc/xpu/Module.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <ATen/xpu/XPUContext.h>
3 #include <ATen/xpu/XPUGeneratorImpl.h>
4 #include <c10/util/CallOnce.h>
5 #include <c10/xpu/XPUCachingAllocator.h>
6 #include <c10/xpu/XPUFunctions.h>
7 #include <torch/csrc/Module.h>
8 #include <torch/csrc/THP.h>
9 #include <torch/csrc/utils/device_lazy_init.h>
10 #include <torch/csrc/utils/pycfunction_helpers.h>
11 #include <torch/csrc/utils/python_numbers.h>
12 #include <torch/csrc/utils/python_strings.h>
13 
14 #ifndef WIN32
15 #include <pthread.h>
16 #endif
17 
18 using namespace torch;
19 
20 static bool in_bad_fork = false; // True for children forked after xpu init
21 
22 #ifndef WIN32
23 // Called in the forked child if xpu has already been initialized
forked_child()24 static void forked_child() {
25   in_bad_fork = true;
26   torch::utils::set_requires_device_init(at::kXPU, true);
27 }
28 #endif
29 
30 // Should be called before the first xpu call. It is mainly called in lazy_init.
31 // Note: This is distinct from initExtension because a stub xpu implementation
32 // has some working functions (e.g. device_count) but cannot fully initialize.
poison_fork()33 static void poison_fork() {
34 #ifndef WIN32
35   static c10::once_flag flag;
36   c10::call_once(flag, [] { pthread_atfork(nullptr, nullptr, forked_child); });
37 #endif
38 }
39 
40 // XPU management methods
41 
THXPModule_isInBadFork_wrap(PyObject * self,PyObject * noargs)42 static PyObject* THXPModule_isInBadFork_wrap(PyObject* self, PyObject* noargs) {
43   HANDLE_TH_ERRORS
44   return PyBool_FromLong(in_bad_fork);
45   END_HANDLE_TH_ERRORS
46 }
47 
THXPModule_setDevice_wrap(PyObject * self,PyObject * arg)48 PyObject* THXPModule_setDevice_wrap(PyObject* self, PyObject* arg) {
49   HANDLE_TH_ERRORS
50   TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to set_device");
51 
52   auto device_index = THPUtils_unpackDeviceIndex(arg);
53   c10::xpu::set_device(device_index);
54 
55   Py_RETURN_NONE;
56   END_HANDLE_TH_ERRORS
57 }
58 
THXPModule_exchangeDevice_wrap(PyObject * self,PyObject * arg)59 PyObject* THXPModule_exchangeDevice_wrap(PyObject* self, PyObject* arg) {
60   HANDLE_TH_ERRORS
61   TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to exchange_device");
62 
63   auto device_index = THPUtils_unpackDeviceIndex(arg);
64   if (device_index < 0) {
65     return THPUtils_packInt32(-1);
66   }
67 
68   torch::utils::device_lazy_init(at::kXPU);
69   auto current_device = c10::xpu::exchange_device(device_index);
70 
71   return THPUtils_packDeviceIndex(current_device);
72   END_HANDLE_TH_ERRORS
73 }
74 
THXPModule_maybeExchangeDevice_wrap(PyObject * self,PyObject * arg)75 PyObject* THXPModule_maybeExchangeDevice_wrap(PyObject* self, PyObject* arg) {
76   HANDLE_TH_ERRORS
77   TORCH_CHECK(
78       THPUtils_checkLong(arg), "invalid argument to maybe_exchange_device");
79 
80   auto device_index = THPUtils_unpackDeviceIndex(arg);
81   if (device_index < 0) {
82     return THPUtils_packInt32(-1);
83   }
84 
85   torch::utils::device_lazy_init(at::kXPU);
86   auto current_device = c10::xpu::maybe_exchange_device(device_index);
87 
88   return THPUtils_packDeviceIndex(current_device);
89   END_HANDLE_TH_ERRORS
90 }
91 
THXPModule_getDevice_wrap(PyObject * self,PyObject * noargs)92 PyObject* THXPModule_getDevice_wrap(PyObject* self, PyObject* noargs) {
93   HANDLE_TH_ERRORS
94 
95   auto device_index = c10::xpu::current_device();
96 
97   return THPUtils_packDeviceIndex(device_index);
98   END_HANDLE_TH_ERRORS
99 }
100 
THXPModule_getDeviceCount_wrap(PyObject * self,PyObject * noargs)101 PyObject* THXPModule_getDeviceCount_wrap(PyObject* self, PyObject* noargs) {
102   HANDLE_TH_ERRORS
103   poison_fork();
104   return THPUtils_packUInt64(at::xpu::device_count());
105   END_HANDLE_TH_ERRORS
106 }
107 
THXPModule_getCurrentStream_wrap(PyObject * self,PyObject * device_index)108 PyObject* THXPModule_getCurrentStream_wrap(
109     PyObject* self,
110     PyObject* device_index) {
111   HANDLE_TH_ERRORS
112   TORCH_CHECK(
113       THPUtils_checkLong(device_index), "invalid argument to current_stream");
114   auto c10_device_index = THPUtils_unpackDeviceIndex(device_index);
115   auto stream = at::xpu::getCurrentXPUStream(c10_device_index);
116   PyObject* output_tuple = PyTuple_New(3);
117   PyTuple_SetItem(
118       output_tuple, 0, THPUtils_packInt64(static_cast<int64_t>(stream.id())));
119   PyTuple_SetItem(
120       output_tuple, 1, THPUtils_packDeviceIndex(stream.device_index()));
121   PyTuple_SetItem(
122       output_tuple,
123       2,
124       THPUtils_packInt64(static_cast<int64_t>(stream.device_type())));
125   return output_tuple;
126   END_HANDLE_TH_ERRORS
127 }
128 
THXPModule_getCurrentStream_raw(PyObject * self,PyObject * device_index)129 PyObject* THXPModule_getCurrentStream_raw(
130     PyObject* self,
131     PyObject* device_index) {
132   HANDLE_TH_ERRORS
133   TORCH_CHECK(
134       THPUtils_checkLong(device_index),
135       "invalid argument to getCurrentRawStream");
136   auto c10_device_index = THPUtils_unpackDeviceIndex(device_index);
137   return PyLong_FromVoidPtr(
138       &at::xpu::getCurrentXPUStream(c10_device_index).queue());
139   END_HANDLE_TH_ERRORS
140 }
141 
THXPModule_setStream_wrap(PyObject * self,PyObject * args,PyObject * kwargs)142 PyObject* THXPModule_setStream_wrap(
143     PyObject* self,
144     PyObject* args,
145     PyObject* kwargs) {
146   HANDLE_TH_ERRORS
147   int64_t stream_id = 0;
148   int64_t device_index = 0;
149   int64_t device_type = 0;
150 
151   // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
152   constexpr const char* kwlist[] = {
153       "stream_id", "device_index", "device_type", nullptr};
154   if (!PyArg_ParseTupleAndKeywords(
155           args,
156           kwargs,
157           "|LLL",
158           // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
159           const_cast<char**>(kwlist),
160           &stream_id,
161           &device_index,
162           &device_type)) {
163   }
164 
165   auto stream = at::xpu::XPUStream::unpack3(
166       stream_id,
167       static_cast<c10::DeviceIndex>(device_index),
168       static_cast<c10::DeviceType>(device_type));
169 
170   auto device = c10::xpu::current_device();
171   if (device != stream.device_index()) {
172     c10::xpu::set_device(stream.device_index());
173   }
174   at::xpu::setCurrentXPUStream(stream);
175   Py_RETURN_NONE;
176   END_HANDLE_TH_ERRORS
177 }
178 
THXPModule_xpuSynchronize(PyObject * self,PyObject * arg)179 PyObject* THXPModule_xpuSynchronize(PyObject* self, PyObject* arg) {
180   HANDLE_TH_ERRORS
181   TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to synchronize");
182   auto device_index = THPUtils_unpackDeviceIndex(arg);
183   {
184     pybind11::gil_scoped_release no_gil;
185     // Only the SYCL queues we have reserved will be synchronized, see Note
186     // [Synchronize Streams on Device].
187     c10::xpu::syncStreamsOnDevice(device_index);
188   }
189   Py_RETURN_NONE;
190   END_HANDLE_TH_ERRORS
191 }
192 
THXPModule_emptyCache(PyObject * self,PyObject * noargs)193 PyObject* THXPModule_emptyCache(PyObject* self, PyObject* noargs) {
194   HANDLE_TH_ERRORS
195   c10::xpu::XPUCachingAllocator::emptyCache();
196   END_HANDLE_TH_ERRORS
197   Py_RETURN_NONE;
198 }
199 
THXPModule_memoryStats(PyObject * self,PyObject * arg)200 PyObject* THXPModule_memoryStats(PyObject* self, PyObject* arg) {
201   HANDLE_TH_ERRORS
202   TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to memory_stats");
203   const auto device_index = THPUtils_unpackDeviceIndex(arg);
204 
205   using c10::CachingDeviceAllocator::DeviceStats;
206   using c10::CachingDeviceAllocator::Stat;
207   using c10::CachingDeviceAllocator::StatArray;
208   using c10::CachingDeviceAllocator::StatType;
209 
210   const auto statToDict = [](const Stat& stat) {
211     py::dict dict;
212 
213     dict["current"] = stat.current;
214     dict["peak"] = stat.peak;
215     dict["allocated"] = stat.allocated;
216     dict["freed"] = stat.freed;
217     return dict;
218   };
219 
220   const auto statArrayToDict = [=](const StatArray& statArray) {
221     const std::array<const char*, static_cast<size_t>(StatType::NUM_TYPES)>
222         statTypeNames = {"all", "small_pool", "large_pool"};
223     py::dict dict;
224     for (const auto i : c10::irange(statTypeNames.size())) {
225       dict[statTypeNames[i]] = statToDict(statArray[i]);
226     }
227     return dict;
228   };
229 
230   const DeviceStats stats =
231       c10::xpu::XPUCachingAllocator::getDeviceStats(device_index);
232 
233   py::dict result;
234   result["allocated_bytes"] = statArrayToDict(stats.allocated_bytes);
235   result["reserved_bytes"] = statArrayToDict(stats.reserved_bytes);
236   result["active_bytes"] = statArrayToDict(stats.active_bytes);
237   result["requested_bytes"] = statArrayToDict(stats.requested_bytes);
238 
239   return result.release().ptr();
240   END_HANDLE_TH_ERRORS
241 }
242 
THXPModule_resetPeakMemoryStats(PyObject * self,PyObject * arg)243 PyObject* THXPModule_resetPeakMemoryStats(PyObject* self, PyObject* arg) {
244   HANDLE_TH_ERRORS
245   TORCH_CHECK(
246       THPUtils_checkLong(arg), "invalid argument to reset_peak_memory_stats");
247   const auto device_index = THPUtils_unpackDeviceIndex(arg);
248   c10::xpu::XPUCachingAllocator::resetPeakStats(device_index);
249   END_HANDLE_TH_ERRORS
250   Py_RETURN_NONE;
251 }
252 
THXPModule_resetAccumulatedMemoryStats(PyObject * self,PyObject * arg)253 PyObject* THXPModule_resetAccumulatedMemoryStats(
254     PyObject* self,
255     PyObject* arg) {
256   HANDLE_TH_ERRORS
257   TORCH_CHECK(
258       THPUtils_checkLong(arg),
259       "invalid argument to reset_accumulated_memory_stats");
260   const auto device_index = THPUtils_unpackDeviceIndex(arg);
261   c10::xpu::XPUCachingAllocator::resetAccumulatedStats(device_index);
262   END_HANDLE_TH_ERRORS
263   Py_RETURN_NONE;
264 }
265 
266 // XPU module initialization
267 
registerXpuDeviceProperties(PyObject * module)268 static void registerXpuDeviceProperties(PyObject* module) {
269   // Add _xpuDevicePropertires class to torch._C
270   using namespace c10::xpu;
271   auto get_device_type = [](const DeviceProp& prop) {
272     std::ostringstream stream;
273     using namespace sycl::info;
274     switch (prop.device_type) {
275       case device_type::cpu:
276         stream << "cpu";
277         break;
278       case device_type::gpu:
279         stream << "gpu";
280         break;
281       case device_type::accelerator:
282         stream << "accelerator";
283         break;
284       case device_type::host:
285         stream << "host";
286         break;
287       default:
288         stream << "unknown device type:"
289                << static_cast<typename std::underlying_type<device_type>::type>(
290                       prop.device_type);
291         break;
292     }
293     return stream.str();
294   };
295   auto gpu_subslice_count = [](const DeviceProp& prop) {
296     return (prop.gpu_eu_count / prop.gpu_eu_count_per_subslice);
297   };
298   auto m = py::handle(module).cast<py::module>();
299 
300 #define DEFINE_READONLY_MEMBER(member) \
301   def_readonly(#member, &DeviceProp::member)
302 
303 #define THXP_FORALL_DEVICE_PROPERTIES(_)                         \
304   py::class_<DeviceProp>(m, "_XpuDeviceProperties")              \
305       ._(name)                                                   \
306       ._(platform_name)                                          \
307       ._(vendor)                                                 \
308       ._(driver_version)                                         \
309       ._(version)                                                \
310       ._(max_compute_units)                                      \
311       ._(gpu_eu_count)                                           \
312       ._(max_work_group_size)                                    \
313       ._(max_num_sub_groups)                                     \
314       ._(sub_group_sizes)                                        \
315       ._(has_fp16)                                               \
316       ._(has_fp64)                                               \
317       ._(has_atomic64)                                           \
318       ._(has_bfloat16_conversions)                               \
319       ._(has_subgroup_matrix_multiply_accumulate)                \
320       ._(has_subgroup_matrix_multiply_accumulate_tensor_float32) \
321       ._(has_subgroup_2d_block_io)
322 
323   THXP_FORALL_DEVICE_PROPERTIES(DEFINE_READONLY_MEMBER)
324       .def_readonly("total_memory", &DeviceProp::global_mem_size)
325       .def_property_readonly("gpu_subslice_count", gpu_subslice_count)
326       .def_property_readonly("type", get_device_type)
327       .def(
328           "__repr__",
329           [&get_device_type, &gpu_subslice_count](const DeviceProp& prop) {
330             std::ostringstream stream;
331             stream << "_XpuDeviceProperties(name='" << prop.name
332                    << "', platform_name='" << prop.platform_name << "', type='"
333                    << get_device_type(prop) << "', driver_version='"
334                    << prop.driver_version << "', total_memory="
335                    << prop.global_mem_size / (1024ull * 1024)
336                    << "MB, max_compute_units=" << prop.max_compute_units
337                    << ", gpu_eu_count=" << prop.gpu_eu_count
338                    << ", gpu_subslice_count=" << gpu_subslice_count(prop)
339                    << ", max_work_group_size=" << prop.max_work_group_size
340                    << ", max_num_sub_groups=" << prop.max_num_sub_groups
341                    << ", sub_group_sizes=[" << prop.sub_group_sizes
342                    << "], has_fp16=" << prop.has_fp16
343                    << ", has_fp64=" << prop.has_fp64
344                    << ", has_atomic64=" << prop.has_atomic64 << ")";
345             return stream.str();
346           });
347 }
348 
bindGetDeviceProperties(PyObject * module)349 static void bindGetDeviceProperties(PyObject* module) {
350   // Add method to torch.xpu
351   auto m = py::handle(module).cast<py::module>();
352   m.def(
353       "_get_device_properties",
354       [](c10::DeviceIndex device) -> c10::xpu::DeviceProp* {
355         return at::xpu::getDeviceProperties(device);
356       },
357       py::return_value_policy::reference);
358 }
359 
360 // Callback for python part. Used for additional initialization of python
361 // classes
THXPModule_initExtension(PyObject * self,PyObject * noargs)362 static PyObject* THXPModule_initExtension(PyObject* self, PyObject* noargs) {
363   HANDLE_TH_ERRORS
364   TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level
365   poison_fork();
366   at::globalContext().lazyInitXPU();
367 
368   auto m = THPObjectPtr(PyImport_ImportModule("torch.xpu"));
369   if (!m)
370     throw python_error();
371 
372   auto set_module_attr = [&](const char* name, PyObject* v) {
373     if (PyObject_SetAttrString(m, name, v) < 0) {
374       throw python_error();
375     }
376   };
377 
378   auto num_gpus = c10::xpu::device_count();
379   THPObjectPtr default_xpu_generators(
380       PyTuple_New(static_cast<Py_ssize_t>(num_gpus)));
381   for (const auto i : c10::irange(num_gpus)) {
382     const auto& gen = at::xpu::detail::getDefaultXPUGenerator(i);
383     auto* cast_gen = THPGenerator_initDefaultGenerator(gen);
384     PyTuple_SetItem(default_xpu_generators.get(), i, cast_gen);
385   }
386   set_module_attr("default_generators", default_xpu_generators.get());
387   bindGetDeviceProperties(m);
388 
389   Py_RETURN_NONE;
390   END_HANDLE_TH_ERRORS
391 }
392 
393 // NOLINTNEXTLINE(*-c-arrays*, *-global-variables)
394 static struct PyMethodDef _THXPModule_methods[] = {
395     {"_xpu_init", THXPModule_initExtension, METH_NOARGS, nullptr},
396     {"_xpu_setDevice", THXPModule_setDevice_wrap, METH_O, nullptr},
397     {"_xpu_exchangeDevice", THXPModule_exchangeDevice_wrap, METH_O, nullptr},
398     {"_xpu_maybeExchangeDevice",
399      THXPModule_maybeExchangeDevice_wrap,
400      METH_O,
401      nullptr},
402     {"_xpu_getDevice", THXPModule_getDevice_wrap, METH_NOARGS, nullptr},
403     {"_xpu_getDeviceCount",
404      THXPModule_getDeviceCount_wrap,
405      METH_NOARGS,
406      nullptr},
407     {"_xpu_isInBadFork", THXPModule_isInBadFork_wrap, METH_NOARGS, nullptr},
408     {"_xpu_getCurrentStream",
409      THXPModule_getCurrentStream_wrap,
410      METH_O,
411      nullptr},
412     {"_xpu_getCurrentRawStream",
413      THXPModule_getCurrentStream_raw,
414      METH_O,
415      nullptr},
416     {"_xpu_setStream",
417      castPyCFunctionWithKeywords(THXPModule_setStream_wrap),
418      METH_VARARGS | METH_KEYWORDS,
419      nullptr},
420     {"_xpu_synchronize", THXPModule_xpuSynchronize, METH_O, nullptr},
421     {"_xpu_emptyCache", THXPModule_emptyCache, METH_NOARGS, nullptr},
422     {"_xpu_memoryStats", THXPModule_memoryStats, METH_O, nullptr},
423     {"_xpu_resetAccumulatedMemoryStats",
424      THXPModule_resetAccumulatedMemoryStats,
425      METH_O,
426      nullptr},
427     {"_xpu_resetPeakMemoryStats",
428      THXPModule_resetPeakMemoryStats,
429      METH_O,
430      nullptr},
431     {nullptr}};
432 
THXPModule_methods()433 PyMethodDef* THXPModule_methods() {
434   return _THXPModule_methods;
435 }
436 
437 namespace torch::xpu {
438 
initModule(PyObject * module)439 void initModule(PyObject* module) {
440   registerXpuDeviceProperties(module);
441 }
442 
443 } // namespace torch::xpu
444