1*da0073e9SAndroid Build Coastguard Worker #pragma once 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/IListRef.h> 4*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/Tensor.h> 5*da0073e9SAndroid Build Coastguard Worker #include <c10/core/DeviceGuard.h> 6*da0073e9SAndroid Build Coastguard Worker #include <c10/core/ScalarType.h> // TensorList whyyyyy 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker namespace at { 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker // Are you here because you're wondering why DeviceGuard(tensor) no 11*da0073e9SAndroid Build Coastguard Worker // longer works? For code organization reasons, we have temporarily(?) 12*da0073e9SAndroid Build Coastguard Worker // removed this constructor from DeviceGuard. The new way to 13*da0073e9SAndroid Build Coastguard Worker // spell it is: 14*da0073e9SAndroid Build Coastguard Worker // 15*da0073e9SAndroid Build Coastguard Worker // OptionalDeviceGuard guard(device_of(tensor)); 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker /// Return the Device of a Tensor, if the Tensor is defined. device_of(const Tensor & t)18*da0073e9SAndroid Build Coastguard Workerinline std::optional<Device> device_of(const Tensor& t) { 19*da0073e9SAndroid Build Coastguard Worker if (t.defined()) { 20*da0073e9SAndroid Build Coastguard Worker return std::make_optional(t.device()); 21*da0073e9SAndroid Build Coastguard Worker } else { 22*da0073e9SAndroid Build Coastguard Worker return std::nullopt; 23*da0073e9SAndroid Build Coastguard Worker } 24*da0073e9SAndroid Build Coastguard Worker } 25*da0073e9SAndroid Build Coastguard Worker device_of(const std::optional<Tensor> & t)26*da0073e9SAndroid Build Coastguard Workerinline std::optional<Device> device_of(const std::optional<Tensor>& t) { 27*da0073e9SAndroid Build Coastguard Worker return t.has_value() ? device_of(t.value()) : std::nullopt; 28*da0073e9SAndroid Build Coastguard Worker } 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker /// Return the Device of a TensorList, if the list is non-empty and 31*da0073e9SAndroid Build Coastguard Worker /// the first Tensor is defined. (This function implicitly assumes 32*da0073e9SAndroid Build Coastguard Worker /// that all tensors in the list have the same device.) device_of(ITensorListRef t)33*da0073e9SAndroid Build Coastguard Workerinline std::optional<Device> device_of(ITensorListRef t) { 34*da0073e9SAndroid Build Coastguard Worker if (!t.empty()) { 35*da0073e9SAndroid Build Coastguard Worker return device_of(t.front()); 36*da0073e9SAndroid Build Coastguard Worker } else { 37*da0073e9SAndroid Build Coastguard Worker return std::nullopt; 38*da0073e9SAndroid Build Coastguard Worker } 39*da0073e9SAndroid Build Coastguard Worker } 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker } // namespace at 42