xref: /aosp_15_r20/external/pytorch/aten/src/ATen/DeviceGuard.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/IListRef.h>
4*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/Tensor.h>
5*da0073e9SAndroid Build Coastguard Worker #include <c10/core/DeviceGuard.h>
6*da0073e9SAndroid Build Coastguard Worker #include <c10/core/ScalarType.h> // TensorList whyyyyy
7*da0073e9SAndroid Build Coastguard Worker 
8*da0073e9SAndroid Build Coastguard Worker namespace at {
9*da0073e9SAndroid Build Coastguard Worker 
10*da0073e9SAndroid Build Coastguard Worker // Are you here because you're wondering why DeviceGuard(tensor) no
11*da0073e9SAndroid Build Coastguard Worker // longer works?  For code organization reasons, we have temporarily(?)
12*da0073e9SAndroid Build Coastguard Worker // removed this constructor from DeviceGuard.  The new way to
13*da0073e9SAndroid Build Coastguard Worker // spell it is:
14*da0073e9SAndroid Build Coastguard Worker //
15*da0073e9SAndroid Build Coastguard Worker //    OptionalDeviceGuard guard(device_of(tensor));
16*da0073e9SAndroid Build Coastguard Worker 
17*da0073e9SAndroid Build Coastguard Worker /// Return the Device of a Tensor, if the Tensor is defined.
device_of(const Tensor & t)18*da0073e9SAndroid Build Coastguard Worker inline std::optional<Device> device_of(const Tensor& t) {
19*da0073e9SAndroid Build Coastguard Worker   if (t.defined()) {
20*da0073e9SAndroid Build Coastguard Worker     return std::make_optional(t.device());
21*da0073e9SAndroid Build Coastguard Worker   } else {
22*da0073e9SAndroid Build Coastguard Worker     return std::nullopt;
23*da0073e9SAndroid Build Coastguard Worker   }
24*da0073e9SAndroid Build Coastguard Worker }
25*da0073e9SAndroid Build Coastguard Worker 
device_of(const std::optional<Tensor> & t)26*da0073e9SAndroid Build Coastguard Worker inline std::optional<Device> device_of(const std::optional<Tensor>& t) {
27*da0073e9SAndroid Build Coastguard Worker   return t.has_value() ? device_of(t.value()) : std::nullopt;
28*da0073e9SAndroid Build Coastguard Worker }
29*da0073e9SAndroid Build Coastguard Worker 
30*da0073e9SAndroid Build Coastguard Worker /// Return the Device of a TensorList, if the list is non-empty and
31*da0073e9SAndroid Build Coastguard Worker /// the first Tensor is defined.  (This function implicitly assumes
32*da0073e9SAndroid Build Coastguard Worker /// that all tensors in the list have the same device.)
device_of(ITensorListRef t)33*da0073e9SAndroid Build Coastguard Worker inline std::optional<Device> device_of(ITensorListRef t) {
34*da0073e9SAndroid Build Coastguard Worker   if (!t.empty()) {
35*da0073e9SAndroid Build Coastguard Worker     return device_of(t.front());
36*da0073e9SAndroid Build Coastguard Worker   } else {
37*da0073e9SAndroid Build Coastguard Worker     return std::nullopt;
38*da0073e9SAndroid Build Coastguard Worker   }
39*da0073e9SAndroid Build Coastguard Worker }
40*da0073e9SAndroid Build Coastguard Worker 
41*da0073e9SAndroid Build Coastguard Worker } // namespace at
42