xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/IntReprQuant.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/ceil_div.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/native/TensorIterator.h>
6 #include <ATen/native/cpu/Loops.h>
7 #include <ATen/native/DispatchStub.h>
8 #include <c10/util/irange.h>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/empty.h>
15 #include <ATen/ops/int_repr_native.h>
16 #endif
17 
18 namespace at {
19 namespace native {
20 
21 // When input Tensor is non-dense, i.e. the allocated memory
22 // is larger than the memory used by all the elements, we'll
23 // convert it to dense tensor, otherwise we'll keep the memory
24 // format of the output the same as input
int_repr_quantized_cpu(const Tensor & self)25 Tensor int_repr_quantized_cpu(const Tensor& self) {
26   Tensor dst;
27   // NOLINTNEXTLINE(clang-diagnostic-unused-variable)
28   AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(self.scalar_type(), "int_repr", [&]() {
29     if (bit_width == 4 || bit_width == 2) {
30       int64_t out_size = at::ceil_div(self.numel() * bit_width, (int64_t)8);
31       dst = at::empty(
32           {out_size},
33           self.options().dtype(UNDERLYING_TYPE),
34           self.suggest_memory_format());
35       const underlying_t* qdata = reinterpret_cast<const underlying_t*>(self.const_data_ptr<scalar_t>());
36       for (const auto i : c10::irange(dst.numel())) {
37         dst[i] = static_cast<underlying_t>(qdata[i]);
38       }
39     } else {
40       dst = at::empty(
41           self.sizes(),
42           self.options().dtype(UNDERLYING_TYPE),
43           self.suggest_memory_format());
44       auto iter = TensorIteratorConfig()
45         .check_all_same_dtype(false)
46         .add_output(dst)
47         .add_input(self)
48         .build();
49       cpu_kernel(iter, [](scalar_t value) -> underlying_t { return value.val_; });
50       }
51   });
52   return dst;
53 }
54 
55 } // namespace native
56 } // namespace at
57