1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch_v2.h>
4
5 #ifndef AT_PER_OPERATOR_HEADERS
6 #include <ATen/Functions.h>
7 #include <ATen/NativeFunctions.h>
8 #else
9 #include <ATen/ops/_local_scalar_dense.h>
10 #include <ATen/ops/_local_scalar_dense_native.h>
11 #include <ATen/ops/item_native.h>
12 #endif
13
14 namespace at::native {
15
item(const Tensor & self)16 Scalar item(const Tensor& self) {
17 auto numel = self.sym_numel();
18 TORCH_CHECK(numel == 1, "a Tensor with ", numel, " elements cannot be converted to Scalar");
19 if (self.is_sparse()) {
20 if (self._nnz() == 0) return Scalar(0);
21 if (self.is_coalesced()) return at::_local_scalar_dense(self._values());
22 return at::_local_scalar_dense(self._values().sum());
23 } else if (self.is_quantized()) {
24 return self.dequantize().item();
25 } else {
26 return _local_scalar_dense(self);
27 }
28 }
29
30 #define AT_SD_BASE_TYPES AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
31 #if !defined(C10_MOBILE)
32 #define AT_SD_TYPES AT_EXPAND(AT_SD_BASE_TYPES), AT_EXPAND(AT_FLOAT8_TYPES)
33 #else
34 #define AT_SD_TYPES AT_EXPAND(AT_SD_BASE_TYPES)
35 #endif
36
_local_scalar_dense_cpu(const Tensor & self)37 Scalar _local_scalar_dense_cpu(const Tensor& self) {
38 Scalar r;
39 AT_DISPATCH_V2(
40 self.scalar_type(),
41 "_local_scalar_dense_cpu",
42 AT_WRAP([&] {
43 scalar_t value = *self.const_data_ptr<scalar_t>();
44 r = Scalar(value);
45 }),
46 AT_EXPAND(AT_SD_TYPES)
47 );
48 return r;
49 }
50
51 } // at::native
52