1 #include <c10/core/Scalar.h>
2 #include <ATen/core/TensorBody.h>
3
4 #include <c10/util/string_view.h>
5
6 namespace at {
7
8 namespace {
9
10 // Verifies the requested type is the same as the Tensor's type.
check_type(const TensorBase & tensor,ScalarType type,c10::string_view type_name)11 void check_type(const TensorBase& tensor, ScalarType type, c10::string_view type_name) {
12 TORCH_CHECK(
13 tensor.scalar_type() == type
14 || (isQIntType(tensor.scalar_type())
15 && toUnderlying(tensor.scalar_type()) == type),
16 "expected scalar type ", type_name, " but found ", tensor.scalar_type());
17 }
18
19 } // namespace
20
21 #define DEFINE_CAST(T, name) \
22 template <> \
23 TORCH_API const T* TensorBase::const_data_ptr() const { \
24 check_type(*this, ScalarType::name, #name); \
25 return this->unsafeGetTensorImpl()->data_ptr_impl<T>(); \
26 } \
27 \
28 template <> \
29 TORCH_API const T* TensorBase::const_data_ptr<const T>() const { \
30 check_type(*this, ScalarType::name, #name); \
31 return this->unsafeGetTensorImpl()->data_ptr_impl<std::remove_const_t<T>>(); \
32 } \
33 \
34 template <> \
35 TORCH_API T* TensorBase::mutable_data_ptr() const { \
36 check_type(*this, ScalarType::name, #name); \
37 return this->unsafeGetTensorImpl()->mutable_data_ptr_impl<T>(); \
38 } \
39 \
40 template <> \
41 TORCH_API T* TensorBase::data_ptr() const { \
42 return mutable_data_ptr<T>(); \
43 } \
44
45 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CAST)
46 AT_FORALL_QINT_TYPES(DEFINE_CAST)
47 DEFINE_CAST(uint16_t, UInt16)
48 DEFINE_CAST(uint32_t, UInt32)
49 DEFINE_CAST(uint64_t, UInt64)
50 #undef DEFINE_CAST
51
52 #define DEFINE_ITEM(T, name) \
53 template <> \
54 TORCH_API T Tensor::item() const { \
55 return item().to##name(); \
56 }
57
58 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ITEM)
59 #undef DEFINE_ITEM
60
61 } //namespace at
62