1 #pragma once 2 3 #include <ATen/core/IListRef.h> 4 #include <ATen/core/Tensor.h> 5 #include <c10/core/DeviceGuard.h> 6 #include <c10/core/ScalarType.h> // TensorList whyyyyy 7 8 namespace at { 9 10 // Are you here because you're wondering why DeviceGuard(tensor) no 11 // longer works? For code organization reasons, we have temporarily(?) 12 // removed this constructor from DeviceGuard. The new way to 13 // spell it is: 14 // 15 // OptionalDeviceGuard guard(device_of(tensor)); 16 17 /// Return the Device of a Tensor, if the Tensor is defined. device_of(const Tensor & t)18inline std::optional<Device> device_of(const Tensor& t) { 19 if (t.defined()) { 20 return std::make_optional(t.device()); 21 } else { 22 return std::nullopt; 23 } 24 } 25 device_of(const std::optional<Tensor> & t)26inline std::optional<Device> device_of(const std::optional<Tensor>& t) { 27 return t.has_value() ? device_of(t.value()) : std::nullopt; 28 } 29 30 /// Return the Device of a TensorList, if the list is non-empty and 31 /// the first Tensor is defined. (This function implicitly assumes 32 /// that all tensors in the list have the same device.) device_of(ITensorListRef t)33inline std::optional<Device> device_of(ITensorListRef t) { 34 if (!t.empty()) { 35 return device_of(t.front()); 36 } else { 37 return std::nullopt; 38 } 39 } 40 41 } // namespace at 42