xref: /aosp_15_r20/external/pytorch/aten/src/ATen/DeviceAccelerator.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/DeviceType.h>
4 #include <c10/macros/Macros.h>
5 
6 #include <ATen/detail/MTIAHooksInterface.h>
7 #include <optional>
8 
9 // This file defines the top level Accelerator concept for PyTorch.
10 // A device is an accelerator per the definition here if:
11 // - It is mutually exclusive with all other accelerators
12 // - It performs asynchronous compute via a Stream/Event system
13 // - It provides a set of common APIs as defined by AcceleratorHooksInterface
14 //
15 // As of today, accelerator devices are (in no particular order):
16 // CUDA, MTIA, XPU, HIP, MPS, PrivateUse1
17 
18 namespace at {
19 
20 // Ensures that only one accelerator is available (at
21 // compile time if possible) and return it.
22 // When checked is true, the returned optional always has a value.
23 TORCH_API std::optional<c10::DeviceType> getAccelerator(bool checked = false);
24 
25 TORCH_API bool isAccelerator(c10::DeviceType d);
26 
27 } // namespace at
28