1 #include <ATen/ATen.h>
2 #include <torch/library.h>
3
4 #ifdef USE_PYTORCH_QNNPACK
5 #include <ATen/native/quantized/cpu/init_qnnpack.h>
6 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
7 #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
8 #include <pytorch_qnnpack.h>
9
10 #include <utility>
11 #endif // USE_PYTORCH_QNNPACK
12
13 namespace at {
14 namespace native {
15
16 namespace {
17
18 #ifdef USE_PYTORCH_QNNPACK
19
20 const static float qnnpack_softmax_output_scale = 0x1.0p-8f;
21 const static int qnnpack_softmax_output_zero_point = 0;
22
is_qnnpack_compatible(const Tensor & qx,const double output_scale,const int64_t output_zero_point)23 bool is_qnnpack_compatible(
24 const Tensor& qx,
25 const double output_scale,
26 const int64_t output_zero_point) {
27 return (
28 (qx.qscheme() == kPerTensorAffine ||
29 qx.qscheme() == kPerTensorSymmetric) &&
30 qx.scalar_type() == c10::kQUInt8 && qx.ndimension() > 0 &&
31 output_scale == qnnpack_softmax_output_scale &&
32 output_zero_point == qnnpack_softmax_output_zero_point);
33 }
34
qsoftmax_qnnpack(const Tensor & qx,const int64_t dim)35 Tensor qsoftmax_qnnpack(const Tensor& qx, const int64_t dim) {
36 /*
37 Cases for contiguity/dimensionality
38 1) stride along target dim is 1
39 requires no change to qx
40 2) dim is the last dimension (but qx is not contiguous)
41 requires using qx.contiguous()
42 3) other
43 requires permuting qx.contiguous()
44 */
45
46 const int64_t last_dim = qx.dim() - 1;
47 std::optional<std::vector<int64_t>> permuted_dims = std::nullopt;
48 std::optional<at::Tensor> qx_contig = std::nullopt;
49 const at::Tensor* qx_contig_ptr = nullptr;
50
51 if (qx.stride(dim) == 1) {
52 qx_contig_ptr = &qx;
53 } else if (dim == last_dim) {
54 qx_contig = qx.contiguous();
55 qx_contig_ptr = &qx_contig.value();
56 } else {
57 permuted_dims = std::vector<int64_t>(qx.dim());
58 std::iota(permuted_dims->begin(), permuted_dims->end(), 0);
59 permuted_dims->at(last_dim) = dim;
60 permuted_dims->at(dim) = last_dim;
61 qx_contig = qx.permute(permuted_dims.value()).contiguous();
62 qx_contig_ptr = &qx_contig.value();
63 }
64
65 at::Tensor qy = at::_empty_affine_quantized(
66 qx_contig_ptr->sizes(),
67 at::device(kCPU)
68 .dtype(qx.scalar_type())
69 .memory_format(qx_contig_ptr->suggest_memory_format()),
70 qnnpack_softmax_output_scale,
71 qnnpack_softmax_output_zero_point,
72 std::nullopt);
73
74 const size_t channels = qx.size(dim);
75 const float input_scale = static_cast<float>(qx.q_scale());
76 const uint32_t flags = 0;
77 const size_t batch_size = qx.numel() / channels;
78 const uint8_t* input =
79 reinterpret_cast<const uint8_t*>(qx_contig_ptr->data_ptr<c10::quint8>());
80 const size_t input_stride = channels;
81 uint8_t* output = reinterpret_cast<uint8_t*>(qy.data_ptr<c10::quint8>());
82 const size_t output_stride = channels;
83
84 initQNNPACK();
85 pytorch_qnnp_operator_t softargmax = nullptr;
86
87 pytorch_qnnp_status status = pytorch_qnnp_create_softargmax_nc_q8(
88 channels,
89 input_scale,
90 qnnpack_softmax_output_zero_point,
91 qnnpack_softmax_output_scale,
92 flags,
93 &softargmax);
94 TORCH_CHECK(
95 status == pytorch_qnnp_status_success,
96 "failed to create QNNPACK Softmax operator");
97 TORCH_CHECK_NOTNULL(softargmax);
98
99 std::unique_ptr<pytorch_qnnp_operator, QnnpackOperatorDeleter> softmax_op(
100 softargmax);
101
102 status = pytorch_qnnp_setup_softargmax_nc_q8(
103 softargmax, batch_size, input, input_stride, output, output_stride);
104 TORCH_CHECK(
105 status == pytorch_qnnp_status_success,
106 "failed to setup QNNPACK Softmax operator");
107
108 pthreadpool_t threadpool = caffe2::pthreadpool_();
109 status = pytorch_qnnp_run_operator(softargmax, threadpool);
110 TORCH_CHECK(
111 status == pytorch_qnnp_status_success,
112 "failed to run QNNPACK Softmax operator");
113
114 return permuted_dims.has_value() ? qy.permute(permuted_dims.value()) : std::move(qy);
115 }
116
117 #endif // USE_PYTORCH_QNNPACK
118
qsoftmax_naive(const Tensor & qx,const int64_t dim,const double output_scale,const int64_t output_zero_point)119 Tensor qsoftmax_naive(
120 const Tensor& qx,
121 const int64_t dim,
122 const double output_scale,
123 const int64_t output_zero_point) {
124 Tensor rx = at::dequantize(qx);
125 Tensor ry = at::softmax(rx, dim);
126 return at::quantize_per_tensor(
127 ry, output_scale, output_zero_point, qx.scalar_type());
128 }
129
qsoftmax(const Tensor & qx,const int64_t dim,const double output_scale,const int64_t output_zero_point)130 Tensor qsoftmax(
131 const Tensor& qx,
132 const int64_t dim,
133 const double output_scale,
134 const int64_t output_zero_point) {
135 #ifdef USE_PYTORCH_QNNPACK
136 if (at::globalContext().qEngine() == at::QEngine::QNNPACK &&
137 is_qnnpack_compatible(qx, output_scale, output_zero_point)) {
138 return qsoftmax_qnnpack(qx, dim);
139 }
140 #endif // USE_PYTORCH_QNNPACK
141 return qsoftmax_naive(qx, dim, output_scale, output_zero_point);
142 }
143
TORCH_LIBRARY_IMPL(quantized,QuantizedCPU,m)144 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
145 m.impl(TORCH_SELECTIVE_NAME("quantized::softmax"), TORCH_FN(qsoftmax));
146 }
147
148 } // namespace
149
150 } // namespace native
151 } // namespace at
152