#include #include #include #include #include namespace at::xpu { namespace { /* * Currently, there is one device properties pool containing the information and * capability about each compute-device. * * Device properties are lazily initialized when the first time properties are * requested for a device. */ DeviceIndex num_gpus = -1; c10::once_flag init_flag; std::deque device_prop_flags; std::vector device_properties; std::deque device_global_idx_flags; std::vector device_global_idxs; void initXPUContextVectors() { num_gpus = c10::xpu::device_count(); device_prop_flags.resize(num_gpus); device_properties.resize(num_gpus); device_global_idx_flags.resize(num_gpus); device_global_idxs.resize(num_gpus); } void initDeviceProperty(DeviceIndex device) { c10::xpu::get_device_properties(&device_properties[device], device); } void initDeviceGlobalIdx(DeviceIndex device) { sycl::device& raw_device = c10::xpu::get_raw_device(device); // Get all SYCL devices associated with the SYCL platform. auto devices = sycl::device::get_devices(); auto match_device = [raw_device](const auto& dev) -> bool { return raw_device == dev; }; auto it = std::find_if(devices.begin(), devices.end(), match_device); TORCH_CHECK( it != devices.end(), "Can't find the global index of XPU device."); device_global_idxs[device] = static_cast(std::distance(devices.begin(), it)); } inline void check_device(DeviceIndex device) { TORCH_CHECK( device >= 0 && device < num_gpus, "device is out of range, device is ", static_cast(device), ", total number of device is ", static_cast(num_gpus), "."); } } // anonymous namespace DeviceProp* getCurrentDeviceProperties() { auto device = c10::xpu::current_device(); return getDeviceProperties(device); } DeviceProp* getDeviceProperties(DeviceIndex device) { c10::call_once(init_flag, initXPUContextVectors); if (device == -1) device = c10::xpu::current_device(); check_device(device); c10::call_once(device_prop_flags[device], initDeviceProperty, device); return &device_properties[device]; } // Return the global index enumerated by sycl::device::get_devices based on the // index of a XPU device in the framework. int32_t getGlobalIdxFromDevice(DeviceIndex device) { c10::call_once(init_flag, initXPUContextVectors); check_device(device); c10::call_once(device_global_idx_flags[device], initDeviceGlobalIdx, device); return device_global_idxs[device]; } } // namespace at::xpu