xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/device_lazy_init.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/TensorOptions.h>
4 
5 // device_lazy_init() is always compiled, even for CPU-only builds.
6 
7 namespace torch::utils {
8 
9 /**
10  * This mechanism of lazy initialization is designed for each device backend.
11  * Currently, CUDA and XPU follow this design. This function `device_lazy_init`
12  * MUST be called before you attempt to access any Type(CUDA or XPU) object
13  * from ATen, in any way. It guarantees that the device runtime status is lazily
14  * initialized when the first runtime API is requested.
15  *
16  * Here are some common ways that a device object may be retrieved:
17  *   - You call getNonVariableType or getNonVariableTypeOpt
18  *   - You call toBackend() on a Type
19  *
20  * It's important to do this correctly, because if you forget to add it you'll
21  * get an oblique error message seems like "Cannot initialize CUDA without
22  * ATen_cuda library" or "Cannot initialize XPU without ATen_xpu library" if you
23  * try to use CUDA or XPU functionality from a CPU-only build, which is not good
24  * UX.
25  */
26 void device_lazy_init(at::DeviceType device_type);
27 void set_requires_device_init(at::DeviceType device_type, bool value);
28 
maybe_initialize_device(at::Device & device)29 inline void maybe_initialize_device(at::Device& device) {
30   // Add more devices here to enable lazy initialization.
31   if (device.is_cuda() || device.is_xpu() || device.is_privateuseone()) {
32     device_lazy_init(device.type());
33   }
34 }
35 
maybe_initialize_device(std::optional<at::Device> & device)36 inline void maybe_initialize_device(std::optional<at::Device>& device) {
37   if (!device.has_value()) {
38     return;
39   }
40   maybe_initialize_device(device.value());
41 }
42 
maybe_initialize_device(const at::TensorOptions & options)43 inline void maybe_initialize_device(const at::TensorOptions& options) {
44   auto device = options.device();
45   maybe_initialize_device(device);
46 }
47 
48 bool is_device_initialized(at::DeviceType device_type);
49 
50 } // namespace torch::utils
51