xref: /aosp_15_r20/external/pytorch/aten/src/ATen/templates/TensorMethods.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/Scalar.h>
2 #include <ATen/core/TensorBody.h>
3 
4 #include <c10/util/string_view.h>
5 
6 namespace at {
7 
8 namespace {
9 
10 // Verifies the requested type is the same as the Tensor's type.
check_type(const TensorBase & tensor,ScalarType type,c10::string_view type_name)11 void check_type(const TensorBase& tensor, ScalarType type, c10::string_view type_name) {
12   TORCH_CHECK(
13       tensor.scalar_type() == type
14       || (isQIntType(tensor.scalar_type())
15           && toUnderlying(tensor.scalar_type()) == type),
16       "expected scalar type ", type_name, " but found ", tensor.scalar_type());
17 }
18 
19 } // namespace
20 
21 #define DEFINE_CAST(T, name)                                         \
22    template <>                                                       \
23    TORCH_API const T* TensorBase::const_data_ptr() const {           \
24      check_type(*this, ScalarType::name, #name);                     \
25      return this->unsafeGetTensorImpl()->data_ptr_impl<T>();         \
26    }                                                                 \
27                                                                      \
28    template <>                                                       \
29    TORCH_API const T* TensorBase::const_data_ptr<const T>() const {  \
30      check_type(*this, ScalarType::name, #name);                     \
31      return this->unsafeGetTensorImpl()->data_ptr_impl<std::remove_const_t<T>>(); \
32    }                                                                 \
33                                                                      \
34    template <>                                                       \
35    TORCH_API T* TensorBase::mutable_data_ptr() const {               \
36      check_type(*this, ScalarType::name, #name);                     \
37      return this->unsafeGetTensorImpl()->mutable_data_ptr_impl<T>(); \
38    }                                                                 \
39                                                                      \
40    template <>                                                       \
41    TORCH_API T* TensorBase::data_ptr() const {                       \
42      return mutable_data_ptr<T>();                                   \
43    }                                                                 \
44 
45  AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CAST)
46  AT_FORALL_QINT_TYPES(DEFINE_CAST)
47  DEFINE_CAST(uint16_t, UInt16)
48  DEFINE_CAST(uint32_t, UInt32)
49  DEFINE_CAST(uint64_t, UInt64)
50  #undef DEFINE_CAST
51 
52  #define DEFINE_ITEM(T, name)      \
53    template <>                     \
54    TORCH_API T Tensor::item() const { \
55      return item().to##name();     \
56    }
57 
58  AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ITEM)
59  #undef DEFINE_ITEM
60 
61  } //namespace at
62