xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qsoftmax.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <torch/library.h>
3 
4 #ifdef USE_PYTORCH_QNNPACK
5 #include <ATen/native/quantized/cpu/init_qnnpack.h>
6 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
7 #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
8 #include <pytorch_qnnpack.h>
9 
10 #include <utility>
11 #endif // USE_PYTORCH_QNNPACK
12 
13 namespace at {
14 namespace native {
15 
16 namespace {
17 
18 #ifdef USE_PYTORCH_QNNPACK
19 
20 const static float qnnpack_softmax_output_scale = 0x1.0p-8f;
21 const static int qnnpack_softmax_output_zero_point = 0;
22 
is_qnnpack_compatible(const Tensor & qx,const double output_scale,const int64_t output_zero_point)23 bool is_qnnpack_compatible(
24     const Tensor& qx,
25     const double output_scale,
26     const int64_t output_zero_point) {
27   return (
28       (qx.qscheme() == kPerTensorAffine ||
29        qx.qscheme() == kPerTensorSymmetric) &&
30       qx.scalar_type() == c10::kQUInt8 && qx.ndimension() > 0 &&
31       output_scale == qnnpack_softmax_output_scale &&
32       output_zero_point == qnnpack_softmax_output_zero_point);
33 }
34 
qsoftmax_qnnpack(const Tensor & qx,const int64_t dim)35 Tensor qsoftmax_qnnpack(const Tensor& qx, const int64_t dim) {
36   /*
37     Cases for contiguity/dimensionality
38     1) stride along target dim is 1
39         requires no change to qx
40     2) dim is the last dimension (but qx is not contiguous)
41         requires using qx.contiguous()
42     3) other
43         requires permuting qx.contiguous()
44    */
45 
46   const int64_t last_dim = qx.dim() - 1;
47   std::optional<std::vector<int64_t>> permuted_dims = std::nullopt;
48   std::optional<at::Tensor> qx_contig = std::nullopt;
49   const at::Tensor* qx_contig_ptr = nullptr;
50 
51   if (qx.stride(dim) == 1) {
52     qx_contig_ptr = &qx;
53   } else if (dim == last_dim) {
54     qx_contig = qx.contiguous();
55     qx_contig_ptr = &qx_contig.value();
56   } else {
57     permuted_dims = std::vector<int64_t>(qx.dim());
58     std::iota(permuted_dims->begin(), permuted_dims->end(), 0);
59     permuted_dims->at(last_dim) = dim;
60     permuted_dims->at(dim) = last_dim;
61     qx_contig = qx.permute(permuted_dims.value()).contiguous();
62     qx_contig_ptr = &qx_contig.value();
63   }
64 
65   at::Tensor qy = at::_empty_affine_quantized(
66       qx_contig_ptr->sizes(),
67       at::device(kCPU)
68           .dtype(qx.scalar_type())
69           .memory_format(qx_contig_ptr->suggest_memory_format()),
70       qnnpack_softmax_output_scale,
71       qnnpack_softmax_output_zero_point,
72       std::nullopt);
73 
74   const size_t channels = qx.size(dim);
75   const float input_scale = static_cast<float>(qx.q_scale());
76   const uint32_t flags = 0;
77   const size_t batch_size = qx.numel() / channels;
78   const uint8_t* input =
79       reinterpret_cast<const uint8_t*>(qx_contig_ptr->data_ptr<c10::quint8>());
80   const size_t input_stride = channels;
81   uint8_t* output = reinterpret_cast<uint8_t*>(qy.data_ptr<c10::quint8>());
82   const size_t output_stride = channels;
83 
84   initQNNPACK();
85   pytorch_qnnp_operator_t softargmax = nullptr;
86 
87   pytorch_qnnp_status status = pytorch_qnnp_create_softargmax_nc_q8(
88       channels,
89       input_scale,
90       qnnpack_softmax_output_zero_point,
91       qnnpack_softmax_output_scale,
92       flags,
93       &softargmax);
94   TORCH_CHECK(
95       status == pytorch_qnnp_status_success,
96       "failed to create QNNPACK Softmax operator");
97   TORCH_CHECK_NOTNULL(softargmax);
98 
99   std::unique_ptr<pytorch_qnnp_operator, QnnpackOperatorDeleter> softmax_op(
100     softargmax);
101 
102   status = pytorch_qnnp_setup_softargmax_nc_q8(
103       softargmax, batch_size, input, input_stride, output, output_stride);
104   TORCH_CHECK(
105       status == pytorch_qnnp_status_success,
106       "failed to setup QNNPACK Softmax operator");
107 
108   pthreadpool_t threadpool = caffe2::pthreadpool_();
109   status = pytorch_qnnp_run_operator(softargmax, threadpool);
110   TORCH_CHECK(
111       status == pytorch_qnnp_status_success,
112       "failed to run QNNPACK Softmax operator");
113 
114   return permuted_dims.has_value() ? qy.permute(permuted_dims.value()) : std::move(qy);
115 }
116 
117 #endif // USE_PYTORCH_QNNPACK
118 
qsoftmax_naive(const Tensor & qx,const int64_t dim,const double output_scale,const int64_t output_zero_point)119 Tensor qsoftmax_naive(
120     const Tensor& qx,
121     const int64_t dim,
122     const double output_scale,
123     const int64_t output_zero_point) {
124   Tensor rx = at::dequantize(qx);
125   Tensor ry = at::softmax(rx, dim);
126   return at::quantize_per_tensor(
127       ry, output_scale, output_zero_point, qx.scalar_type());
128 }
129 
qsoftmax(const Tensor & qx,const int64_t dim,const double output_scale,const int64_t output_zero_point)130 Tensor qsoftmax(
131     const Tensor& qx,
132     const int64_t dim,
133     const double output_scale,
134     const int64_t output_zero_point) {
135 #ifdef USE_PYTORCH_QNNPACK
136   if (at::globalContext().qEngine() == at::QEngine::QNNPACK &&
137       is_qnnpack_compatible(qx, output_scale, output_zero_point)) {
138     return qsoftmax_qnnpack(qx, dim);
139   }
140 #endif // USE_PYTORCH_QNNPACK
141   return qsoftmax_naive(qx, dim, output_scale, output_zero_point);
142 }
143 
TORCH_LIBRARY_IMPL(quantized,QuantizedCPU,m)144 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
145   m.impl(TORCH_SELECTIVE_NAME("quantized::softmax"), TORCH_FN(qsoftmax));
146 }
147 
148 } // namespace
149 
150 } // namespace native
151 } // namespace at
152