xref: /aosp_15_r20/external/pytorch/aten/src/ATen/DeviceGuard.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/IListRef.h>
4 #include <ATen/core/Tensor.h>
5 #include <c10/core/DeviceGuard.h>
6 #include <c10/core/ScalarType.h> // TensorList whyyyyy
7 
8 namespace at {
9 
10 // Are you here because you're wondering why DeviceGuard(tensor) no
11 // longer works?  For code organization reasons, we have temporarily(?)
12 // removed this constructor from DeviceGuard.  The new way to
13 // spell it is:
14 //
15 //    OptionalDeviceGuard guard(device_of(tensor));
16 
17 /// Return the Device of a Tensor, if the Tensor is defined.
device_of(const Tensor & t)18 inline std::optional<Device> device_of(const Tensor& t) {
19   if (t.defined()) {
20     return std::make_optional(t.device());
21   } else {
22     return std::nullopt;
23   }
24 }
25 
device_of(const std::optional<Tensor> & t)26 inline std::optional<Device> device_of(const std::optional<Tensor>& t) {
27   return t.has_value() ? device_of(t.value()) : std::nullopt;
28 }
29 
30 /// Return the Device of a TensorList, if the list is non-empty and
31 /// the first Tensor is defined.  (This function implicitly assumes
32 /// that all tensors in the list have the same device.)
device_of(ITensorListRef t)33 inline std::optional<Device> device_of(ITensorListRef t) {
34   if (!t.empty()) {
35     return device_of(t.front());
36   } else {
37     return std::nullopt;
38   }
39 }
40 
41 } // namespace at
42