xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Scalar.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch_v2.h>
4 
5 #ifndef AT_PER_OPERATOR_HEADERS
6 #include <ATen/Functions.h>
7 #include <ATen/NativeFunctions.h>
8 #else
9 #include <ATen/ops/_local_scalar_dense.h>
10 #include <ATen/ops/_local_scalar_dense_native.h>
11 #include <ATen/ops/item_native.h>
12 #endif
13 
14 namespace at::native {
15 
item(const Tensor & self)16 Scalar item(const Tensor& self) {
17   auto numel = self.sym_numel();
18   TORCH_CHECK(numel == 1, "a Tensor with ", numel, " elements cannot be converted to Scalar");
19   if (self.is_sparse()) {
20     if (self._nnz() == 0) return Scalar(0);
21     if (self.is_coalesced()) return at::_local_scalar_dense(self._values());
22     return at::_local_scalar_dense(self._values().sum());
23   } else if (self.is_quantized()) {
24     return self.dequantize().item();
25   } else {
26     return _local_scalar_dense(self);
27   }
28 }
29 
30 #define AT_SD_BASE_TYPES AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
31 #if !defined(C10_MOBILE)
32 #define AT_SD_TYPES AT_EXPAND(AT_SD_BASE_TYPES), AT_EXPAND(AT_FLOAT8_TYPES)
33 #else
34 #define AT_SD_TYPES AT_EXPAND(AT_SD_BASE_TYPES)
35 #endif
36 
_local_scalar_dense_cpu(const Tensor & self)37 Scalar _local_scalar_dense_cpu(const Tensor& self) {
38   Scalar r;
39   AT_DISPATCH_V2(
40     self.scalar_type(),
41     "_local_scalar_dense_cpu",
42     AT_WRAP([&] {
43       scalar_t value = *self.const_data_ptr<scalar_t>();
44       r = Scalar(value);
45     }),
46     AT_EXPAND(AT_SD_TYPES)
47   );
48   return r;
49 }
50 
51 } // at::native
52