1 #include <ATen/xpu/XPUContext.h>
2 #include <c10/util/CallOnce.h>
3 #include <c10/util/Exception.h>
4
5 #include <deque>
6 #include <vector>
7
8 namespace at::xpu {
9 namespace {
10
11 /*
12 * Currently, there is one device properties pool containing the information and
13 * capability about each compute-device.
14 *
15 * Device properties are lazily initialized when the first time properties are
16 * requested for a device.
17 */
18 DeviceIndex num_gpus = -1;
19 c10::once_flag init_flag;
20 std::deque<c10::once_flag> device_prop_flags;
21 std::vector<DeviceProp> device_properties;
22
23 std::deque<c10::once_flag> device_global_idx_flags;
24 std::vector<int32_t> device_global_idxs;
25
initXPUContextVectors()26 void initXPUContextVectors() {
27 num_gpus = c10::xpu::device_count();
28 device_prop_flags.resize(num_gpus);
29 device_properties.resize(num_gpus);
30 device_global_idx_flags.resize(num_gpus);
31 device_global_idxs.resize(num_gpus);
32 }
33
initDeviceProperty(DeviceIndex device)34 void initDeviceProperty(DeviceIndex device) {
35 c10::xpu::get_device_properties(&device_properties[device], device);
36 }
37
initDeviceGlobalIdx(DeviceIndex device)38 void initDeviceGlobalIdx(DeviceIndex device) {
39 sycl::device& raw_device = c10::xpu::get_raw_device(device);
40 // Get all SYCL devices associated with the SYCL platform.
41 auto devices = sycl::device::get_devices();
42 auto match_device = [raw_device](const auto& dev) -> bool {
43 return raw_device == dev;
44 };
45 auto it = std::find_if(devices.begin(), devices.end(), match_device);
46 TORCH_CHECK(
47 it != devices.end(), "Can't find the global index of XPU device.");
48 device_global_idxs[device] =
49 static_cast<int32_t>(std::distance(devices.begin(), it));
50 }
51
check_device(DeviceIndex device)52 inline void check_device(DeviceIndex device) {
53 TORCH_CHECK(
54 device >= 0 && device < num_gpus,
55 "device is out of range, device is ",
56 static_cast<int>(device),
57 ", total number of device is ",
58 static_cast<int>(num_gpus),
59 ".");
60 }
61
62 } // anonymous namespace
63
getCurrentDeviceProperties()64 DeviceProp* getCurrentDeviceProperties() {
65 auto device = c10::xpu::current_device();
66 return getDeviceProperties(device);
67 }
68
getDeviceProperties(DeviceIndex device)69 DeviceProp* getDeviceProperties(DeviceIndex device) {
70 c10::call_once(init_flag, initXPUContextVectors);
71 if (device == -1)
72 device = c10::xpu::current_device();
73 check_device(device);
74 c10::call_once(device_prop_flags[device], initDeviceProperty, device);
75 return &device_properties[device];
76 }
77
78 // Return the global index enumerated by sycl::device::get_devices based on the
79 // index of a XPU device in the framework.
getGlobalIdxFromDevice(DeviceIndex device)80 int32_t getGlobalIdxFromDevice(DeviceIndex device) {
81 c10::call_once(init_flag, initXPUContextVectors);
82 check_device(device);
83 c10::call_once(device_global_idx_flags[device], initDeviceGlobalIdx, device);
84 return device_global_idxs[device];
85 }
86
87 } // namespace at::xpu
88