1 #pragma once 2 3 #include <c10/core/DeviceType.h> 4 #include <c10/macros/Macros.h> 5 6 #include <ATen/detail/MTIAHooksInterface.h> 7 #include <optional> 8 9 // This file defines the top level Accelerator concept for PyTorch. 10 // A device is an accelerator per the definition here if: 11 // - It is mutually exclusive with all other accelerators 12 // - It performs asynchronous compute via a Stream/Event system 13 // - It provides a set of common APIs as defined by AcceleratorHooksInterface 14 // 15 // As of today, accelerator devices are (in no particular order): 16 // CUDA, MTIA, XPU, HIP, MPS, PrivateUse1 17 18 namespace at { 19 20 // Ensures that only one accelerator is available (at 21 // compile time if possible) and return it. 22 // When checked is true, the returned optional always has a value. 23 TORCH_API std::optional<c10::DeviceType> getAccelerator(bool checked = false); 24 25 TORCH_API bool isAccelerator(c10::DeviceType d); 26 27 } // namespace at 28