xref: /aosp_15_r20/external/pytorch/aten/src/ATen/xpu/XPUContext.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/xpu/XPUContext.h>
2 #include <c10/util/CallOnce.h>
3 #include <c10/util/Exception.h>
4 
5 #include <deque>
6 #include <vector>
7 
8 namespace at::xpu {
9 namespace {
10 
11 /*
12  * Currently, there is one device properties pool containing the information and
13  * capability about each compute-device.
14  *
15  * Device properties are lazily initialized when the first time properties are
16  * requested for a device.
17  */
18 DeviceIndex num_gpus = -1;
19 c10::once_flag init_flag;
20 std::deque<c10::once_flag> device_prop_flags;
21 std::vector<DeviceProp> device_properties;
22 
23 std::deque<c10::once_flag> device_global_idx_flags;
24 std::vector<int32_t> device_global_idxs;
25 
initXPUContextVectors()26 void initXPUContextVectors() {
27   num_gpus = c10::xpu::device_count();
28   device_prop_flags.resize(num_gpus);
29   device_properties.resize(num_gpus);
30   device_global_idx_flags.resize(num_gpus);
31   device_global_idxs.resize(num_gpus);
32 }
33 
initDeviceProperty(DeviceIndex device)34 void initDeviceProperty(DeviceIndex device) {
35   c10::xpu::get_device_properties(&device_properties[device], device);
36 }
37 
initDeviceGlobalIdx(DeviceIndex device)38 void initDeviceGlobalIdx(DeviceIndex device) {
39   sycl::device& raw_device = c10::xpu::get_raw_device(device);
40   // Get all SYCL devices associated with the SYCL platform.
41   auto devices = sycl::device::get_devices();
42   auto match_device = [raw_device](const auto& dev) -> bool {
43     return raw_device == dev;
44   };
45   auto it = std::find_if(devices.begin(), devices.end(), match_device);
46   TORCH_CHECK(
47       it != devices.end(), "Can't find the global index of XPU device.");
48   device_global_idxs[device] =
49       static_cast<int32_t>(std::distance(devices.begin(), it));
50 }
51 
check_device(DeviceIndex device)52 inline void check_device(DeviceIndex device) {
53   TORCH_CHECK(
54       device >= 0 && device < num_gpus,
55       "device is out of range, device is ",
56       static_cast<int>(device),
57       ", total number of device is ",
58       static_cast<int>(num_gpus),
59       ".");
60 }
61 
62 } // anonymous namespace
63 
getCurrentDeviceProperties()64 DeviceProp* getCurrentDeviceProperties() {
65   auto device = c10::xpu::current_device();
66   return getDeviceProperties(device);
67 }
68 
getDeviceProperties(DeviceIndex device)69 DeviceProp* getDeviceProperties(DeviceIndex device) {
70   c10::call_once(init_flag, initXPUContextVectors);
71   if (device == -1)
72     device = c10::xpu::current_device();
73   check_device(device);
74   c10::call_once(device_prop_flags[device], initDeviceProperty, device);
75   return &device_properties[device];
76 }
77 
78 // Return the global index enumerated by sycl::device::get_devices based on the
79 // index of a XPU device in the framework.
getGlobalIdxFromDevice(DeviceIndex device)80 int32_t getGlobalIdxFromDevice(DeviceIndex device) {
81   c10::call_once(init_flag, initXPUContextVectors);
82   check_device(device);
83   c10::call_once(device_global_idx_flags[device], initDeviceGlobalIdx, device);
84   return device_global_idxs[device];
85 }
86 
87 } // namespace at::xpu
88