1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/core/List.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/native/Activation.h>
7 #include <ATen/native/TopKImpl.h>
8 #include <ATen/native/TensorIterator.h>
9 #include <ATen/native/UpSample.h>
10 #include <ATen/native/cpu/IndexKernelUtils.h>
11 #include <ATen/native/cpu/Loops.h>
12 #include <ATen/native/quantized/AffineQuantizer.h>
13 #include <ATen/native/quantized/FakeQuantAffine.h>
14 #include <ATen/native/quantized/IndexKernel.h>
15 #include <ATen/native/quantized/cpu/QuantizedOps.h>
16 #include <ATen/native/cpu/utils.h>
17 #include <c10/util/irange.h>
18 #include <ATen/native/cpu/utils.h>
19
20 #ifndef AT_PER_OPERATOR_HEADERS
21 #include <ATen/Functions.h>
22 #else
23 #include <ATen/ops/_empty_affine_quantized.h>
24 #include <ATen/ops/empty.h>
25 #endif
26
27 #include <cmath>
28 #ifdef USE_FBGEMM
29 #include <fbgemm/QuantUtils.h>
30 #endif
31 #ifdef _OPENMP
32 #include <omp.h>
33 #endif
34 #if defined(__ARM_NEON__) || defined(__aarch64__)
35 #include <ATen/quantized/Quantizer.h>
36 #include <arm_neon.h>
37 #endif
38
39
40 // NOLINTBEGIN(*-c-arrays)
41 namespace at::native {
42 namespace {
43
check_tensor_memory_format(const Tensor & ref,const Tensor & other)44 void check_tensor_memory_format(const Tensor& ref, const Tensor& other) {
45 TORCH_CHECK(
46 ref.is_contiguous(ref.suggest_memory_format()),
47 "Quantized tensor should be contiguous");
48 TORCH_CHECK(
49 other.is_contiguous(ref.suggest_memory_format()),
50 "Float tensor should be contiguous "
51 "in same memory format as quantized tensor");
52 }
53
54 // ****************** HEY YOU! YES YOU! Read this! ********************
55 //
56 // Please read the README.md in this directory before editing this file
57
58 template <bool ReLUFused = false>
qcat_nhwc_kernel(const MaterializedITensorListRef & qxs,int64_t dim,double scale,int64_t zero_point)59 Tensor qcat_nhwc_kernel(
60 const MaterializedITensorListRef& qxs,
61 int64_t dim,
62 double scale,
63 int64_t zero_point) {
64 const at::Tensor& qx0 = qxs[0];
65 int64_t C_out = 0;
66 std::vector<int64_t> Cs_in;
67 // Prefix sum of input channels for fast indexing
68 std::vector<int64_t> Cs_sum;
69 std::vector<double> scales;
70 std::vector<int64_t> zero_pts;
71 std::vector<void*> data_ptrs;
72 std::vector<bool> is_fast_path;
73
74 for (const at::Tensor& qx : qxs) {
75 TORCH_CHECK(
76 qx.dim() == qx0.dim(),
77 "Tensors must have the same number of dimensions: got ",
78 qx.dim(),
79 " and ",
80 qx0.dim());
81 #define CHECK_DIM(d) \
82 TORCH_CHECK( \
83 qx.size(d) == qx0.size(d), \
84 "Sizes of tensors must match expect in dimension 1. Got", \
85 qx.size(d), \
86 " and ", \
87 qx0.size(d));
88 CHECK_DIM(0);
89 CHECK_DIM(2);
90 CHECK_DIM(3);
91 TORCH_CHECK(
92 qx.scalar_type() == qx0.scalar_type(),
93 "Expected object of scalar type ",
94 toString(qx0.scalar_type()),
95 " but got scalar type ",
96 toString(qx.scalar_type()));
97 Cs_in.push_back(qx.size(1));
98 Cs_sum.push_back(C_out);
99 C_out += qx.size(1);
100 scales.push_back(qx.q_scale());
101 zero_pts.push_back(qx.q_zero_point());
102 data_ptrs.push_back(qx.data_ptr());
103 is_fast_path.push_back(
104 qx.q_scale() == scale &&
105 qx.q_zero_point() == zero_point);
106 }
107
108 const int64_t N = qx0.size(0);
109 const int64_t H = qx0.size(2);
110 const int64_t W = qx0.size(3);
111 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
112 float inv_scale = 1.0 / scale;
113
114 auto output = at::_empty_affine_quantized(
115 {N, C_out, H, W},
116 qx0.options().memory_format(MemoryFormat::ChannelsLast),
117 scale,
118 zero_point,
119 std::nullopt);
120
121 // N, H, and W are explicitly captured here because there's a bug in GCC5
122 // and clang5 which causes an internal compiler error if they're not
123 AT_DISPATCH_QINT_TYPES(output.scalar_type(), "qcat_nhwc", [&, N, H, W]() {
124 using Vec = Vectorized<scalar_t>;
125 at::parallel_for(0, N * H * W, 0, [&](int64_t begin, int64_t end) {
126 for (const auto i : c10::irange(begin, end)) {
127 // loop over input tensors
128 for (const auto tidx : c10::irange(Cs_in.size())) {
129 scalar_t::underlying* optr =
130 reinterpret_cast<scalar_t::underlying*>(output.data_ptr()) +
131 i * C_out + Cs_sum[tidx];
132
133 auto curr_C = Cs_in[tidx];
134 float curr_scale = scales[tidx];
135 int64_t curr_zero_pt = zero_pts[tidx];
136
137 scalar_t::underlying* iptr =
138 reinterpret_cast<scalar_t::underlying*>(data_ptrs[tidx]) +
139 i * curr_C;
140
141 if (is_fast_path[tidx] && !ReLUFused) {
142 std::memcpy(optr, iptr, curr_C * sizeof(typename scalar_t::underlying));
143 continue;
144 }
145
146 constexpr auto VLEN = Vec::size();
147 int64_t c = 0;
148
149 // Vectorized loop
150 if (c + VLEN <= curr_C) {
151 auto curr_scale_vec = Vectorized<float>(curr_scale);
152 auto curr_zero_pt_vec = Vectorized<float>((float)curr_zero_pt);
153 auto scale_neg_zp_premul = curr_scale_vec * curr_zero_pt_vec.neg();
154 for (; c + VLEN <= curr_C; c += VLEN) {
155 auto inp_vec = Vec::loadu(iptr + c);
156 auto float_values = inp_vec.dequantize(
157 curr_scale_vec, curr_zero_pt_vec, scale_neg_zp_premul);
158 Vec::float_vec_return_type retvals;
159 for (int i = 0; i < Vec::float_num_vecs(); ++i) {
160 if constexpr (ReLUFused) {
161 retvals[i] =
162 vec::maximum(float_values[i], Vectorized<float>(0.0f));
163 } else {
164 retvals[i] = float_values[i];
165 }
166 }
167 auto quantized =
168 Vec::quantize(retvals, scale, zero_point, inv_scale);
169 quantized.store(optr + c);
170 }
171 }
172
173 // Vectorized loop for channel between 8 and 32 (avx2)
174 constexpr auto kVLEN = Vectorized<float>::size();
175 int64_t elem_size = curr_C - c;
176 if ((VLEN == 4 * kVLEN) && elem_size >= kVLEN) {
177 auto curr_scale_vec = Vectorized<float>(curr_scale);
178 auto curr_zero_pt_vec = Vectorized<float>((float)curr_zero_pt);
179 auto scale_neg_zp_premul = curr_scale_vec * curr_zero_pt_vec.neg();
180 int64_t vec_num = elem_size / kVLEN;
181 std::array<typename scalar_t::underlying, VLEN> buf_in{};
182 memcpy(buf_in.data(), iptr + c, vec_num * kVLEN);
183 auto inp_vec = Vec::loadu(buf_in.data());
184 auto float_values = inp_vec.dequantize(
185 curr_scale_vec, curr_zero_pt_vec, scale_neg_zp_premul);
186 Vec::float_vec_return_type retvals;
187 for (int i = 0; i < vec_num; ++i) {
188 if constexpr (ReLUFused) {
189 retvals[i] =
190 vec::maximum(float_values[i], Vectorized<float>(0.0f));
191 } else {
192 retvals[i] = float_values[i];
193 }
194 }
195 auto quantized =
196 Vec::quantize(retvals, scale, zero_point, inv_scale);
197 quantized.store(optr + c, vec_num * kVLEN);
198 c += vec_num * kVLEN;
199 }
200
201 // Scalar loop
202 for (; c < curr_C; ++c) {
203 auto float_val = at::native::dequantize_val(
204 curr_scale,
205 curr_zero_pt,
206 reinterpret_cast<scalar_t*>(iptr)[c]);
207 if constexpr (ReLUFused) {
208 float_val = std::max(0.0f, float_val);
209 }
210 optr[c] = at::native::quantize_val<scalar_t>(
211 scale, zero_point, float_val)
212 .val_;
213 } // for c
214 } // for tidx
215 } // for i
216 });
217 });
218
219 return output;
220 }
221
222 // horizontal sum over a range of uint8_t
hsum(const uint8_t * A,int len)223 int64_t hsum(const uint8_t* A, int len) {
224 int64_t row_sum = 0;
225 int i = 0;
226
227 #ifdef CPU_CAPABILITY_AVX2
228 __m256i sum_v = _mm256_setzero_si256();
229 __m256i one_epi16_v = _mm256_set1_epi16(1);
230 __m256i one_epi8_v = _mm256_set1_epi8(1);
231 // vectorized
232 for (; i < len / 32 * 32; i += 32) {
233 __m256i src_v = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i));
234 sum_v = _mm256_add_epi32(
235 sum_v,
236 _mm256_madd_epi16(
237 // first argument is unsigned, second is signed
238 _mm256_maddubs_epi16(src_v, one_epi8_v),
239 one_epi16_v)
240 );
241 }
242
243 alignas(64) int32_t temp[8];
244 _mm256_store_si256(reinterpret_cast<__m256i*>(temp), sum_v);
245 for (const auto k : c10::irange(8)) {
246 row_sum += temp[k];
247 }
248 #elif defined(CPU_CAPABILITY_AVX512)
249 __m512i sum_v = _mm512_setzero_si512();
250 __m512i one_epi16_v = _mm512_set1_epi16(1);
251 __m512i one_epi8_v = _mm512_set1_epi8(1);
252 // vectorized
253 for (; i < len / 64 * 64; i += 64) {
254 __m512i src_v = _mm512_loadu_si512(reinterpret_cast<__m512i const*>(A + i));
255 sum_v = _mm512_add_epi32(
256 sum_v,
257 _mm512_madd_epi16(
258 // first argument is unsigned, second is signed
259 _mm512_maddubs_epi16(src_v, one_epi8_v),
260 one_epi16_v)
261 );
262 }
263
264 alignas(64) int32_t temp[16];
265 _mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum_v);
266 for (const auto k : c10::irange(16)) {
267 row_sum += temp[k];
268 }
269 #endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512
270
271 // scalar
272 for (; i < len; ++i) {
273 row_sum += A[i];
274 }
275
276 return row_sum;
277 }
278
279 // horizontal sum over a range of int8_t
hsum(const int8_t * A,int len)280 int64_t hsum(const int8_t* A, int len) {
281 int64_t row_sum = 0;
282 int i = 0;
283
284 #ifdef CPU_CAPABILITY_AVX2
285 __m256i sum_v = _mm256_setzero_si256();
286 __m256i one_epi16_v = _mm256_set1_epi16(1);
287 __m256i one_epi8_v = _mm256_set1_epi8(1);
288 // vectorized
289 for (; i < len / 32 * 32; i += 32) {
290 __m256i src_v = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i));
291 sum_v = _mm256_add_epi32(
292 sum_v,
293 _mm256_madd_epi16(
294 // first argument is unsigned, second is signed
295 _mm256_maddubs_epi16(one_epi8_v, src_v),
296 one_epi16_v)
297 );
298 }
299
300 alignas(64) int32_t temp[8];
301 _mm256_store_si256(reinterpret_cast<__m256i*>(temp), sum_v);
302 for (const auto k : c10::irange(8)) {
303 row_sum += temp[k];
304 }
305 #elif defined(CPU_CAPABILITY_AVX512)
306 __m512i sum_v = _mm512_setzero_si512();
307 __m512i one_epi16_v = _mm512_set1_epi16(1);
308 __m512i one_epi8_v = _mm512_set1_epi8(1);
309 // vectorized
310 for (; i < len / 64 * 64; i += 64) {
311 __m512i src_v = _mm512_loadu_si512(reinterpret_cast<__m512i const*>(A + i));
312 sum_v = _mm512_add_epi32(
313 sum_v,
314 _mm512_madd_epi16(
315 // first argument is unsigned, second is signed
316 _mm512_maddubs_epi16(one_epi8_v, src_v),
317 one_epi16_v)
318 );
319 }
320
321 alignas(64) int32_t temp[16];
322 _mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum_v);
323 for (const auto k : c10::irange(16)) {
324 row_sum += temp[k];
325 }
326 #endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512
327
328 // scalar
329 for (; i < len; ++i) {
330 row_sum += A[i];
331 }
332
333 return row_sum;
334 }
335
336 // horizontal sum over a range of int32_t
hsum(const int32_t * A,int len)337 int64_t hsum(const int32_t* A, int len) {
338 int64_t row_sum = 0;
339 int i = 0;
340
341 #ifdef CPU_CAPABILITY_AVX2
342 __m256i sum_epi64 = _mm256_setzero_si256();
343 // vectorized
344 for (; i < len / 8 * 8; i += 8) {
345 __m256i src_epi32 = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i));
346 // widen
347 __m128i src_lo_epi32 = _mm256_castsi256_si128(src_epi32);
348 __m128i src_hi_epi32 = _mm256_extracti128_si256(src_epi32, 1);
349 __m256i src_lo_epi64 = _mm256_cvtepi32_epi64(src_lo_epi32);
350 __m256i src_hi_epi64 = _mm256_cvtepi32_epi64(src_hi_epi32);
351 // add
352 sum_epi64 = _mm256_add_epi64(sum_epi64, src_lo_epi64);
353 sum_epi64 = _mm256_add_epi64(sum_epi64, src_hi_epi64);
354 }
355
356 alignas(64) int64_t temp[4];
357 _mm256_store_si256(reinterpret_cast<__m256i*>(temp), sum_epi64);
358 for (const auto k : c10::irange(4)) {
359 row_sum += temp[k];
360 }
361 #elif defined(CPU_CAPABILITY_AVX512)
362 __m512i sum_epi64 = _mm512_setzero_si512();
363 // vectorized
364 for (; i < len / 16 * 16; i += 16) {
365 __m512i src_epi32 = _mm512_loadu_si512(reinterpret_cast<__m512i const*>(A + i));
366 // widen
367 __m256i src_lo_epi32 = _mm512_castsi512_si256(src_epi32);
368 __m256i src_hi_epi32 = _mm512_extracti32x8_epi32(src_epi32, 1);
369 __m512i src_lo_epi64 = _mm512_cvtepi32_epi64(src_lo_epi32);
370 __m512i src_hi_epi64 = _mm512_cvtepi32_epi64(src_hi_epi32);
371 // add
372 sum_epi64 = _mm512_add_epi64(sum_epi64, src_lo_epi64);
373 sum_epi64 = _mm512_add_epi64(sum_epi64, src_hi_epi64);
374 }
375
376 alignas(64) int64_t temp[8];
377 _mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum_epi64);
378 for (const auto k : c10::irange(8)) {
379 row_sum += temp[k];
380 }
381 #endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512
382
383 // scalar
384 for (; i < len; ++i) {
385 row_sum += A[i];
386 }
387
388 return row_sum;
389 }
390
391 // horizontal sum of squares over a range of uint8_t
hsum_sq(const uint8_t * A,int len)392 int64_t hsum_sq(const uint8_t* A, int len) {
393 int64_t row_sum = 0;
394 int i = 0;
395
396 #ifdef CPU_CAPABILITY_AVX2
397 // vectorized
398 __m256i sum_v_epu32 = _mm256_setzero_si256();
399 alignas(64) int32_t temp[8];
400 int overflow_threshold = 262144; // 2147483647(max of int32)/(256*256)*8 = 262144
401 int loop = len / overflow_threshold + 1;
402 for(int j=0; j<=loop; j++){
403 for (; ((i < overflow_threshold * j) && (i < len / 16 * 16)); i += 16) {
404 // (i15, ..., i0)
405 __m128i src_epu8 = _mm_loadu_si128(reinterpret_cast<__m128i const*>(A + i));
406 __m256i src_epu16 = _mm256_cvtepu8_epi16(src_epu8);
407 // (i15 ^ 2, ..., i0 ^ 2)
408 __m256i sq_epu16 = _mm256_mullo_epi16(src_epu16, src_epu16);
409 // (i7 ^ 2, ..., i0 ^ 2)
410 __m128i sq_lo_epu16 = _mm256_castsi256_si128(sq_epu16);
411 // (i15 ^ 2, ..., i8 ^ 2)
412 __m128i sq_hi_epu16 = _mm256_extractf128_si256(sq_epu16, 1);
413 // widen to epu32
414 __m256i sq_lo_epu32 = _mm256_cvtepu16_epi32(sq_lo_epu16);
415 __m256i sq_hi_epu32 = _mm256_cvtepu16_epi32(sq_hi_epu16);
416 // add to running sum
417 sum_v_epu32 = _mm256_add_epi32(sum_v_epu32, sq_lo_epu32);
418 sum_v_epu32 = _mm256_add_epi32(sum_v_epu32, sq_hi_epu32);
419 }
420 _mm256_store_si256(reinterpret_cast<__m256i*>(temp), sum_v_epu32);
421 for (const auto k : c10::irange(8)) {
422 row_sum += temp[k];
423 }
424 sum_v_epu32 = _mm256_setzero_si256();
425 }
426 #elif defined(CPU_CAPABILITY_AVX512)
427 __m512i sum_v_epu32 = _mm512_setzero_si512();
428 alignas(64) int32_t temp[16];
429 int overflow_threshold = 262144; // 2147483647(max of int32)/(512*512)*8 = 262144
430 int loop = len / overflow_threshold + 1;
431 for(int j=0; j<=loop; j++){
432 for (; ((i < overflow_threshold * j) && (i < len / 32 * 32)); i += 32) {
433 // (i31, ..., i0)
434 __m256i src_epu8 = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i));
435 __m512i src_epu16 = _mm512_cvtepu8_epi16(src_epu8);
436 // (i31 ^ 2, ..., i0 ^ 2)
437 __m512i sq_epu16 = _mm512_mullo_epi16(src_epu16, src_epu16);
438 // (i15 ^ 2, ..., i0 ^ 2)
439 __m256i sq_lo_epu16 = _mm512_castsi512_si256(sq_epu16);
440 // (i31 ^ 2, ..., i16 ^ 2)
441 __m256i sq_hi_epu16 = _mm512_extracti32x8_epi32(sq_epu16, 1);
442 // widen to epu32
443 __m512i sq_lo_epu32 = _mm512_cvtepu16_epi32(sq_lo_epu16);
444 __m512i sq_hi_epu32 = _mm512_cvtepu16_epi32(sq_hi_epu16);
445 // add to running sum
446 sum_v_epu32 = _mm512_add_epi32(sum_v_epu32, sq_lo_epu32);
447 sum_v_epu32 = _mm512_add_epi32(sum_v_epu32, sq_hi_epu32);
448 }
449 _mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum_v_epu32);
450 for (const auto k : c10::irange(16)) {
451 row_sum += temp[k];
452 }
453 sum_v_epu32 = _mm512_setzero_si512();
454 }
455 #endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512
456
457 // scalar
458 for (; i < len; ++i) {
459 row_sum += A[i] * A[i];
460 }
461
462 return row_sum;
463 }
464
465 // horizontal sum of squares over a range of int8_t
hsum_sq(const int8_t * A,int len)466 int64_t hsum_sq(const int8_t* A, int len) {
467 int64_t row_sum = 0;
468 int i = 0;
469
470 #ifdef CPU_CAPABILITY_AVX2
471 // vectorized
472 __m256i sum_v_epi32 = _mm256_setzero_si256();
473 alignas(64) int32_t temp[8];
474
475 int overflow_threshold = 1048576; //2147483647/(128*128)*8 = 1048576
476 int loop = len / overflow_threshold + 1;
477
478 for(int j=0; j<=loop; j++){
479 for (; ((i < overflow_threshold * j) && (i < len / 16 * 16)); i += 16) {
480 // (i15, ..., i0)
481 __m128i src_epi8 = _mm_loadu_si128(reinterpret_cast<__m128i const*>(A + i));
482 __m256i src_epi16 = _mm256_cvtepi8_epi16(src_epi8);
483 // (i15 ^ 2, ..., i0 ^ 2)
484 __m256i sq_epi16 = _mm256_mullo_epi16(src_epi16, src_epi16);
485 // (i7 ^ 2, ..., i0 ^ 2)
486 __m128i sq_lo_epi16 = _mm256_castsi256_si128(sq_epi16);
487 // (i15 ^ 2, ..., i8 ^ 2)
488 __m128i sq_hi_epi16 = _mm256_extractf128_si256(sq_epi16, 1);
489 // widen to epi32
490 __m256i sq_lo_epi32 = _mm256_cvtepi16_epi32(sq_lo_epi16);
491 __m256i sq_hi_epi32 = _mm256_cvtepi16_epi32(sq_hi_epi16);
492 // add to running sum
493 sum_v_epi32 = _mm256_add_epi32(sum_v_epi32, sq_lo_epi32);
494 sum_v_epi32 = _mm256_add_epi32(sum_v_epi32, sq_hi_epi32);
495 }
496 _mm256_store_si256(reinterpret_cast<__m256i*>(temp), sum_v_epi32);
497
498 for (const auto k : c10::irange(8)) {
499 row_sum += temp[k];
500 }
501 sum_v_epi32 = _mm256_setzero_si256();
502 }
503 #elif defined(CPU_CAPABILITY_AVX512)
504 // vectorized
505 __m512i sum_v_epi32 = _mm512_setzero_si512();
506 alignas(64) int32_t temp[16];
507
508 int overflow_threshold = 1048576; //2147483647/(256*256)*8 = 1048576
509 int loop = len / overflow_threshold + 1;
510
511 for(int j=0; j<=loop; j++){
512 for (; ((i < overflow_threshold * j) && (i < len / 32 * 32)); i += 32) {
513 // (i31, ..., i0)
514 __m256i src_epi8 = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i));
515 __m512i src_epi16 = _mm512_cvtepi8_epi16(src_epi8);
516 // (i31 ^ 2, ..., i0 ^ 2)
517 __m512i sq_epi16 = _mm512_mullo_epi16(src_epi16, src_epi16);
518 // (i15 ^ 2, ..., i0 ^ 2)
519 __m256i sq_lo_epi16 = _mm512_castsi512_si256(sq_epi16);
520 // (i31 ^ 2, ..., i16 ^ 2)
521 __m256i sq_hi_epi16 = _mm512_extracti32x8_epi32(sq_epi16, 1);
522 // widen to epi32
523 __m512i sq_lo_epi32 = _mm512_cvtepi16_epi32(sq_lo_epi16);
524 __m512i sq_hi_epi32 = _mm512_cvtepi16_epi32(sq_hi_epi16);
525 // add to running sum
526 sum_v_epi32 = _mm512_add_epi32(sum_v_epi32, sq_lo_epi32);
527 sum_v_epi32 = _mm512_add_epi32(sum_v_epi32, sq_hi_epi32);
528 }
529 _mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum_v_epi32);
530
531 for (const auto k : c10::irange(16)) {
532 row_sum += temp[k];
533 }
534 sum_v_epi32 = _mm512_setzero_si512();
535 }
536 #endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512
537
538 // scalar
539 for (; i < len; ++i) {
540 row_sum += A[i] * A[i];
541 }
542
543 return row_sum;
544 }
545
546 // horizontal sum os squares over a range of int32_t
547 // floats throughout are necessary to prevent overflow
hsum_sq(const int32_t * A,int len)548 float hsum_sq(const int32_t* A, int len) {
549 float row_sum = 0;
550 int i = 0;
551
552 #ifdef CPU_CAPABILITY_AVX2
553 __m256 sum_ps = _mm256_setzero_ps();
554 // vectorized
555 for (; i < len / 8 * 8; i += 8) {
556 __m256i src_epi32 = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i));
557 __m256 src_ps = _mm256_cvtepi32_ps(src_epi32);
558 sum_ps = _mm256_add_ps(sum_ps, _mm256_mul_ps(src_ps, src_ps));
559 }
560
561 alignas(64) float temp[8];
562 _mm256_store_ps(temp, sum_ps);
563 for (const auto k : c10::irange(8)) {
564 row_sum += static_cast<float>(temp[k]);
565 }
566 #elif defined(CPU_CAPABILITY_AVX512)
567 __m512 sum_ps = _mm512_setzero_ps();
568 // vectorized
569 for (; i < len / 16 * 16; i += 16) {
570 __m512i src_epi32 = _mm512_loadu_si512(reinterpret_cast<__m512i const*>(A + i));
571 __m512 src_ps = _mm512_cvtepi32_ps(src_epi32);
572 sum_ps = _mm512_add_ps(sum_ps, _mm512_mul_ps(src_ps, src_ps));
573 }
574
575 alignas(64) float temp[16];
576 _mm512_store_ps(temp, sum_ps);
577 for (const auto k : c10::irange(16)) {
578 row_sum += static_cast<float>(temp[k]);
579 }
580 #endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512
581
582 // scalar
583 for (; i < len; ++i) {
584 int64_t cur = static_cast<int64_t>(A[i]);
585 row_sum += (float)cur * (float)cur;
586 }
587
588 return row_sum;
589 }
590
qrelu_kernel(const Tensor & qx,Tensor & qy)591 void qrelu_kernel(const Tensor& qx, Tensor& qy) {
592 const auto zero_point = qx.q_zero_point();
593 AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qrelu", [&]() {
594 qy = at::_empty_affine_quantized(
595 qx.sizes(),
596 at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
597 qx.q_scale(),
598 qx.q_zero_point(),
599 std::nullopt);
600 using Vec = Vectorized<scalar_t>;
601 auto zero_point_vec = Vec(scalar_t(zero_point));
602 auto iter = TensorIterator::unary_op(qy, qx);
603 cpu_kernel_vec(
604 iter,
605 [&](scalar_t value) -> scalar_t {
606 return scalar_t(std::max<underlying_t>(value.val_, zero_point));
607 },
608 [&](Vec value) -> Vec { return value.relu(zero_point_vec); });
609 });
610 }
611
leaky_qrelu_out_kernel(Tensor & out,const Tensor & qx,const Scalar & negval_)612 static void leaky_qrelu_out_kernel(Tensor& out, const Tensor& qx,
613 const Scalar& negval_) {
614 int64_t i_zp = qx.q_zero_point();
615 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
616 float i_scale = qx.q_scale();
617
618 int64_t o_zp = out.q_zero_point();
619 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
620 float o_scale = out.q_scale();
621 float o_inv_scale = 1.0f / o_scale;
622
623 float negval = negval_.to<float>();
624
625 AT_DISPATCH_QINT_TYPES(out.scalar_type(), "leaky_qrelu", [&] {
626 using Vec = Vectorized<float>; // Naive implementation uses dequant/quant loop.
627 using qVec = Vectorized<scalar_t>;
628 Vec zero_vec = Vec(0.0f);
629 Vec one_vec = Vec(1.0f);
630
631 Vec i_scale_vec = Vec((float)i_scale);
632 Vec i_zp_vec = Vec((float)i_zp);
633 Vec i_scale_zp_neg_premul_vec = i_scale_vec * i_zp_vec.neg();
634
635 Vec negval_vec = Vec(negval);
636
637 auto iter = TensorIterator::unary_op(out, qx);
638
639 cpu_kernel_vec(
640 iter,
641 [&](scalar_t value_qx) -> scalar_t {
642 auto value_dx = at::native::dequantize_val(i_scale, i_zp, value_qx);
643 auto value_dy = value_dx > 0 ? value_dx : value_dx * negval;
644 return at::native::quantize_val<scalar_t>(o_scale, o_zp, value_dy);
645 },
646 [&](qVec qx_vec) -> qVec {
647 /* Vectorized implementation creates a multiplicand vector, which has
648 * "alpha" for all negative dx values and ones-vector for all
649 * positive values of dx. The multiplicand then is multiplied by the
650 * input.
651 */
652 auto dx_vec_vec = qx_vec.dequantize(i_scale_vec, i_zp_vec,
653 i_scale_zp_neg_premul_vec);
654 for (auto & dx_vec : dx_vec_vec) {
655 const auto multiplicand = Vec::blendv(negval_vec, one_vec,
656 dx_vec > zero_vec);
657 dx_vec *= multiplicand;
658 }
659 return qVec::quantize(dx_vec_vec, o_scale, o_zp, o_inv_scale);
660 });
661 });
662 }
663
qprelu_out_kernel(Tensor & out,const Tensor & qx,const Tensor & qw)664 static void qprelu_out_kernel(Tensor& out,
665 const Tensor& qx,
666 const Tensor& qw) {
667 int32_t i_zp = static_cast<int32_t>(qx.q_zero_point());
668 float i_scale = static_cast<float>(qx.q_scale());
669
670 int32_t w_zp = static_cast<int32_t>(qw.q_zero_point());
671 float w_scale = static_cast<float>(qw.q_scale());
672
673 int32_t o_zp = static_cast<int32_t>(out.q_zero_point());
674 float o_scale = static_cast<float>(out.q_scale());
675 float o_inv_scale = 1.0f / o_scale;
676
677 float multiplier = i_scale * w_scale * o_inv_scale;
678
679 int64_t input_ndim = qx.dim();
680 TORCH_CHECK(input_ndim > 0, "qprelu: zero-dim input tensor is not allowed.");
681
682 // This logic is present in at::prelu and repeated here, as this path can be
683 // hit via quantized::prelu, which is registered under quantized/cpu/qprelu.cpu
684 auto qw_nd = qw;
685 if (input_ndim != qw_nd.dim()) {
686 DimVector dim_w(input_ndim, 1);
687 if (input_ndim > 1) {
688 dim_w[1] = qw.numel();
689 }
690 // This will always be a view in CPU/CUDA, but some backends
691 // like MKLDNN do not support views
692 qw_nd = qw_nd.reshape(dim_w);
693 }
694
695 auto iter = TensorIteratorConfig()
696 .add_output(out)
697 .add_input(qx)
698 .add_input(qw_nd)
699 .build();
700
701 AT_DISPATCH_QINT_TYPES(out.scalar_type(), "qprelu", [&] {
702 using qVec = Vectorized<scalar_t>;
703 qVec i_zp_vec = qVec(static_cast<scalar_t>(i_zp));
704 qVec w_zp_vec = qVec(static_cast<scalar_t>(w_zp));
705
706 // Quantized one as weight
707 auto qw_one = at::native::quantize_val<scalar_t>(w_scale, w_zp, 1.0f);
708 qVec vec_qw_one = qVec(qw_one);
709 auto vec_qw_one_sub_zp = vec_qw_one.widening_subtract(w_zp_vec)[0];
710 int32_t qw_one_sub_zp = qw_one.val_ - w_zp;
711
712 cpu_kernel_vec(
713 iter,
714 [=](scalar_t val_qx, scalar_t val_qw) -> scalar_t {
715 int32_t qx_pos = std::max(static_cast<int32_t>(val_qx.val_), i_zp);
716 int32_t qx_neg = std::min(static_cast<int32_t>(val_qx.val_), i_zp);
717 int32_t qx_pos_sub_zp = qx_pos - i_zp;
718 int32_t qx_neg_sub_zp = qx_neg - i_zp;
719 int32_t qw_sub_zp = val_qw.val_ - w_zp;
720 auto qy_sub_zp = qx_pos_sub_zp * qw_one_sub_zp + qx_neg_sub_zp * qw_sub_zp;
721 return at::native::requantize_from_int<scalar_t>(
722 multiplier, o_zp, qy_sub_zp);
723 },
724 [=](qVec vec_qx, qVec vec_qw) -> qVec {
725 auto vec_qx_pos = vec_qx.maximum(i_zp_vec);
726 auto vec_qx_neg = vec_qx.minimum(i_zp_vec);
727 qVec::int_vec_return_type qx_pos_sub_zp = vec_qx_pos.widening_subtract(i_zp_vec);
728 qVec::int_vec_return_type qx_neg_sub_zp = vec_qx_neg.widening_subtract(i_zp_vec);
729 qVec::int_vec_return_type qw_sub_zp = vec_qw.widening_subtract(w_zp_vec);
730 qVec::int_vec_return_type qy_sub_zp;
731 for (const auto i : c10::irange(qVec::int_num_vecs())) {
732 qy_sub_zp[i] = qx_pos_sub_zp[i] * vec_qw_one_sub_zp + qx_neg_sub_zp[i] * qw_sub_zp[i];
733 }
734 return qVec::requantize_from_int(qy_sub_zp, multiplier, o_zp);
735 });
736 });
737
738 }
739
qgelu_kernel(const Tensor & qx,Tensor & qy,GeluType approximate)740 void qgelu_kernel(const Tensor& qx, Tensor& qy, GeluType approximate) {
741 int64_t zero_point = qx.q_zero_point();
742 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
743 float scale = qx.q_scale();
744 auto scale_vec = Vectorized<float>(scale);
745 auto zero_point_vec = Vectorized<float>((float)zero_point);
746 auto scale_neg_zp_premul_vec = scale_vec * zero_point_vec.neg();
747 int64_t output_zero_point = zero_point;
748 float output_scale = scale;
749 float inv_output_scale = 1.0 / output_scale;
750 const auto kAlphaVec = Vectorized<float>(M_SQRT1_2);
751 const auto kBetaVec = Vectorized<float>(M_SQRT2 * M_2_SQRTPI * 0.5);
752 const auto kKappaVec = Vectorized<float>(0.044715);
753 const auto kOneVec = Vectorized<float>(1);
754 const auto kPointFiveVec = Vectorized<float>(0.5);
755
756 if (approximate == GeluType::Tanh) {
757 AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qgelu", [&]() {
758 qy = at::_empty_affine_quantized(
759 qx.sizes(),
760 at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
761 output_scale,
762 output_zero_point,
763 std::nullopt);
764 auto iter = TensorIterator::unary_op(qy, qx);
765
766 using Vec = Vectorized<scalar_t>;
767 cpu_kernel_vec(
768 iter,
769 [&](scalar_t value_qx) -> scalar_t {
770 const auto value_dx =
771 at::native::dequantize_val(scale, zero_point, value_qx);
772
773 const auto kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
774 const auto kKappa = 0.044715;
775 const auto x_cube = value_dx * value_dx * value_dx;
776 const auto inner = kBeta * (value_dx + kKappa * x_cube);
777 const auto value_dy = 0.5 * value_dx * (1.0 + std::tanh(inner));
778
779 return at::native::quantize_val<scalar_t>(
780 output_scale, output_zero_point, value_dy);
781 },
782 [&](Vec value_qx) -> Vec {
783 auto value_dx = value_qx.dequantize(
784 scale_vec, zero_point_vec, scale_neg_zp_premul_vec);
785 for (auto & value : value_dx) {
786 auto value_cube = value * value * value;
787 auto inner = kBetaVec * (value + kKappaVec * value_cube);
788 value = kPointFiveVec * value * (kOneVec + inner.tanh());
789 }
790 return Vec::quantize(
791 value_dx, output_scale, output_zero_point, inv_output_scale);
792 });
793 });
794 } else {
795 AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qgelu", [&]() {
796 qy = at::_empty_affine_quantized(
797 qx.sizes(),
798 at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
799 output_scale,
800 output_zero_point,
801 std::nullopt);
802 auto iter = TensorIterator::unary_op(qy, qx);
803
804 using Vec = Vectorized<scalar_t>;
805 cpu_kernel_vec(
806 iter,
807 [&](scalar_t value_qx) -> scalar_t {
808 const auto value_dx =
809 at::native::dequantize_val(scale, zero_point, value_qx);
810 const auto value_dy =
811 value_dx * 0.5 * (1 + std::erf(value_dx * M_SQRT1_2));
812 return at::native::quantize_val<scalar_t>(
813 output_scale, output_zero_point, value_dy);
814 },
815 [&](Vec value_qx) -> Vec {
816 auto value_dx = value_qx.dequantize(
817 scale_vec, zero_point_vec, scale_neg_zp_premul_vec);
818 for (auto & value : value_dx) {
819 value = value * kPointFiveVec * (kOneVec + (value * kAlphaVec).erf());
820 }
821 return Vec::quantize(
822 value_dx, output_scale, output_zero_point, inv_output_scale);
823 });
824 });
825 }
826 }
827
828
qsigmoid_kernel(const Tensor & qx,Tensor & qy,double output_scale,int64_t output_zero_point)829 void qsigmoid_kernel(
830 const Tensor& qx, Tensor& qy, double output_scale, int64_t output_zero_point ) {
831 int64_t zero_point = qx.q_zero_point();
832 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
833 float scale = qx.q_scale();
834 auto scale_vec = Vectorized<float>(scale);
835 auto zero_point_vec = Vectorized<float>((float)zero_point);
836
837 AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qsigmoid", [&]() {
838 float inv_output_scale = 1.0 / output_scale;
839
840 qy = at::_empty_affine_quantized(
841 qx.sizes(),
842 at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
843 output_scale,
844 output_zero_point,
845 std::nullopt);
846 auto iter = TensorIterator::unary_op(qy, qx);
847
848 using Vec = Vectorized<scalar_t>;
849 cpu_kernel_vec(
850 iter,
851 [&](scalar_t value_qx) -> scalar_t {
852 const auto value_dx =
853 at::native::dequantize_val(scale, zero_point, value_qx);
854 const auto value_dy = 1.0f / (1.0 + std::exp((-value_dx)));
855 return at::native::quantize_val<scalar_t>(
856 output_scale, output_zero_point, value_dy);
857 },
858 [&](Vec value_qx) -> Vec {
859 auto value_dx = value_qx.dequantize(scale_vec, zero_point_vec);
860 for (auto & value : value_dx) {
861 value = value.neg();
862 value = value.exp();
863 value = Vectorized<float>(1.0f) + value;
864 value = value.reciprocal();
865 }
866 return Vec::quantize(
867 value_dx, output_scale, output_zero_point, inv_output_scale);
868 });
869 });
870 }
871
qhardsigmoid_kernel(const Tensor & qx,Tensor & qy)872 void qhardsigmoid_kernel(const Tensor& qx, Tensor& qy) {
873 int64_t zero_point = qx.q_zero_point();
874 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
875 float scale = qx.q_scale();
876 auto scale_vec = Vectorized<float>(scale);
877 auto zero_point_vec = Vectorized<float>((float)zero_point);
878 auto scale_neg_zp_premul_vec = scale_vec * zero_point_vec.neg();
879
880 AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qhardsigmoid", [&]() {
881
882 // - Output scale is set to 1.0 / 2^(BIT_NUM)
883 float output_scale = 0.00390625; // 1.0 / 2^8
884 if (SCALAR_TYPE == at::kQInt32) {
885 output_scale = 2.3283064365386963e-10; // 1.0 / 2^32
886 }
887 float inv_output_scale = 1.0 / output_scale;
888
889 // The default zero-point is zero. As a one-off optimization for
890 // kQInt8, we set the zero-point to -128 to maximize precision in the
891 // [0, 1] output range. kQInt32 can be handled in a future PR if needed.
892 int64_t output_zero_point = 0;
893 if (SCALAR_TYPE == at::kQInt8) {
894 output_zero_point = -128;
895 }
896
897 qy = at::_empty_affine_quantized(
898 qx.sizes(),
899 at::device(kCPU).dtype(SCALAR_TYPE),
900 output_scale,
901 output_zero_point,
902 qx.suggest_memory_format());
903 auto iter = TensorIterator::unary_op(qy, qx);
904
905 using qVec = Vectorized<scalar_t>;
906 using fVec = Vectorized<float>;
907 fVec kZeroVec(0.0f);
908 fVec kThreeVec(3.0f);
909 fVec kSixVec(6.0f);
910
911 // Naive implementation: uses dequantize/execute/quantize routine
912 cpu_kernel_vec(
913 iter,
914 [&](scalar_t qx) -> scalar_t {
915 auto x = at::native::dequantize_val(scale, zero_point, qx);
916 const auto y = std::min(std::max(x + 3.0f, 0.0f), 6.0f) / 6.0f;
917 return at::native::quantize_val<scalar_t>(
918 output_scale, output_zero_point, y);
919 },
920 [&](qVec value_qx) -> qVec {
921 auto value_dx = value_qx.dequantize(
922 scale_vec, zero_point_vec, scale_neg_zp_premul_vec);
923 for (auto & value : value_dx) {
924 value =
925 vec::minimum(
926 vec::maximum(value + kThreeVec, kZeroVec),
927 kSixVec) /
928 kSixVec;
929 }
930 return qVec::quantize(
931 value_dx, output_scale, output_zero_point, inv_output_scale);
932 });
933 });
934 }
935
qclamp_kernel(const Tensor & qx,const Scalar & min_scalar,const Scalar & max_scalar,Tensor & qy)936 void qclamp_kernel(
937 const Tensor& qx,
938 const Scalar& min_scalar,
939 const Scalar& max_scalar,
940 Tensor& qy) {
941 AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qclamp", [&]() {
942 qy = at::_empty_affine_quantized(
943 qx.sizes(),
944 at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
945 qx.q_scale(),
946 qx.q_zero_point(),
947 std::nullopt);
948 using Vec = Vectorized<scalar_t>;
949 auto iter = TensorIterator::unary_op(qy, qx);
950 auto min = min_scalar.to<float>();
951 auto max = max_scalar.to<float>();
952 scalar_t min_q = at::native::quantize_val<scalar_t>(
953 qx.q_scale(), qx.q_zero_point(), min);
954 scalar_t max_q = at::native::quantize_val<scalar_t>(
955 qx.q_scale(), qx.q_zero_point(), max);
956 auto min_vec = Vec(min_q);
957 auto max_vec = Vec(max_q);
958 cpu_kernel_vec(
959 iter,
960 [&](scalar_t value) -> scalar_t {
961 underlying_t min_clamped =
962 std::max<underlying_t>(value.val_, min_q.val_);
963 return scalar_t(std::min<underlying_t>(min_clamped, max_q.val_));
964 },
965 [&](Vec val) -> Vec {
966 auto min_clamped = val.maximum(min_vec);
967 return min_clamped.minimum(max_vec);
968 });
969 });
970 }
971
qclamp_min_kernel(const Tensor & qx,const Scalar & min_scalar,Tensor & qy)972 void qclamp_min_kernel(const Tensor& qx, const Scalar& min_scalar, Tensor& qy) {
973 AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qclamp", [&]() {
974 qy = at::_empty_affine_quantized(
975 qx.sizes(),
976 at::device(kCPU)
977 .dtype(SCALAR_TYPE)
978 .memory_format(qx.suggest_memory_format()),
979 qx.q_scale(),
980 qx.q_zero_point(),
981 std::nullopt);
982 using Vec = Vectorized<scalar_t>;
983 auto iter = TensorIterator::unary_op(qy, qx);
984 auto min = min_scalar.to<float>();
985 scalar_t min_q = at::native::quantize_val<scalar_t>(
986 qx.q_scale(), qx.q_zero_point(), min);
987 auto min_vec = Vec(min_q);
988 cpu_kernel_vec(
989 iter,
990 [&](scalar_t value) -> scalar_t {
991 return scalar_t(std::max<underlying_t>(value.val_, min_q.val_));
992 },
993 [&](Vec val) -> Vec { return val.maximum(min_vec); });
994 });
995 }
996
qclamp_max_kernel(const Tensor & qx,const Scalar & max_scalar,Tensor & qy)997 void qclamp_max_kernel(const Tensor& qx, const Scalar& max_scalar, Tensor& qy) {
998 AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qclamp", [&]() {
999 qy = at::_empty_affine_quantized(
1000 qx.sizes(),
1001 at::device(kCPU)
1002 .dtype(SCALAR_TYPE)
1003 .memory_format(qx.suggest_memory_format()),
1004 qx.q_scale(),
1005 qx.q_zero_point(),
1006 std::nullopt);
1007 using Vec = Vectorized<scalar_t>;
1008 auto iter = TensorIterator::unary_op(qy, qx);
1009 auto max = max_scalar.to<float>();
1010 scalar_t max_q = at::native::quantize_val<scalar_t>(
1011 qx.q_scale(), qx.q_zero_point(), max);
1012 auto max_vec = Vec(max_q);
1013 cpu_kernel_vec(
1014 iter,
1015 [&](scalar_t value) -> scalar_t {
1016 return scalar_t(std::min<underlying_t>(value.val_, max_q.val_));
1017 },
1018 [&](Vec val) -> Vec { return val.minimum(max_vec); });
1019 });
1020 }
1021
qthreshold_kernel(const Tensor & qx,const Scalar & threshold_scalar,const Scalar & value_scalar,Tensor & qy)1022 void qthreshold_kernel(
1023 // TODO: For future tasks, since output quantization parameters are set equal to
1024 // the input ones, it might make sense to implement this completely in the
1025 // quantized domain.
1026 const Tensor& qx,
1027 const Scalar& threshold_scalar,
1028 const Scalar& value_scalar,
1029 Tensor& qy) {
1030
1031 // defines input and output scales and zero_points
1032 int64_t input_zero_point = qx.q_zero_point();
1033 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1034 float input_scale = qx.q_scale();
1035 int64_t output_zero_point = qy.q_zero_point();
1036 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1037 float output_scale = qy.q_scale();
1038 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1039 float inv_output_scale = 1.0 / output_scale;
1040
1041 AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qthreshold", [&]() {
1042 qy = at::_empty_affine_quantized(
1043 qx.sizes(),
1044 at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
1045 qx.q_scale(),
1046 qx.q_zero_point(),
1047 std::nullopt);
1048
1049 // vectorized
1050 using Vec = Vectorized<float>;
1051 using qVec = Vectorized<scalar_t>;
1052 // defines the iterator
1053 auto iter = TensorIterator::unary_op(qy, qx);
1054 // defines the vectorized versions
1055 Vec input_scale_vec = Vec(input_scale);
1056 Vec input_zero_point_vec = Vec(input_zero_point);
1057 Vec input_scale_neg_zp_premul_vec = input_scale_vec * input_zero_point_vec.neg();
1058 // defines the floating-point versions of threshold and value
1059 float threshold_float = threshold_scalar.to<float>();
1060 float value_float = value_scalar.to<float>();
1061 Vec threshold_vec = Vec(threshold_float);
1062 Vec value_vec = Vec(value_float);
1063
1064 // Naive implementation: uses dequantize/execute/quantize routine
1065 cpu_kernel_vec(
1066 iter,
1067 [&](scalar_t value_qx) -> scalar_t {
1068 // dequantize
1069 const auto x = at::native::dequantize_val(input_scale, input_zero_point, value_qx);
1070 // Applies the Threshold operation
1071 const auto y = x > threshold_float ? x : value_float;
1072 // quantize
1073 return at::native::quantize_val<scalar_t>(output_scale, output_zero_point, y);
1074 },
1075 [&](qVec value_qx) -> qVec {
1076 // dequantize
1077 auto dx_vec = value_qx.dequantize(
1078 input_scale_vec, input_zero_point_vec, input_scale_neg_zp_premul_vec);
1079 for (auto & value : dx_vec) {
1080 // check if any elements are below threshold
1081 const auto cmp_to_threshold = value > threshold_vec;
1082 if (cmp_to_threshold.zero_mask()) {
1083 // blend
1084 value = Vec::blendv(value_vec, value, cmp_to_threshold);
1085 }
1086 }
1087 // quantize
1088 return qVec::quantize(dx_vec, output_scale, output_zero_point, inv_output_scale);
1089 });
1090 });
1091 }
1092
1093
qhardswish_kernel(const Tensor & qx,Tensor & qy)1094 void qhardswish_kernel(const Tensor& qx, Tensor& qy) {
1095 const auto i_scale = qx.q_scale();
1096 const auto i_zero_point = qx.q_zero_point();
1097
1098 const auto o_scale = qy.q_scale();
1099 const auto o_zero_point = qy.q_zero_point();
1100 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1101 const float o_inv_scale = 1.0 / o_scale;
1102
1103 using fVec = Vectorized<float>;
1104 fVec i_scale_vec(i_scale);
1105 fVec i_zero_point_vec(i_zero_point);
1106 fVec i_scale_neg_zp_premul_vec = i_scale_vec * i_zero_point_vec.neg();
1107 fVec zero_vec(0.0f);
1108 fVec three_vec(3.0f);
1109 fVec six_vec(6.0f);
1110
1111 AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qhardswish", [&]() {
1112 using qVec = Vectorized<scalar_t>;
1113 auto iter = TensorIterator::unary_op(qy, qx);
1114 cpu_kernel_vec(
1115 iter,
1116 [&](scalar_t value) -> scalar_t {
1117 const auto x =
1118 at::native::dequantize_val(i_scale, i_zero_point, value);
1119 const auto y = x * std::min(std::max(x + 3.0f, 0.0f), 6.0f) / 6.0f;
1120 return at::native::quantize_val<scalar_t>(o_scale, o_zero_point, y);
1121 },
1122 [&](qVec value) -> qVec {
1123 auto value_dx = value.dequantize(i_scale_vec, i_zero_point_vec,
1124 i_scale_neg_zp_premul_vec);
1125 for (auto & value : value_dx) {
1126 value = value * vec::minimum(
1127 vec::maximum(value + three_vec, zero_vec),
1128 six_vec
1129 ) / six_vec;
1130 }
1131 return qVec::quantize(value_dx, o_scale, o_zero_point, o_inv_scale);
1132 });
1133 });
1134 }
1135
1136
qtanh_kernel(const Tensor & qx,Tensor & qy)1137 void qtanh_kernel(const Tensor& qx, Tensor& qy) {
1138 int64_t zero_point = qx.q_zero_point();
1139 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1140 float scale = qx.q_scale();
1141 auto scale_vec = Vectorized<float>(scale);
1142 auto zero_point_vec = Vectorized<float>((float)zero_point);
1143 auto scale_neg_zp_premul_vec = scale_vec * zero_point_vec.neg();
1144
1145 AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qtanh", [&]() {
1146 // Naive implementation: uses dequantize/execute/quantize routine
1147 // - Output scale is set to 2.0 / 2^(BIT_NUM)
1148 // - For signed types output zero point is set to 0
1149 // - For unsigned types output zero point is set to (qmax + qmin) / 2.0
1150 float output_scale = 0.0078125; // 2.0 / 512
1151 int64_t output_zero_point = 0;
1152 if (SCALAR_TYPE == at::kQInt32) {
1153 output_scale = 4.656612873077393e-10; // 2.0 / 2^32
1154 } else if (SCALAR_TYPE == at::kQUInt8) {
1155 output_zero_point = 128;
1156 }
1157 float inv_output_scale = 1.0 / output_scale;
1158
1159 qy = at::_empty_affine_quantized(
1160 qx.sizes(),
1161 at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
1162 output_scale,
1163 output_zero_point,
1164 std::nullopt);
1165 auto iter = TensorIterator::unary_op(qy, qx);
1166
1167 using Vec = Vectorized<scalar_t>;
1168 cpu_kernel_vec(
1169 iter,
1170 [&](scalar_t value_qx) -> scalar_t {
1171 const auto value_dx =
1172 at::native::dequantize_val(scale, zero_point, value_qx);
1173 return at::native::quantize_val<scalar_t>(
1174 output_scale, output_zero_point, std::tanh(value_dx));
1175 },
1176 [&](Vec value_qx) -> Vec {
1177 const auto value_dx = value_qx.dequantize(
1178 scale_vec, zero_point_vec, scale_neg_zp_premul_vec);
1179 Vec::float_vec_return_type retvals;
1180 for (const auto idx : c10::irange(Vec::float_num_vecs())) {
1181 retvals[idx] = value_dx[idx].tanh();
1182 }
1183 return Vec::quantize(
1184 retvals, output_scale, output_zero_point, inv_output_scale);
1185 });
1186 });
1187 }
1188
qelu_kernel(const Tensor & qx,const Scalar & alpha,const Scalar & scale,const Scalar & input_scale,Tensor & qy)1189 void qelu_kernel(
1190 const Tensor& qx,
1191 const Scalar& alpha,
1192 const Scalar& scale,
1193 const Scalar& input_scale,
1194 Tensor& qy) {
1195 // scale and input_scale arguments refer to a generalized ELU formula
1196 // if x >= 0, ELU(x) = x * scale
1197 // if x <= 0, ELU(x) = (exp(x * input_scale) - 1) * scale
1198 // in the normal ELU formula, both are equal to 1
1199 // they are NOT related to the quantization scale term
1200
1201 int64_t i_zp = qx.q_zero_point();
1202 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1203 float i_scale = qx.q_scale();
1204
1205 // In a future PR, we can improve on output scale and zero_point
1206 // selection.
1207 int64_t o_zp = qy.q_zero_point();
1208 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1209 float o_scale = qy.q_scale();
1210 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1211 float inv_o_scale = 1.0 / o_scale;
1212
1213 float alpha_float = alpha.to<float>();
1214 float scale_coef = scale.to<float>();
1215 float input_scale_coef = input_scale.to<float>();
1216
1217 AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qelu_kernel", [&] {
1218
1219 auto iter = TensorIterator::unary_op(qy, qx);
1220
1221 // vectorized
1222 using Vec = Vectorized<float>;
1223 using qVec = Vectorized<scalar_t>;
1224
1225 Vec zero_vec = Vec(0.0f);
1226 Vec one_vec = Vec(1.0f);
1227 Vec alpha_vec = Vec(alpha_float);
1228 Vec scale_coef_vec = Vec(scale_coef);
1229 Vec input_scale_coef_vec = Vec(input_scale_coef);
1230 Vec i_scale_vec = Vec(i_scale);
1231 Vec i_zero_point_vec = Vec((float)i_zp);
1232 Vec i_scale_neg_zp_premul_vec = i_scale_vec * i_zero_point_vec.neg();
1233
1234 cpu_kernel_vec(
1235 iter,
1236 [&](scalar_t value_qx) -> scalar_t {
1237 // dequantize
1238 const auto x = at::native::dequantize_val(i_scale, i_zp, value_qx);
1239 // ELU
1240 const auto y = x >= 0
1241 ? x * scale_coef
1242 : ((std::exp(x * input_scale_coef) - 1) * alpha_float * scale_coef);
1243
1244 // quantize
1245 return at::native::quantize_val<scalar_t>(o_scale, o_zp, y);
1246 },
1247 [&](qVec value_qx) -> qVec {
1248 // dequantize
1249 auto dx_vec_vec = value_qx.dequantize(i_scale_vec, i_zero_point_vec,
1250 i_scale_neg_zp_premul_vec);
1251 for (auto & value : dx_vec_vec) {
1252 // quickly check if any elements are below zero
1253 const auto cmp_to_zero = value > zero_vec;
1254
1255 if (cmp_to_zero.zero_mask()) {
1256
1257 Vec dx_vec_copy_neg_elu = value * one_vec;
1258 // calculate the negative part of ELU on the copy
1259 dx_vec_copy_neg_elu = dx_vec_copy_neg_elu * input_scale_coef_vec;
1260 dx_vec_copy_neg_elu = dx_vec_copy_neg_elu.exp();
1261 dx_vec_copy_neg_elu = dx_vec_copy_neg_elu - one_vec;
1262 dx_vec_copy_neg_elu = dx_vec_copy_neg_elu * alpha_vec;
1263 // blend
1264 value = Vec::blendv(dx_vec_copy_neg_elu, value,
1265 value > zero_vec);
1266 }
1267
1268 value = value * scale_coef_vec;
1269 }
1270 // quantize
1271 return qVec::quantize(dx_vec_vec, o_scale, o_zp, inv_o_scale);
1272 }
1273 );
1274
1275 });
1276 }
1277
1278 // Note: out is assumed to be the same size as self and other.
1279 // Note: Addition is only supported when self and out are of the same dtype.
1280 // Note: other is already assumed to be in int32, i.e., it's
1281 // round(float/self_scale)
1282 template <bool ReLUFused = false>
qadd_scalar_kernel(Tensor & out,const Tensor & self,const Scalar & other)1283 void qadd_scalar_kernel(Tensor& out, const Tensor& self, const Scalar& other) {
1284 int64_t zero_point = out.q_zero_point();
1285 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1286 float scale = out.q_scale();
1287 float inv_scale = 1.0f / scale;
1288 int64_t self_zero_point = self.q_zero_point();
1289 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1290 float self_scale = self.q_scale();
1291
1292 float multiplier = self_scale * inv_scale;
1293
1294 AT_DISPATCH_QINT_TYPES(self.scalar_type(), "qadd_scalar", [&]() {
1295 using Vec = Vectorized<scalar_t>;
1296 auto iter = TensorIterator::unary_op(out, self);
1297 auto other_val = other.to<int32_t>();
1298 auto other_vec = Vectorized<c10::qint32>(static_cast<c10::qint32>(other_val));
1299 cpu_kernel_vec(
1300 iter,
1301 [&](scalar_t a) -> scalar_t {
1302 int32_t a_sub_z = static_cast<int32_t>(a.val_) -
1303 static_cast<int32_t>(self_zero_point);
1304 int32_t c = a_sub_z + other_val;
1305 scalar_t res = at::native::requantize_from_int<scalar_t>(
1306 multiplier, zero_point, c);
1307 if constexpr (ReLUFused) {
1308 res.val_ = std::max<scalar_t::underlying>(res.val_, zero_point);
1309 }
1310 return res;
1311 },
1312 [&](Vec a) -> Vec {
1313 Vec::int_vec_return_type a_sub_z =
1314 a.widening_subtract(Vec(static_cast<scalar_t>(self_zero_point)));
1315 Vec::int_vec_return_type c;
1316 for (const auto i : c10::irange(Vec::int_num_vecs())) {
1317 c[i] = a_sub_z[i] + other_vec;
1318 }
1319 Vec rv = Vec::requantize_from_int(c, multiplier, zero_point);
1320 if constexpr (ReLUFused) {
1321 rv = rv.maximum(Vec(static_cast<scalar_t>(zero_point)));
1322 }
1323 return rv;
1324 });
1325 });
1326 }
1327 // Note: out is assumed to be the same size as self and other.
1328 // Note: Addition is only supported when self, other, out are of the same dtype.
1329 template <bool ReLUFused = false>
qadd_kernel(Tensor & out,const Tensor & self,const Tensor & other)1330 void qadd_kernel(Tensor& out, const Tensor& self, const Tensor& other) {
1331 int64_t zero_point = out.q_zero_point();
1332 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1333 float scale = out.q_scale();
1334 float inv_scale = 1.0f / scale;
1335 int64_t self_zero_point = self.q_zero_point();
1336 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1337 float self_scale = self.q_scale();
1338 int64_t other_zero_point = other.q_zero_point();
1339 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1340 float other_scale = other.q_scale();
1341
1342 // Broadcast out the parameters here to amortize out that cost across
1343 // loop iterations.
1344 // TODO: we can optimize dequantization by doing a premultiplication
1345 // of the zero point by scale and doing FMA on scale*x_q - (scale*zero_point)
1346 auto self_zero_point_vec = Vectorized<float>((float)self_zero_point);
1347 auto self_scale_vec = Vectorized<float>(self_scale);
1348 auto other_zero_point_vec = Vectorized<float>((float)other_zero_point);
1349 auto other_scale_vec = Vectorized<float>(other_scale);
1350
1351 auto self_scale_neg_zp_premul_vec = self_scale_vec * self_zero_point_vec.neg();
1352 auto other_scale_zp_premul_vec = other_scale_vec * other_zero_point_vec.neg();
1353
1354 auto iter = TensorIterator::borrowing_binary_op(out, self, other);
1355
1356 AT_DISPATCH_QINT_TYPES(out.scalar_type(), "qadd", [&]() {
1357 using Vec = Vectorized<scalar_t>;
1358 cpu_kernel_vec(
1359 iter,
1360 [&](scalar_t a, scalar_t b) -> scalar_t {
1361 const auto da =
1362 at::native::dequantize_val(self_scale, self_zero_point, a);
1363 const auto db =
1364 at::native::dequantize_val(other_scale, other_zero_point, b);
1365 float c = da + db;
1366 if (ReLUFused) {
1367 c = std::max<float>(c, 0.0);
1368 }
1369 return at::native::quantize_val<scalar_t>(scale, zero_point, c);
1370 },
1371 [&](Vec a, Vec b) -> Vec {
1372 const auto da = a.dequantize(
1373 self_scale_vec, self_zero_point_vec, self_scale_neg_zp_premul_vec);
1374 const auto db = b.dequantize(
1375 other_scale_vec, other_zero_point_vec, other_scale_zp_premul_vec);
1376 Vec::float_vec_return_type retvals;
1377 for (const auto i : c10::irange(Vec::float_num_vecs())) {
1378 auto c = da[i] + db[i];
1379 if constexpr (ReLUFused) {
1380 c = vec::maximum(c, Vectorized<float>(0.0f));
1381 }
1382 retvals[i] = c;
1383 }
1384 // TODO: fbgemm::Quantize doesn't support taking in the
1385 // pre-broadcasted parameters. We might be able to save some cycles by
1386 // enabling that in the API.
1387 // TODO: specialize fbgemm::Quantize for a single vector and make it
1388 // inlineable. This could help with interleaving as suggested by the
1389 // TensorIterator implementations
1390 auto rv = Vec::quantize(retvals, scale, zero_point, inv_scale);
1391 return rv;
1392 });
1393 });
1394 }
1395
1396 // Note: out is assumed to be the same size as self and other.
1397 // Note: Multiplication is only supported when self, other, out are of the same
1398 // dtype.
1399 template <bool ReLUFused = false>
qmul_kernel(Tensor & out,const Tensor & self,const Tensor & other)1400 void qmul_kernel(Tensor& out, const Tensor& self, const Tensor& other) {
1401 int64_t zero_point = out.q_zero_point();
1402 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1403 float scale = out.q_scale();
1404 float inv_scale = 1.0f / scale;
1405 int64_t self_zero_point = self.q_zero_point();
1406 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1407 float self_scale = self.q_scale();
1408 int64_t other_zero_point = other.q_zero_point();
1409 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1410 float other_scale = other.q_scale();
1411
1412 float multiplier = self_scale * other_scale * inv_scale;
1413
1414 auto iter = TensorIterator::borrowing_binary_op(out, self, other);
1415
1416 AT_DISPATCH_QINT_TYPES(out.scalar_type(), "qmul", [&]() {
1417 using Vec = Vectorized<scalar_t>;
1418 cpu_kernel_vec(
1419 iter,
1420 [&](scalar_t a, scalar_t b) -> scalar_t {
1421 int32_t a_sub_z = static_cast<int32_t>(a.val_) -
1422 static_cast<int32_t>(self_zero_point);
1423 int32_t b_sub_z = static_cast<int32_t>(b.val_) -
1424 static_cast<int32_t>(other_zero_point);
1425 int32_t c = a_sub_z * b_sub_z;
1426 scalar_t res = at::native::requantize_from_int<scalar_t>(
1427 multiplier, zero_point, c);
1428 if constexpr (ReLUFused) {
1429 res.val_ = std::max<scalar_t::underlying>(res.val_, zero_point);
1430 }
1431 return res;
1432 },
1433 [&](Vec a, Vec b) -> Vec {
1434 Vec::int_vec_return_type a_sub_zp =
1435 a.widening_subtract(Vec(static_cast<scalar_t>(self_zero_point)));
1436 Vec::int_vec_return_type b_sub_zp =
1437 b.widening_subtract(Vec(static_cast<scalar_t>(other_zero_point)));
1438 Vec::int_vec_return_type c;
1439 for (const auto i : c10::irange(Vec::int_num_vecs())) {
1440 c[i] = a_sub_zp[i] * b_sub_zp[i];
1441 }
1442 Vec rv = Vec::requantize_from_int(c, multiplier, zero_point);
1443 if constexpr (ReLUFused) {
1444 rv = rv.maximum(Vec(static_cast<scalar_t>(zero_point)));
1445 }
1446 return rv;
1447 });
1448 });
1449 }
1450
1451 template <typename scalar_t, typename scalar_t_underlying>
_qmaxpool_2d_nhwc_kernel(const Tensor & qx,int64_t iC,int64_t iH,int64_t iW,int64_t oH,int64_t oW,int64_t kH,int64_t kW,int64_t sH,int64_t sW,int64_t pH,int64_t pW,int64_t dH,int64_t dW,Tensor & qy)1452 void _qmaxpool_2d_nhwc_kernel(
1453 const Tensor& qx,
1454 int64_t iC, // input/output channels
1455 int64_t iH,
1456 int64_t iW, // input sizes
1457 int64_t oH,
1458 int64_t oW, // output sizes
1459 int64_t kH,
1460 int64_t kW, // kernel size
1461 int64_t sH,
1462 int64_t sW, // strides
1463 int64_t pH,
1464 int64_t pW, // padding
1465 int64_t dH,
1466 int64_t dW, // dilation
1467 Tensor& qy) {
1468 scalar_t* idata = static_cast<scalar_t*>(qx.data_ptr());
1469 scalar_t* odata = static_cast<scalar_t*>(qy.data_ptr());
1470
1471 int64_t nBatch = qx.size(0);
1472 at::parallel_for(0, nBatch * oH * oW, 0, [&](int64_t begin, int64_t end) {
1473 int64_t b{0}, row{0}, col{0};
1474 data_index_init(begin, b, nBatch, row, oH, col, oW);
1475
1476 for (const auto i : c10::irange(begin, end)) {
1477 auto* i_p = reinterpret_cast<scalar_t_underlying*>(idata + b * iW * iH * iC);
1478 auto* o_p = reinterpret_cast<scalar_t_underlying*>(odata + i * iC);
1479
1480 // Loop over reduction block
1481 int64_t h_start = row * sH - pH;
1482 int64_t w_start = col * sW - pW;
1483 int64_t h_end = std::min(h_start + (kH - 1) * dH + 1, iH);
1484 int64_t w_end = std::min(w_start + (kW - 1) * dW + 1, iW);
1485 while (h_start < 0)
1486 h_start += dH;
1487 while (w_start < 0)
1488 w_start += dW;
1489
1490 int64_t c = 0;
1491
1492 // Interleaved vector loop 4x
1493 constexpr auto vec_width = Vectorized<scalar_t>::size();
1494 for (; c + 4 * vec_width <= iC; c += 4 * vec_width) {
1495 Vectorized<scalar_t> acc{
1496 scalar_t(std::numeric_limits<scalar_t_underlying>::lowest())};
1497 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
1498 Vectorized<scalar_t> accs[4] = {acc, acc, acc, acc};
1499 int64_t tcntr = 0;
1500 int64_t x, y;
1501 for (y = h_start; y < h_end; y += dH) {
1502 for (x = w_start; x < w_end; x += dW) {
1503 for (const auto i : c10::irange(4)) {
1504 tcntr = y * iW + x;
1505 auto vals = Vectorized<scalar_t>::loadu(
1506 i_p + tcntr * iC + c + Vectorized<scalar_t>::size() * i);
1507 accs[i] = vec::maximum(accs[i], vals);
1508 }
1509 } // for x
1510 } // for y
1511 for (const auto i : c10::irange(4)) {
1512 accs[i].store(o_p + c + Vectorized<scalar_t>::size() * i);
1513 }
1514 } // for c
1515
1516 // Vector loop
1517 for (; c + vec_width <= iC; c += vec_width) {
1518 Vectorized<scalar_t> acc{
1519 scalar_t(std::numeric_limits<scalar_t_underlying>::lowest())};
1520 int64_t tcntr = 0;
1521 int64_t x, y;
1522 for (y = h_start; y < h_end; y += dH) {
1523 for (x = w_start; x < w_end; x += dW) {
1524 tcntr = y * iW + x;
1525 auto vals = Vectorized<scalar_t>::loadu(i_p + tcntr * iC + c);
1526 acc = vec::maximum(acc, vals);
1527 } // for x
1528 } // for y
1529 acc.store(o_p + c);
1530 } // for c
1531
1532 for (; c < iC; ++c) {
1533 auto max_val = std::numeric_limits<scalar_t_underlying>::lowest();
1534 int64_t tcntr = 0;
1535 int64_t x, y;
1536 for (y = h_start; y < h_end; y += dH) {
1537 for (x = w_start; x < w_end; x += dW) {
1538 tcntr = y * iW + x;
1539 auto val = *(i_p + tcntr * iC + c);
1540 max_val = std::max(max_val, val);
1541 } // for x
1542 } // for y
1543
1544 o_p[c] = max_val;
1545 } // for c
1546
1547 data_index_step(b, nBatch, row, oH, col, oW);
1548 }
1549 });
1550 }
1551
qmaxpool_2d_nhwc_kernel(const Tensor & qx,int64_t iC,int64_t iH,int64_t iW,int64_t oH,int64_t oW,int64_t kH,int64_t kW,int64_t sH,int64_t sW,int64_t pH,int64_t pW,int64_t dH,int64_t dW,Tensor & qy)1552 void qmaxpool_2d_nhwc_kernel(
1553 const Tensor& qx,
1554 int64_t iC, // input/output channels
1555 int64_t iH,
1556 int64_t iW, // input sizes
1557 int64_t oH,
1558 int64_t oW, // output sizes
1559 int64_t kH,
1560 int64_t kW, // kernel size
1561 int64_t sH,
1562 int64_t sW, // strides
1563 int64_t pH,
1564 int64_t pW, // padding
1565 int64_t dH,
1566 int64_t dW, // dilation
1567 Tensor& qy) {
1568 if (qx.scalar_type() == ScalarType::Byte) {
1569 AT_DISPATCH_INTEGRAL_TYPES(qx.scalar_type(), "max_pool2d_nhwc", [&]() {
1570 _qmaxpool_2d_nhwc_kernel<scalar_t, scalar_t>(qx, iC, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, qy);
1571 });
1572 } else {
1573 AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "max_pool2d_nhwc", [&]() {
1574 _qmaxpool_2d_nhwc_kernel<scalar_t, scalar_t::underlying>(qx, iC, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, qy);
1575 });
1576 }
1577 }
1578
qmaxpool_3d_nthwc_kernel(const Tensor & qx,int64_t iC,int64_t iT,int64_t iH,int64_t iW,int64_t oT,int64_t oH,int64_t oW,int64_t kT,int64_t kH,int64_t kW,int64_t sT,int64_t sH,int64_t sW,int64_t pT,int64_t pH,int64_t pW,int64_t dT,int64_t dH,int64_t dW,Tensor & qy)1579 void qmaxpool_3d_nthwc_kernel(
1580 const Tensor& qx,
1581 int64_t iC, // input/output channels
1582 int64_t iT,
1583 int64_t iH,
1584 int64_t iW, // input sizes
1585 int64_t oT,
1586 int64_t oH,
1587 int64_t oW, // output sizes
1588 int64_t kT,
1589 int64_t kH,
1590 int64_t kW, // kernel size
1591 int64_t sT,
1592 int64_t sH,
1593 int64_t sW, // strides
1594 int64_t pT,
1595 int64_t pH,
1596 int64_t pW, // padding
1597 int64_t dT,
1598 int64_t dH,
1599 int64_t dW, // dilation
1600 Tensor& qy) {
1601 AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "max_pool3d_nthwc", [&]() {
1602 scalar_t* idata = static_cast<scalar_t*>(qx.data_ptr());
1603 scalar_t* odata = static_cast<scalar_t*>(qy.data_ptr());
1604 int64_t nBatch = qx.size(0);
1605 at::parallel_for(0, nBatch * oT * oH * oW, 0, [&](int64_t begin, int64_t end) {
1606 int64_t b{0}, time{0}, row{0}, col{0};
1607
1608 data_index_init(begin, b, nBatch, time, oT, row, oH, col, oW);
1609
1610 for (const auto i : c10::irange(begin, end)) {
1611 auto* i_p = reinterpret_cast<scalar_t::underlying*>(idata + b * iT * iW * iH * iC);
1612 auto* o_p = reinterpret_cast<scalar_t::underlying*>(odata + i * iC);
1613
1614 // Loop over reduction block
1615 int64_t t_start = time * sT - pT;
1616 int64_t h_start = row * sH - pH;
1617 int64_t w_start = col * sW - pW;
1618 int64_t t_end = std::min(t_start + (kT - 1) * dT + 1, iT);
1619 int64_t h_end = std::min(h_start + (kH - 1) * dH + 1, iH);
1620 int64_t w_end = std::min(w_start + (kW - 1) * dW + 1, iW);
1621 while (t_start < 0)
1622 t_start += dT;
1623 while (h_start < 0)
1624 h_start += dH;
1625 while (w_start < 0)
1626 w_start += dW;
1627
1628 int64_t c = 0;
1629 constexpr auto vec_width = Vectorized<scalar_t>::size();
1630 // Vector loop
1631 for (; c + vec_width <= iC; c += vec_width) {
1632 Vectorized<scalar_t> acc{
1633 scalar_t(std::numeric_limits<scalar_t::underlying>::lowest())};
1634 int64_t tcntr = 0;
1635 int64_t t, x, y;
1636 for (t = t_start; t < t_end; t += dT) {
1637 for (y = h_start; y < h_end; y += dH) {
1638 for (x = w_start; x < w_end; x += dW) {
1639 tcntr = t * iH * iW + y * iW + x;
1640 auto vals = Vectorized<scalar_t>::loadu(i_p + tcntr * iC + c);
1641 acc = vec::maximum(acc, vals);
1642 } // for x
1643 } // for y
1644 } // for t
1645 acc.store(o_p + c);
1646 } // for c
1647
1648 for (; c < iC; ++c) {
1649 auto max_val = std::numeric_limits<scalar_t::underlying>::lowest();
1650 int64_t tcntr = 0;
1651 int64_t t, x, y;
1652 for (t = t_start; t < t_end; t += dT) {
1653 for (y = h_start; y < h_end; y += dH) {
1654 for (x = w_start; x < w_end; x += dW) {
1655 tcntr = t * iH * iW + y * iW + x;
1656 auto val = *(i_p + tcntr * iC + c);
1657 max_val = std::max(max_val, val);
1658 } // for x
1659 } // for y
1660 } // for t
1661 o_p[c] = max_val;
1662 } // for c
1663 data_index_step(b, nBatch, time, oT, row, oH, col, oW);
1664 }
1665
1666 });
1667
1668 });
1669 }
1670
1671 template <typename T>
do_avg_pool_nhwc_on_AVX_n(const typename T::underlying * i_p,typename T::underlying * o_p,int & c_start,int input_zero_point_m_size,int output_zero_point,float multiplier,int dstart,int dend,int hstart,int hend,int wstart,int wend,int dsize,int hsize,int wsize,int csize)1672 void do_avg_pool_nhwc_on_AVX_n(
1673 const typename T::underlying* i_p,
1674 typename T::underlying* o_p,
1675 int& c_start,
1676 int input_zero_point_m_size,
1677 int output_zero_point,
1678 float multiplier,
1679 int dstart,
1680 int dend,
1681 int hstart,
1682 int hend,
1683 int wstart,
1684 int wend,
1685 int dsize,
1686 int hsize,
1687 int wsize,
1688 int csize) {
1689 #if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER)
1690 // buffer for channel accumulator, used to interchange channel-loop
1691 // to inner-most, so that memory access of the input tensor data is
1692 // continuous.
1693 #ifdef CPU_CAPABILITY_AVX2
1694 constexpr int cb_size = 16;
1695 #else
1696 constexpr int cb_size = 8;
1697 #endif
1698 constexpr int vec_width = Vectorized<T>::size() / 4;
1699 constexpr int cb_step = cb_size * vec_width;
1700 Vectorized<int32_t> acc_buffer[cb_size];
1701 Vectorized<float> acc_buffer_fp[cb_size];
1702
1703 #ifdef CPU_CAPABILITY_AVX2
1704 if (vec_width == 8) {
1705 #else
1706 if (vec_width == 16) {
1707 #endif
1708 for (int c = c_start; c < csize; c += cb_step) {
1709 int cend = std::min(cb_size, (csize - c) / vec_width);
1710 // initialize loop
1711 for (const auto ic : c10::irange(cend)) {
1712 acc_buffer[ic] = Vectorized<int32_t>(input_zero_point_m_size);
1713 }
1714 // compute loop
1715 for (const auto id : c10::irange(dstart, dend)) {
1716 for (const auto ih : c10::irange(hstart, hend)) {
1717 for (const auto iw : c10::irange(wstart, wend)) {
1718 const int i_idx =
1719 (id * wsize * hsize + ih * wsize + iw) *
1720 csize +
1721 c;
1722 for (const auto ic : c10::irange(cend)) {
1723 auto vals = vec::convert_to_int32<typename T::underlying>(
1724 i_p + i_idx + ic * vec_width);
1725 acc_buffer[ic] = acc_buffer[ic] + vals;
1726 }
1727 }
1728 }
1729 }
1730 // convert int32 accumulative to fp32
1731 vec::convert((int*)acc_buffer, (float*)acc_buffer_fp, cend * vec_width);
1732
1733 // first quantize using AVX2 or AVX512 using 32 lanes, then 8, finally falls
1734 // back to single
1735 #ifdef CPU_CAPABILITY_AVX2
1736 QuantizeAvx2<typename T::underlying>(
1737 (float*)acc_buffer_fp,
1738 o_p + c,
1739 cend * vec_width,
1740 multiplier,
1741 output_zero_point);
1742 #else
1743 QuantizeAvx512<typename T::underlying>(
1744 (float*)acc_buffer_fp,
1745 o_p + c,
1746 cend * vec_width,
1747 multiplier,
1748 output_zero_point);
1749 #endif
1750 }
1751 c_start = csize / vec_width * vec_width;
1752 }
1753 #endif
1754 }
1755
1756 template <typename T>
1757 void do_avg_pool_on_AVX_n(
1758 typename T::underlying* i_p,
1759 typename T::underlying* o_p,
1760 int64_t& c,
1761 int64_t channel_size,
1762 int64_t channel_multiplier,
1763 int32_t input_zero_point_m_size,
1764 int32_t output_zero_point,
1765 float multiplier,
1766 int64_t dstart,
1767 int64_t dend,
1768 int64_t hstart,
1769 int64_t hend,
1770 int64_t wstart,
1771 int64_t wend,
1772 int64_t stride_C,
1773 int64_t stride_D,
1774 int64_t stride_H,
1775 int64_t stride_W) {
1776 #if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER)
1777 constexpr int vec_width = Vectorized<T>::size() / 4;
1778 #ifdef CPU_CAPABILITY_AVX2
1779 if (vec_width == 8) {
1780 #else
1781 if (vec_width == 16) {
1782 #endif
1783 for (; c + vec_width <= channel_size; c += vec_width) {
1784 int64_t tcntr = 0;
1785
1786 Vectorized<int32_t> acc(input_zero_point_m_size);
1787 for (const auto id : c10::irange(dstart, dend)) {
1788 for (const auto ih : c10::irange(hstart, hend)) {
1789 for (const auto iw : c10::irange(wstart, wend)) {
1790 tcntr = id * stride_D + ih * stride_H + iw * stride_W;
1791 auto vals = vec::convert_to_int32<typename T::underlying>(
1792 i_p + tcntr * channel_multiplier + c * stride_C);
1793 acc = acc + vals;
1794 }
1795 }
1796 }
1797 int32_t acc_int[vec_width];
1798 float acc_fp[vec_width];
1799 acc.store(acc_int);
1800 vec::convert(acc_int, acc_fp, vec_width);
1801 at::native::quantize_vec<T>(
1802 1.0f / multiplier,
1803 output_zero_point,
1804 acc_fp,
1805 reinterpret_cast<T*>(o_p + c),
1806 vec_width);
1807 }
1808 }
1809 #endif
1810 }
1811
1812 template <typename T>
1813 void _qadaptive_avg_pool_kernel(
1814 const Tensor& qx,
1815 Tensor& qy,
1816 int64_t nBatch,
1817 int64_t sizeC,
1818 int64_t isizeD, // Set to 1 for 2d
1819 int64_t isizeH,
1820 int64_t isizeW,
1821 int64_t osizeD, // Set to 1 for 2d
1822 int64_t osizeH,
1823 int64_t osizeW,
1824 int64_t istrideB,
1825 int64_t istrideC,
1826 int64_t istrideD, // Set to 1 for 2d
1827 int64_t istrideH,
1828 int64_t istrideW) {
1829
1830 T* idata = static_cast<T*>(qx.data_ptr());
1831 T* odata = static_cast<T*>(qy.data_ptr());
1832
1833 const float input_scale = qx.q_scale();
1834 const float output_scale = qy.q_scale();
1835 const int input_zero_point = qx.q_zero_point();
1836 const int output_zero_point = qy.q_zero_point();
1837
1838 at::parallel_for(0, nBatch, 0, [&](int64_t batch_start, int64_t batch_end) {
1839 for (const auto b : c10::irange(batch_start, batch_end)) {
1840 auto* i_p = reinterpret_cast<typename T::underlying*>(
1841 idata + b * istrideB);
1842
1843 for (const auto od : c10::irange(osizeD)) {
1844 int istartD = (int)std::floor((float)(od * isizeD) / osizeD);
1845 int iendD = (int)std::ceil((float)((od + 1) * isizeD) / osizeD);
1846 int kD = iendD - istartD;
1847 for (const auto oh : c10::irange(osizeH)) {
1848 int istartH = (int)std::floor((float)(oh * isizeH) / osizeH);
1849 int iendH = (int)std::ceil((float)((oh + 1) * isizeH) / osizeH);
1850 int kH = iendH - istartH;
1851 for (const auto ow : c10::irange(osizeW)) {
1852 auto* o_p = reinterpret_cast<typename T::underlying*>(
1853 odata +
1854 b * osizeD * osizeH * osizeW * sizeC +
1855 od * osizeH * osizeW * sizeC +
1856 oh * osizeW * sizeC +
1857 ow * sizeC);
1858 int istartW = (int)std::floor((float)(ow * isizeW) / osizeW);
1859 int iendW = (int)std::ceil((float)((ow + 1) * isizeW) / osizeW);
1860 int kW = iendW - istartW;
1861 int size = kD * kH * kW;
1862 float multiplier = input_scale / output_scale / size;
1863 int input_zero_point_m_size = -input_zero_point * size;
1864 int64_t c = 0;
1865 // For int8 or uint8quantization, we implicitly use int32 as
1866 // accumulation Or else, it will go to the slow path
1867 // TODO: support 16bit, 32bit, and etc.
1868 auto* internal_i_p = i_p +
1869 istartD * istrideD +
1870 istartH * istrideH +
1871 istartW * istrideW;
1872
1873 // Note: If AVX is not available, `do_avg_pool_on_AVX_n is a noop.
1874 // In that case, the following loop takes over
1875 // TODO: more vectorization with loop interleaving
1876 do_avg_pool_on_AVX_n<T>(
1877 internal_i_p,
1878 o_p,
1879 c,
1880 sizeC,
1881 1,
1882 input_zero_point_m_size,
1883 output_zero_point,
1884 multiplier,
1885 0,
1886 kD,
1887 0,
1888 kH,
1889 0,
1890 kW,
1891 istrideC,
1892 istrideD,
1893 istrideH,
1894 istrideW);
1895 // 1) The following loop handles the remaining channels
1896 // 2) It also handles the Non-AVX2 path
1897 for (; c < sizeC; ++c) {
1898 int32_t acc_int32 = input_zero_point_m_size;
1899 int64_t tcntr = 0;
1900 for (const auto id : c10::irange(kD)) {
1901 for (const auto ih : c10::irange(kH)) {
1902 for (const auto iw : c10::irange(kW)) {
1903 tcntr = id * istrideD +
1904 ih * istrideH +
1905 iw * istrideW;
1906 auto val = *(internal_i_p + tcntr + c * istrideC);
1907 acc_int32 += val;
1908 }
1909 }
1910 }
1911 // clamp
1912 o_p[c] = at::native::quantize_val<T>(1.0f / multiplier,
1913 output_zero_point,
1914 acc_int32).val_;
1915 } // c
1916 } // oh
1917 } // ow
1918 } // od
1919 }
1920 });
1921 }
1922
1923 void qadaptive_avg_pool2d_nhwc_kernel(
1924 const Tensor& qx,
1925 Tensor& qy,
1926 int64_t nBatch,
1927 int64_t sizeC,
1928 int64_t isizeH,
1929 int64_t isizeW,
1930 int64_t osizeH,
1931 int64_t osizeW,
1932 int64_t istrideB,
1933 int64_t istrideC,
1934 int64_t istrideH,
1935 int64_t istrideW) {
1936 AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "adaptive_avg_pool2d_nhwc", [&]() {
1937 _qadaptive_avg_pool_kernel<scalar_t>(
1938 qx,
1939 qy,
1940 nBatch,
1941 sizeC,
1942 /*isizeD=*/1,
1943 isizeH,
1944 isizeW,
1945 /*osizeD=*/1,
1946 osizeH,
1947 osizeW,
1948 istrideB,
1949 istrideC,
1950 /*istrideD=*/1,
1951 istrideH,
1952 istrideW);
1953 }
1954 );
1955 }
1956
1957 void qadaptive_avg_pool3d_ndhwc_kernel(
1958 const Tensor& qx,
1959 Tensor& qy,
1960 int64_t nBatch,
1961 int64_t sizeC,
1962 int64_t isizeD,
1963 int64_t isizeH,
1964 int64_t isizeW,
1965 int64_t osizeD,
1966 int64_t osizeH,
1967 int64_t osizeW,
1968 int64_t istrideB,
1969 int64_t istrideC,
1970 int64_t istrideD,
1971 int64_t istrideH,
1972 int64_t istrideW) {
1973 AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "adaptive_avg_pool3d_ndhwc", [&]() {
1974 _qadaptive_avg_pool_kernel<scalar_t>(
1975 qx,
1976 qy,
1977 nBatch,
1978 sizeC,
1979 isizeD,
1980 isizeH,
1981 isizeW,
1982 osizeD,
1983 osizeH,
1984 osizeW,
1985 istrideB,
1986 istrideC,
1987 istrideD,
1988 istrideH,
1989 istrideW);
1990 }
1991 );
1992 }
1993
1994 template <typename T>
1995 void _qavg_pool_nhwc_kernel(
1996 const Tensor& qx,
1997 Tensor& qy,
1998 int64_t nBatch,
1999 int64_t nInputPlane,
2000 int64_t inputWidth,
2001 int64_t inputHeight,
2002 int64_t inputDepth,
2003 int64_t outputWidth,
2004 int64_t outputHeight,
2005 int64_t outputDepth,
2006 int kW,
2007 int kH,
2008 int kD,
2009 int dW,
2010 int dH,
2011 int dD,
2012 int padW,
2013 int padH,
2014 int padD,
2015 bool count_include_pad,
2016 std::optional<int64_t> divisor_override) {
2017 T* idata = static_cast<T*>(qx.data_ptr());
2018 T* odata = static_cast<T*>(qy.data_ptr());
2019 int strideC = 1;
2020 int strideW = strideC * nInputPlane;
2021 int istrideH = strideW * inputWidth;
2022 int istrideD = istrideH * inputHeight;
2023 int istrideB = istrideD * inputDepth;
2024
2025 // lift these operations outside the loop to reduce access overheads
2026 float input_scale = qx.q_scale();
2027 float output_scale = qy.q_scale();
2028 int input_zero_point = qx.q_zero_point();
2029 int output_zero_point = qy.q_zero_point();
2030 int64_t divisor_override_factor =
2031 divisor_override.has_value() ? divisor_override.value() : 0;
2032
2033 at::parallel_for(0, nBatch * outputDepth * outputHeight * outputWidth, 0, [&](int64_t begin, int64_t end) {
2034 int64_t b{0}, od{0}, oh{0}, ow{0};
2035 data_index_init(begin, b, nBatch, od, outputDepth, oh, outputHeight, ow, outputWidth);
2036
2037 for (const auto i : c10::irange(begin, end)) {
2038 auto* i_p = reinterpret_cast<typename T::underlying*>(idata + b * istrideB);
2039 auto* o_p = reinterpret_cast<typename T::underlying*>(odata + i * strideW);
2040 int dstart = od * dD - padD;
2041 int hstart = oh * dH - padH;
2042 int wstart = ow * dW - padW;
2043
2044 int dend = std::min(dstart + kD, (int)inputDepth + padD);
2045 int hend = std::min(hstart + kH, (int)inputHeight + padH);
2046 int wend = std::min(wstart + kW, (int)inputWidth + padW);
2047 int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
2048
2049 dstart = std::max(dstart, 0);
2050 hstart = std::max(hstart, 0);
2051 wstart = std::max(wstart, 0);
2052 dend = std::min(dend, (int)inputDepth);
2053 hend = std::min(hend, (int)inputHeight);
2054 wend = std::min(wend, (int)inputWidth);
2055
2056 int size = (dend - dstart) * (hend - hstart) * (wend - wstart);
2057 int divide_size = count_include_pad ? pool_size : size;
2058 int divide_factor =
2059 divisor_override_factor ? divisor_override_factor : divide_size;
2060 float multiplier = input_scale / output_scale / divide_factor;
2061 int input_zero_point_m_size = -input_zero_point * size;
2062
2063 int c_start = 0;
2064
2065 // For int8 quantization, we implicitly use int32 as accumulation
2066 // Or else, it will go to the slow path
2067 // TODO: support 16bit, 32bit, and etc.
2068 do_avg_pool_nhwc_on_AVX_n<T>(
2069 i_p,
2070 o_p,
2071 c_start,
2072 input_zero_point_m_size,
2073 output_zero_point,
2074 multiplier,
2075 dstart,
2076 dend,
2077 hstart,
2078 hend,
2079 wstart,
2080 wend,
2081 inputDepth,
2082 inputHeight,
2083 inputWidth,
2084 nInputPlane);
2085
2086 // 1) The following loop handles the remaining channels
2087 // 2) It also handles the Non-AVX2 path
2088 for (const auto c: c10::irange(c_start, nInputPlane)) {
2089 int32_t acc_int32 = input_zero_point_m_size;
2090 for (const auto id : c10::irange(dstart, dend)) {
2091 for (const auto ih : c10::irange(hstart, hend)) {
2092 for (const auto iw : c10::irange(wstart, wend)) {
2093 auto val =
2094 *(i_p + id * istrideD + ih * istrideH + iw * strideW +
2095 c * strideC);
2096 acc_int32 += val;
2097 }
2098 }
2099 }
2100 double acc_fp = acc_int32 * 1.0;
2101 // clamp
2102 o_p[c] = at::native::quantize_val<T>(
2103 1.0f / multiplier, output_zero_point, acc_fp)
2104 .val_;
2105 } // c
2106
2107 data_index_step(b, nBatch, od, outputDepth, oh, outputHeight, ow, outputWidth);
2108 }
2109 });
2110 }
2111
2112 void qavg_pool2d_nhwc_kernel(
2113 const Tensor& qx,
2114 Tensor& qy,
2115 int64_t b,
2116 int64_t nInputPlane,
2117 int64_t inputWidth,
2118 int64_t inputHeight,
2119 int64_t outputWidth,
2120 int64_t outputHeight,
2121 int kW,
2122 int kH,
2123 int dW,
2124 int dH,
2125 int padW,
2126 int padH,
2127 bool count_include_pad,
2128 std::optional<int64_t> divisor_override) {
2129 AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "avg_pool2d_nhwc", [&]() {
2130 _qavg_pool_nhwc_kernel<scalar_t>(
2131 qx,
2132 qy,
2133 b,
2134 nInputPlane,
2135 inputWidth,
2136 inputHeight,
2137 1,
2138 outputWidth,
2139 outputHeight,
2140 1,
2141 kW,
2142 kH,
2143 1,
2144 dW,
2145 dH,
2146 1,
2147 padW,
2148 padH,
2149 0,
2150 count_include_pad,
2151 divisor_override);
2152 });
2153 }
2154
2155 void qavg_pool3d_nhwc_kernel(
2156 const Tensor& qx,
2157 Tensor& qy,
2158 int64_t b,
2159 int64_t nInputPlane,
2160 int64_t inputWidth,
2161 int64_t inputHeight,
2162 int64_t inputDepth,
2163 int64_t outputWidth,
2164 int64_t outputHeight,
2165 int64_t outputDepth,
2166 int kW,
2167 int kH,
2168 int kD,
2169 int dW,
2170 int dH,
2171 int dD,
2172 int padW,
2173 int padH,
2174 int padD,
2175 bool count_include_pad,
2176 std::optional<int64_t> divisor_override) {
2177 AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "avg_pool3d_nhwc", [&]() {
2178 _qavg_pool_nhwc_kernel<scalar_t>(
2179 qx,
2180 qy,
2181 b,
2182 nInputPlane,
2183 inputWidth,
2184 inputHeight,
2185 inputDepth,
2186 outputWidth,
2187 outputHeight,
2188 outputDepth,
2189 kW,
2190 kH,
2191 kD,
2192 dW,
2193 dH,
2194 dD,
2195 padW,
2196 padH,
2197 padD,
2198 count_include_pad,
2199 divisor_override);
2200 });
2201 }
2202
2203 template <typename T>
2204 int64_t do_quantized_bilinear_on_AVX_n(
2205 const typename T::underlying*& pos1,
2206 typename T::underlying*& pos2,
2207 int64_t input_width,
2208 int64_t output_height,
2209 int64_t output_width,
2210 int64_t channels,
2211 int32_t output_zero_point,
2212 int32_t input_zero_point,
2213 float inverse_scale,
2214 const float h0lambda,
2215 const float h1lambda,
2216 const float w0lambda,
2217 const float w1lambda,
2218 const int64_t h1p,
2219 const int64_t w1p) {
2220 int64_t c = 0;
2221 #if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER)
2222 constexpr auto vec_width = Vectorized<T>::size() / 4;
2223 #ifdef CPU_CAPABILITY_AVX2
2224 if (vec_width == 8) {
2225 #else
2226 if (vec_width == 16) {
2227 #endif
2228 for (; c + vec_width <= channels; c += vec_width) {
2229 Vectorized<float> pos1_fp_v[4];
2230 Vectorized<int32_t> pos1_int_v[4];
2231 pos1_int_v[0] = vec::convert_to_int32<typename T::underlying>(pos1);
2232 pos1_int_v[1] = vec::convert_to_int32<typename T::underlying>(
2233 pos1 + w1p * channels);
2234 pos1_int_v[2] = vec::convert_to_int32<typename T::underlying>(
2235 pos1 + h1p * input_width * channels);
2236 pos1_int_v[3] = vec::convert_to_int32<typename T::underlying>(
2237 pos1 + (h1p * input_width + w1p) * channels);
2238 for (const auto i : c10::irange(4)) {
2239 int32_t pos1_int[vec_width];
2240 float pos1_fp[vec_width];
2241 pos1_int_v[i].store(pos1_int);
2242 vec::convert(pos1_int, pos1_fp, vec_width);
2243 pos1_fp_v[i] = Vectorized<float>::loadu(pos1_fp, 8);
2244 }
2245 Vectorized<float> h0lambda_v(h0lambda);
2246 Vectorized<float> h1lambda_v(h1lambda);
2247 Vectorized<float> w0lambda_v(w0lambda);
2248 Vectorized<float> w1lambda_v(w1lambda);
2249 Vectorized<float> input_zero_point_v(input_zero_point);
2250 Vectorized<float> result =
2251 h0lambda_v * (w0lambda_v * pos1_fp_v[0] + w1lambda_v * pos1_fp_v[1]) +
2252 h1lambda_v * (w0lambda_v * pos1_fp_v[2] + w1lambda_v * pos1_fp_v[3]) -
2253 input_zero_point_v;
2254 float result_fp[vec_width];
2255 result.store(result_fp);
2256 at::native::quantize_vec<T>(
2257 inverse_scale,
2258 output_zero_point,
2259 result_fp,
2260 reinterpret_cast<T*>(pos2),
2261 vec_width);
2262 pos1 += vec_width;
2263 pos2 += vec_width;
2264 }
2265 }
2266 #endif
2267 return c;
2268 }
2269
2270 void qupsample_bilinear2d_nhwc_kernel(
2271 Tensor& output,
2272 const Tensor& input,
2273 int64_t input_height,
2274 int64_t input_width,
2275 int64_t output_height,
2276 int64_t output_width,
2277 int64_t nbatch,
2278 int64_t channels,
2279 bool align_corners,
2280 std::optional<double> scales_h,
2281 std::optional<double> scales_w) {
2282 AT_DISPATCH_QINT_TYPES(input.scalar_type(), "upsample_bilinear2d_nhwc", [&]() {
2283 auto* idata = static_cast<scalar_t*>(input.data_ptr());
2284 auto* odata = static_cast<scalar_t*>(output.data_ptr());
2285 float inverse_scale = output.q_scale() / input.q_scale();
2286 const auto rheight = area_pixel_compute_scale<float>(
2287 input_height, output_height, align_corners, scales_h);
2288 const auto rwidth = area_pixel_compute_scale<float>(
2289 input_width, output_width, align_corners, scales_w);
2290
2291 auto input_q_zero_point = input.q_zero_point();
2292 auto output_q_zero_point = output.q_zero_point();
2293 at::parallel_for(0, nbatch * output_height * output_width, 0, [&](int64_t begin, int64_t end) {
2294 int64_t b{0}, h2{0}, w2{0};
2295 data_index_init(begin, b, nbatch, h2, output_height, w2, output_width);
2296
2297 for (C10_UNUSED const auto i : c10::irange(begin, end)) {
2298 auto* i_p = reinterpret_cast<typename scalar_t::underlying*>(
2299 idata + b * input_height * input_width * channels);
2300 auto* o_p = reinterpret_cast<typename scalar_t::underlying*>(
2301 odata + b * output_height * output_width * channels);
2302
2303 const auto h1r = area_pixel_compute_source_index<float>(
2304 rheight, h2, align_corners, /*cubic=*/false);
2305
2306 const int64_t h1 = h1r;
2307 const int64_t h1p = (h1 < input_height - 1) ? 1 : 0;
2308 const float h1lambda = h1r - h1;
2309 const float h0lambda = static_cast<float>(1.) - h1lambda;
2310
2311 const auto w1r = area_pixel_compute_source_index<float>(
2312 rwidth, w2, align_corners, /*cubic=*/false);
2313 const int64_t w1 = w1r;
2314 const int64_t w1p = (w1 < input_width - 1) ? 1 : 0;
2315
2316 const float w1lambda = w1r - w1;
2317 const float w0lambda = static_cast<float>(1.) - w1lambda;
2318
2319 int64_t c = 0;
2320 // We use float32 to do the computation
2321 const typename scalar_t::underlying* pos1 =
2322 i_p + (h1 * input_width + w1) * channels;
2323 typename scalar_t::underlying* pos2 =
2324 o_p + (h2 * output_width + w2) * channels;
2325 // We have to isolate this function out because the VS does not
2326 // expand the macro correctly.
2327 c = do_quantized_bilinear_on_AVX_n<scalar_t>(
2328 pos1,
2329 pos2,
2330 input_width,
2331 output_height,
2332 output_width,
2333 channels,
2334 output_q_zero_point,
2335 input_q_zero_point,
2336 inverse_scale,
2337 h0lambda,
2338 h1lambda,
2339 w0lambda,
2340 w1lambda,
2341 h1p,
2342 w1p);
2343 // 1) The following loop handles the remaining channels
2344 // 2) It also handles the Non-AVX2 path
2345 for (; c < channels; ++c) {
2346 float result = h0lambda *
2347 (w0lambda * pos1[0] + w1lambda * pos1[w1p * channels]) +
2348 h1lambda *
2349 (w0lambda * pos1[h1p * input_width * channels] +
2350 w1lambda * pos1[(h1p * input_width + w1p) * channels]);
2351 pos2[0] = at::native::quantize_val<scalar_t>(
2352 inverse_scale,
2353 output_q_zero_point,
2354 result - input_q_zero_point)
2355 .val_;
2356 pos1 += 1;
2357 pos2 += 1;
2358 } // c
2359
2360 data_index_step(b, nbatch, h2, output_height, w2, output_width);
2361 }
2362 });
2363 });
2364 }
2365
2366 void qtopk_kernel(Tensor& values,
2367 Tensor& indices,
2368 const Tensor& self,
2369 int64_t k,
2370 int64_t dim,
2371 bool largest,
2372 bool sorted) {
2373 auto sizes = self.sizes();
2374 auto iter = TensorIteratorConfig()
2375 .check_all_same_dtype(false)
2376 .resize_outputs(false)
2377 .declare_static_shape(sizes, /*squash_dims=*/dim)
2378 .add_output(values)
2379 .add_output(indices)
2380 .add_input(self)
2381 .build();
2382
2383 auto mode_values_stride = values.strides()[dim];
2384 auto mode_indices_stride = indices.strides()[dim];
2385 auto tmp_values_stride = self.strides()[dim];
2386 // If sizes is empty, the tensor is scalar. This prevents accessing an empty array.
2387 auto dim_size = sizes.empty() ? 1 : sizes[dim];
2388
2389 AT_DISPATCH_QINT_TYPES(self.scalar_type(), "qtopk_cpu", [&] {
2390 auto loop = [&](char** data, const int64_t* strides, int64_t n) {
2391 using underlying_t = typename scalar_t::underlying;
2392 static_assert(sizeof(scalar_t) == sizeof(underlying_t), "");
2393 return topk_impl_loop<underlying_t, underlying_t>(
2394 mode_values_stride, mode_indices_stride, tmp_values_stride,
2395 k, dim_size, largest, sorted, data, strides, n);
2396 };
2397
2398 int64_t grain_size = internal::GRAIN_SIZE / std::max(int64_t{1}, sizes[dim]);
2399 iter.for_each(loop, /*grain_size=*/grain_size);
2400 });
2401 }
2402
2403 template <typename T, bool ReluFused>
2404 inline void do_bn_compute(
2405 typename T::underlying* X_ptr,
2406 typename T::underlying* Y_ptr,
2407 Vectorized<float> & fake_scale,
2408 Vectorized<float> & in_zp_vec,
2409 Vectorized<float> & scale_neg_zp_premul,
2410 int64_t out_zero_point,
2411 Vectorized<T> & out_zero_point_v,
2412 float* alpha,
2413 float* beta,
2414 int64_t vec_num,
2415 int64_t kVLen
2416 ) {
2417 using Vec = Vectorized<T>;
2418 auto vals_q = Vec::loadu(X_ptr);
2419 // Fake scale of 1.0 here, should not affect performance (FMA in place of sub)
2420 auto vals_dq = vals_q.dequantize(fake_scale, in_zp_vec, scale_neg_zp_premul);
2421 for (const auto idx : c10::irange(vec_num)) {
2422 auto alpha_v = Vectorized<float>::loadu(alpha + idx * kVLen);
2423 auto beta_v = Vectorized<float>::loadu(beta + idx * kVLen);
2424 vals_dq[idx] = vec::fmadd(alpha_v, vals_dq[idx], beta_v);
2425 }
2426 // NOLINTNEXTLINE(bugprone-argument-comment)
2427 auto outputs_q = Vec::quantize(vals_dq, /*output_scale=*/1.0f, out_zero_point, /*inv_output_scale=*/1.0f);
2428 // Fake scale again
2429 if constexpr (ReluFused) {
2430 outputs_q = outputs_q.maximum(out_zero_point_v);
2431 }
2432 outputs_q.store(Y_ptr, vec_num * kVLen);
2433 }
2434
2435 template <bool ReluFused>
2436 void q_batch_norm_kernel(
2437 int64_t N,
2438 int64_t C,
2439 int64_t HxW,
2440 int64_t in_zero_point,
2441 int64_t out_zero_point,
2442 const Tensor& input,
2443 const Tensor& a,
2444 const Tensor& b,
2445 Tensor& output) {
2446
2447 AT_DISPATCH_QINT_TYPES(input.scalar_type(), "qbatch_norm", [&]() {
2448 float* alpha = a.data_ptr<float>();
2449 float* beta = b.data_ptr<float>();
2450 auto minimum = std::numeric_limits<scalar_t::underlying>::lowest();
2451 auto maximum = std::numeric_limits<scalar_t::underlying>::max();
2452 scalar_t::underlying* X =
2453 reinterpret_cast<scalar_t::underlying*>(input.data_ptr());
2454 scalar_t::underlying* Y = reinterpret_cast<scalar_t::underlying*>(output.data_ptr());
2455
2456 constexpr int kVLen = Vectorized<float>::size();
2457 const int64_t outer_size = N * HxW;
2458 using Vec = Vectorized<scalar_t>;
2459 // Hoisted variables
2460 auto in_zp_vec = Vectorized<float>(static_cast<float>(in_zero_point));
2461 auto fake_scale = Vectorized<float>(1.0f);
2462 auto scale_neg_zp_premul = fake_scale * in_zp_vec.neg();
2463 auto out_zero_point_v = Vec(scalar_t(out_zero_point));
2464 const auto lanes = static_cast<int64_t>(Vec::float_num_vecs() * kVLen);
2465 at::parallel_for(0, outer_size, 0, [&](int64_t begin, int64_t end) {
2466 for (const auto i : c10::irange(begin, end)) {
2467 auto* X_ptr = reinterpret_cast<typename scalar_t::underlying*>(X + i * C);
2468 auto* Y_ptr = reinterpret_cast<typename scalar_t::underlying*>(Y + i * C);
2469 int64_t ch = 0;
2470
2471 for(; ch + lanes <= C; ch += lanes) {
2472 do_bn_compute<scalar_t, ReluFused>(
2473 X_ptr + ch,
2474 Y_ptr + ch,
2475 fake_scale,
2476 in_zp_vec,
2477 scale_neg_zp_premul,
2478 out_zero_point,
2479 out_zero_point_v,
2480 alpha + ch,
2481 beta + ch,
2482 Vec::float_num_vecs(),
2483 kVLen
2484 );
2485 }
2486
2487 // for channel between 8 and 32, still use 32 width for performance
2488 // Benchmark shows it is faster than doing 8 channels each time
2489 int64_t elem_size = C - ch;
2490 if ((lanes == 32) && elem_size >= kVLen) {
2491 int64_t vec_num = elem_size / kVLen;
2492 std::vector<typename scalar_t::underlying> buf_in(lanes);
2493 memcpy(buf_in.data(), X_ptr + ch, vec_num * kVLen); // 3 cycles
2494 do_bn_compute<scalar_t, ReluFused>(
2495 buf_in.data(),
2496 Y_ptr + ch,
2497 fake_scale,
2498 in_zp_vec,
2499 scale_neg_zp_premul,
2500 out_zero_point,
2501 out_zero_point_v,
2502 alpha + ch,
2503 beta + ch,
2504 vec_num,
2505 kVLen
2506 );
2507 ch += vec_num * kVLen;
2508 }
2509 // for channels less than 8
2510 for (; ch < C; ++ch) {
2511 long quantized_down = out_zero_point +
2512 lrintf(alpha[ch] * (X_ptr[ch] - in_zero_point) +
2513 beta[ch]);
2514 if constexpr (ReluFused) { // static if
2515 quantized_down = std::max<long>(quantized_down, out_zero_point);
2516 }
2517 Y_ptr[ch] = std::min<long>(
2518 std::max<long>(quantized_down, minimum), maximum);
2519 }
2520 }
2521 });
2522 });
2523 }
2524
2525 void _fake_quantize_tensor_helper(
2526 Tensor& output,
2527 Tensor& mask,
2528 const Tensor& input,
2529 int fake_quant_on,
2530 float sc,
2531 int64_t z_point,
2532 int64_t quant_min,
2533 int64_t quant_max) {
2534
2535 float inv_scale = 1.0f / sc;
2536
2537 auto iter_combined = TensorIteratorConfig()
2538 .check_all_same_dtype(false)
2539 .add_output(output)
2540 .add_output(mask)
2541 .add_input(input)
2542 .build();
2543
2544 AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "fake_quantize_tensor_cachemask_kernel_type_handling", [&] {
2545 iter_combined.for_each([&](char** data, const int64_t* strides, int64_t n) {
2546 for (const auto i : c10::irange(n)) {
2547 scalar_t* output_val = (scalar_t*)(data[0] + i * strides[0]);
2548 bool* mask_val = (bool*)(data[1] + i * strides[1]);
2549 scalar_t* input_val = (scalar_t*)(data[2] + i * strides[2]);
2550
2551 const auto qval = static_cast<int64_t>(z_point + std::nearbyint(*input_val * inv_scale));
2552 if (fake_quant_on) {
2553 *output_val = (std::fmin(std::fmax(qval, quant_min), quant_max) - z_point) * sc;
2554 *mask_val = ((quant_min <= qval) && (qval <= quant_max));
2555 } else {
2556 *output_val = *input_val;
2557 *mask_val = 1;
2558 }
2559 }
2560 });
2561 });
2562 }
2563
2564 void fake_quantize_tensor_cachemask_kernel(
2565 Tensor& output,
2566 Tensor& mask,
2567 const Tensor& input,
2568 float sc,
2569 int64_t z_point,
2570 int64_t quant_min,
2571 int64_t quant_max) {
2572 _fake_quantize_tensor_helper(output, mask, input, 1, sc, z_point, quant_min, quant_max);
2573 }
2574
2575 void fake_quantize_tensor_cachemask_tensor_qparams_kernel(
2576 Tensor& output,
2577 Tensor& mask,
2578 const Tensor& input,
2579 const Tensor& sc,
2580 const Tensor& z_point,
2581 const Tensor& fake_quant_enabled,
2582 int64_t quant_min,
2583 int64_t quant_max) {
2584 _fake_quantize_tensor_helper(output, mask, input, fake_quant_enabled.item().toInt(), sc.item().toFloat(), z_point.item().toInt(), quant_min, quant_max);
2585 }
2586
2587 void fake_quantize_learnable_tensor_grad_kernel_cpu(
2588 TensorIterator& iter,
2589 float scale,
2590 float inv_scale,
2591 int64_t zero_point,
2592 int64_t quant_min,
2593 int64_t quant_max,
2594 float grad_factor) {
2595 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
2596 float dscale_small = quant_min - zero_point;
2597 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
2598 float dscale_big = quant_max - zero_point;
2599 iter.for_each([&](char** data, const int64_t* strides, int64_t n) {
2600 /* When a for_each call is made on a TensorIterator with multiple inputs and outputs,
2601 the order they are accessed follows the order they are built within the iterator.
2602 For example, if an iterator is built in the following order:
2603 auto iter = TensorIteratorConfig().
2604 .add_output(firstOutput)
2605 .add_output(secondOutput)
2606 .add_input(firstInput)
2607 .add_input(secondInput)
2608 .build()
2609 data will contain 4 pointers to pointers to values in the following order:
2610 firstOutput, secondOutput, firstInput, secondInput.
2611 Proper pointer referencing and dereferencing, along with the usage of strides
2612 (to move onto different elements), can allow accessing of the input and assignment
2613 to the right output.
2614 */
2615 for (const auto i : c10::irange(n)) {
2616 float* dXOutput = (float*)(data[0] + i * strides[0]);
2617 float* dScaleOutput = (float*)(data[1] + i * strides[1]);
2618 float* dZeroPointOutput = (float*)(data[2] + i * strides[2]);
2619 float* XInput = (float*)(data[3] + i * strides[3]);
2620 float* dYInput = (float*)(data[4] + i * strides[4]);
2621 // Calculate gradients for X.
2622 int64_t xqi = std::nearbyint(zero_point + (*XInput) * inv_scale);
2623 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
2624 *dXOutput = (*dYInput) * (xqi >= quant_min && xqi <= quant_max);
2625 // Calculate gradients for scale and zero point.
2626 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
2627 float xfqi = static_cast<float>((std::max(std::min(xqi, quant_max), quant_min) - zero_point) * scale);
2628 // Calculate gradients according to the gradient of the clamp function.
2629 if (xqi < quant_min || xqi > quant_max) {
2630 *dZeroPointOutput = (*dYInput) * (-1) * scale * grad_factor;
2631 *dScaleOutput = ((xqi < quant_min) ? ((*dYInput) * dscale_small) : ((*dYInput) * dscale_big)) * grad_factor;
2632 } else {
2633 *dZeroPointOutput = 0;
2634 *dScaleOutput = (*dYInput) * (xfqi - (*XInput)) * inv_scale * grad_factor;
2635 }
2636 }
2637 });
2638 }
2639
2640 template <typename SelfType>
2641 void _fake_quant_per_channel_cachemask_cpu_helper(
2642 TensorIterator& iter,
2643 TensorIterator& iter_mask,
2644 const int64_t quant_min,
2645 const int64_t quant_max) {
2646
2647 const auto& zero_point_dtype = iter.input_dtype(2);
2648
2649 if(at::isFloatingType(zero_point_dtype)){
2650 // When zero_point is float, quantize mirroring affine quantizer equation
2651 // Xq = Round(Xf * inv_scale + zero_point)
2652 // where zero_point is in float.
2653 AT_DISPATCH_FLOATING_TYPES_AND_HALF(zero_point_dtype, "fake_quantize_channel_cachemask_cpu_zero_point_handling", [&] {
2654 // write mask
2655 cpu_kernel(iter_mask, [=](SelfType self, float scale, scalar_t zero_point) -> bool {
2656 float inv_scale = 1.0f / scale;
2657 const auto qval = std::lrintf(zero_point + (self * inv_scale));
2658 return ((quant_min <= qval) && (qval <= quant_max));
2659 });
2660
2661 // write fake_quant
2662 cpu_kernel(iter, [=](SelfType self, float scale, scalar_t zero_point) -> SelfType {
2663 float inv_scale = 1.0f / scale;
2664 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
2665 return (std::fmin(
2666 std::fmax(
2667 std::lrintf(zero_point + self * inv_scale),
2668 quant_min),
2669 quant_max) -
2670 zero_point) *
2671 scale;
2672 });
2673 });
2674
2675 } else {
2676 // write mask
2677 cpu_kernel(iter_mask, [=](SelfType self, float scale, int32_t zero_point) -> bool {
2678 float inv_scale = 1.0f / scale;
2679 const auto qval = static_cast<int64_t>(zero_point + std::nearbyint(self * inv_scale));
2680 return ((quant_min <= qval) && (qval <= quant_max));
2681 });
2682
2683 // write fake_quant
2684 cpu_kernel(iter, [=](SelfType self, float scale, int32_t zero_point) -> SelfType {
2685 float inv_scale = 1.0f / scale;
2686 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
2687 return (std::fmin(
2688 std::fmax(
2689 static_cast<int64_t>(
2690 zero_point + std::nearbyint(self * inv_scale)),
2691 quant_min),
2692 quant_max) -
2693 zero_point) *
2694 scale;
2695 });
2696 }
2697
2698 }
2699
2700
2701 void fake_quant_per_channel_cachemask_cpu(
2702 TensorIterator& iter,
2703 TensorIterator& iter_mask,
2704 int64_t quant_min,
2705 int64_t quant_max) {
2706 // TODO(future, optional): read once, write twice. Not done at the moment
2707 // for simplicity, as we do not expect this to be a bottleneck.
2708
2709 AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "fake_quantize_channel_cachemask_cpu_type_handling", [&] {
2710 _fake_quant_per_channel_cachemask_cpu_helper<scalar_t>(iter, iter_mask, quant_min, quant_max);
2711 });
2712 }
2713
2714
2715 void fake_quantize_learnable_channel_grad_kernel_cpu(
2716 TensorIterator& iter,
2717 int64_t quant_min,
2718 int64_t quant_max,
2719 float grad_factor) {
2720 iter.for_each([&](char** data, const int64_t* strides, int64_t n) {
2721 /* To see how the input and outputs are referenced and assigned,
2722 please see the implementation of
2723 fake_quantize_learnable_tensor_grad_kernel_cpu.
2724 */
2725 for (const auto i : c10::irange(n)) {
2726 float* dx_output = (float*)(data[0] + i * strides[0]);
2727 float* dscale_output = (float*)(data[1] + i * strides[1]);
2728 float* dzero_point_output = (float*)(data[2] + i * strides[2]);
2729 float* x_input = (float*)(data[3] + i * strides[3]);
2730 float* dy_input = (float*)(data[4] + i * strides[4]);
2731 float* scale_input = (float*)(data[5] + i * strides[5]);
2732 float* zero_point_input = (float*)(data[6] + i * strides[6]);
2733
2734 float inv_scale = 1.0f / (*scale_input);
2735 float dscale_small = quant_min - (*zero_point_input);
2736 float dscale_big = quant_max - (*zero_point_input);
2737 // Calculate gradients for X.
2738 int64_t xqi = std::nearbyint((*zero_point_input) + (*x_input) * inv_scale);
2739 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
2740 *dx_output = (*dy_input) * (xqi >= quant_min && xqi <= quant_max);
2741 // Calculate gradients for scale and zero point.
2742 float xfqi = static_cast<float>((std::max(std::min(xqi, quant_max), quant_min) - (*zero_point_input)) * (*scale_input));
2743 if (xqi < quant_min || xqi > quant_max) {
2744 *dzero_point_output = (*dy_input) * (-1) * (*scale_input) * grad_factor;
2745 *dscale_output = ((xqi < quant_min) ? ((*dy_input) * dscale_small) : ((*dy_input) * dscale_big)) * grad_factor;
2746 } else {
2747 *dzero_point_output = 0;
2748 *dscale_output = (*dy_input) * (xfqi - (*x_input)) * inv_scale * grad_factor;
2749 }
2750 }
2751 });
2752 }
2753
2754 // Assumes X is composed of M groups of N elements. Normalizes each of the
2755 // groups and optionally applies affine scaling. Useful for LayerNorm,
2756 // GroupNorm, InstanceNorm.
2757 void quantized_normalize_kernel(
2758 const Tensor& X, // input tensor
2759 const Tensor& gamma, // weight (optional)
2760 const Tensor& beta, // bias (optional)
2761 bool affine_per_channel, // scaling applied elementwise if false, per channel if true
2762 int num_channels, // only used if affine_per_channel is set
2763 int num_groups, // only used if affine_per_channel is set
2764 int64_t M, // number of groups
2765 int64_t N, // number of elements in each group
2766 double eps,
2767 Tensor* Y) {
2768 AT_DISPATCH_QINT_TYPES(X.scalar_type(), "quantized_layer_norm_kernel_impl_cpu", [&]() {
2769 using qVec = vec::Vectorized<scalar_t>;
2770 using fVec = vec::Vectorized<float>;
2771
2772 TORCH_INTERNAL_ASSERT(X.numel() == M * N, "Unexpected num elements in X");
2773 TORCH_INTERNAL_ASSERT(
2774 !gamma.defined() ||
2775 (!affine_per_channel && gamma.numel() == N) ||
2776 (affine_per_channel && gamma.numel() == num_channels),
2777 "Unexpected size of gamma");
2778 TORCH_INTERNAL_ASSERT(
2779 !beta.defined() ||
2780 (!affine_per_channel && beta.numel() == N) ||
2781 (affine_per_channel && beta.numel() == num_channels),
2782 "Unexpected size of beta");
2783
2784 scalar_t* X_data = X.data_ptr<scalar_t>();
2785 const float* gamma_data = gamma.defined() ? gamma.const_data_ptr<float>() : nullptr;
2786 const float* beta_data = beta.defined() ? beta.const_data_ptr<float>() : nullptr;
2787 scalar_t* Y_data = Y->data_ptr<scalar_t>();
2788 const bool gamma_null = gamma_data == nullptr;
2789 const bool beta_null = beta_data == nullptr;
2790 int64_t x_zp = X.q_zero_point();
2791 float x_scale = X.q_scale();
2792 fVec x_zp_vec((float)x_zp);
2793 fVec one_vec(1.0f);
2794 fVec zero_vec(0.0f);
2795 float x_fake_scale = 1.0f;
2796 fVec x_fake_scale_vec(x_fake_scale);
2797 fVec x_fake_scale_zp_neg_premul_vec = x_fake_scale_vec * x_zp_vec.neg();
2798 int64_t y_zp = Y->q_zero_point();
2799 float y_scale = Y->q_scale();
2800 float y_inv_scale = 1.0f / y_scale;
2801
2802 constexpr int kFloatVLen = fVec::size();
2803 int64_t kIntVLen = kFloatVLen * qVec::float_num_vecs();
2804 int64_t kNumIntVecInLayer = N / kIntVLen;
2805 int64_t kNonVecRemInLayer = N % kIntVLen;
2806 int channels_per_group = num_channels / num_groups;
2807 int64_t NPerChannel = N / channels_per_group;
2808 int64_t kNumIntVecInChannel = NPerChannel / kIntVLen;
2809 int64_t kNonVecRemInChannel = NPerChannel % kIntVLen;
2810
2811 at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {
2812 for (const auto i : c10::irange(start, end)) {
2813
2814 scalar_t* X_ptr = X_data + i * N;
2815 scalar_t* Y_ptr = Y_data + i * N;
2816
2817 // First pass: calculate mean and variance.
2818
2819 scalar_t::underlying* X_ptr_underlying = reinterpret_cast<scalar_t::underlying*>(X_ptr);
2820 auto l_sum_shifted = hsum(X_ptr_underlying, N);
2821 auto l_sum_sq_shifted = hsum_sq(X_ptr_underlying, N);
2822 float l_mean_shifted_div_scale_x = static_cast<float>(l_sum_shifted) / N;
2823 // mean(dqX) / scale_x
2824 float layer_mean_div_scale_x = l_mean_shifted_div_scale_x - x_zp;
2825 // var(dqX) / scale_x^2
2826 float layer_var_div_scale_x_sq =
2827 std::max(static_cast<float>(l_sum_sq_shifted) / N -
2828 l_mean_shifted_div_scale_x * l_mean_shifted_div_scale_x, 0.0f);
2829 // scale_x / sqrt(var(dqX) + eps)
2830 float scale_x_div_layer_std = x_scale /
2831 std::sqrt(layer_var_div_scale_x_sq * x_scale * x_scale + eps);
2832 fVec layer_mean_div_scale_xVec(layer_mean_div_scale_x);
2833 fVec scale_x_div_layer_stdVec(scale_x_div_layer_std);
2834
2835 // Second pass: normalize
2836
2837 // TODO replace with TensorIterator implementation once #33166 is fixed.
2838 if (affine_per_channel) {
2839
2840 // if scaling per channel, scaling parameters can be pre-multiplied
2841 // with normalization parameters
2842 for (const auto chIdx : c10::irange(channels_per_group)) {
2843 int scalingIdx = (i * channels_per_group + chIdx) % (num_channels);
2844 float gamma = gamma_null ? 1.0f : gamma_data[scalingIdx];
2845 // scale_x / layer_std * gamma
2846 float gamma_p = scale_x_div_layer_std * gamma;
2847 float beta = beta_null ? 0.0f : beta_data[scalingIdx];
2848 fVec gamma_p_vec(gamma_p);
2849 fVec beta_vec(beta);
2850
2851 int64_t chStartIdx = chIdx * NPerChannel;
2852 int64_t chEndIdx = chStartIdx + NPerChannel;
2853
2854 for (const auto vecIdx : c10::irange(kNumIntVecInChannel)) {
2855 int64_t vecStartIdx = chStartIdx + vecIdx * kIntVLen;
2856 auto qXVec = qVec::loadu(X_ptr + vecStartIdx);
2857 auto dqXVec = qXVec.dequantize(x_fake_scale_vec, x_zp_vec,
2858 x_fake_scale_zp_neg_premul_vec);
2859 for (auto &dq : dqXVec) {
2860 dq =
2861 (dq - layer_mean_div_scale_xVec) *
2862 gamma_p_vec + beta_vec;
2863 }
2864 qVec::quantize(dqXVec, y_scale, y_zp, y_inv_scale)
2865 .store(Y_ptr + vecStartIdx);
2866 }
2867
2868 // Remainder
2869 if (kNonVecRemInChannel > 0) {
2870 int64_t remIdx = chEndIdx - kNonVecRemInChannel;
2871 auto qXVec = qVec::loadu(X_ptr + remIdx, kNonVecRemInChannel);
2872 auto dqXVec = qXVec.dequantize(x_fake_scale_vec, x_zp_vec,
2873 x_fake_scale_zp_neg_premul_vec);
2874 int validDqvecLen = (kNonVecRemInChannel - 1) / fVec::size() + 1;
2875 for (int i = 0; i < validDqvecLen; ++i) {
2876 auto &dq = dqXVec[i];
2877 dq =
2878 (dq - layer_mean_div_scale_xVec) *
2879 gamma_p_vec + beta_vec;
2880 }
2881 qVec::quantize(dqXVec, y_scale, y_zp, y_inv_scale)
2882 .store(Y_ptr + remIdx, kNonVecRemInChannel);
2883 }
2884 } // chIdx
2885
2886 } else {
2887
2888 for (const auto vecIdx : c10::irange(kNumIntVecInLayer)) {
2889 int64_t vecStartIdx = vecIdx * kIntVLen;
2890 auto qXVec = qVec::loadu(X_ptr + vecStartIdx);
2891 auto dqXVec = qXVec.dequantize(x_fake_scale_vec, x_zp_vec,
2892 x_fake_scale_zp_neg_premul_vec);
2893 for (const auto dqXVecIdx : c10::irange(dqXVec.size())) {
2894 int64_t vecVecStartIdx = vecStartIdx + dqXVecIdx * kFloatVLen;
2895 auto gammaVec = gamma_null
2896 ? one_vec
2897 : fVec::loadu(gamma_data + vecVecStartIdx);
2898 auto betaVec = beta_null
2899 ? zero_vec
2900 : fVec::loadu(beta_data + vecVecStartIdx);
2901 dqXVec[dqXVecIdx] =
2902 (dqXVec[dqXVecIdx] - layer_mean_div_scale_xVec) *
2903 scale_x_div_layer_stdVec * gammaVec + betaVec;
2904 qVec::quantize(dqXVec, y_scale, y_zp, y_inv_scale)
2905 .store(Y_ptr + vecStartIdx);
2906 }
2907 }
2908 for (int64_t remIdx = N - kNonVecRemInLayer; remIdx < N; remIdx++) {
2909 const float gamma_v = gamma_null ? 1.0f : gamma_data[remIdx];
2910 const float beta_v = beta_null ? 0.0f : beta_data[remIdx];
2911 auto qXVal = X_ptr[remIdx];
2912 float dqXVal = at::native::dequantize_val(x_fake_scale, x_zp, qXVal);
2913 float dqY =
2914 ((dqXVal - layer_mean_div_scale_x) * scale_x_div_layer_std) * gamma_v + beta_v;
2915 Y_ptr[remIdx] = at::native::quantize_val<scalar_t>(y_scale, y_zp, dqY);
2916 }
2917 }
2918 }
2919 }); // parallel_for
2920
2921 });
2922 }
2923
2924 void qmean_inner_dim_kernel(
2925 const Tensor& self,
2926 OptionalIntArrayRef opt_dim,
2927 bool keepdim,
2928 std::optional<ScalarType> opt_dtype,
2929 Tensor& result) {
2930 // 'opt_dtype' should be none or equal to that of input
2931 ScalarType dtype = self.scalar_type();
2932 auto in_dims = self.sizes().vec();
2933 auto out_dims = in_dims;
2934 bool is_all_reduce = !opt_dim.has_value() || opt_dim.value().empty();
2935 size_t num_dims_to_squeeze = is_all_reduce ? self.dim() : opt_dim.value().size();
2936 int64_t M = 1; // Num of groups
2937 int64_t N = 1; // Num of elements to take average of in each group
2938 for (size_t i = 0; i < in_dims.size() - num_dims_to_squeeze; ++i) {
2939 M *= in_dims[i];
2940 }
2941 for (size_t i = 0; i < num_dims_to_squeeze; ++i) {
2942 auto idx = out_dims.size() - 1 - i;
2943 N *= out_dims[idx];
2944 out_dims[idx] = 1;
2945 }
2946 if (!keepdim) {
2947 out_dims.erase(out_dims.end() - num_dims_to_squeeze, out_dims.end());
2948 }
2949 result = at::_empty_affine_quantized(
2950 out_dims,
2951 at::device(kCPU).dtype(dtype).memory_format(self.suggest_memory_format()),
2952 self.q_scale(),
2953 self.q_zero_point(),
2954 std::nullopt);
2955
2956 AT_DISPATCH_QINT_TYPES(self.scalar_type(), "quantized_mean_kernel_impl_cpu", [&]() {
2957 scalar_t* X_data = self.data_ptr<scalar_t>();
2958 scalar_t* Y_data = result.data_ptr<scalar_t>();
2959
2960 at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {
2961 for (const auto i : c10::irange(start, end)) {
2962 scalar_t* X_ptr = X_data + i * N;
2963 scalar_t* Y_ptr = Y_data + i;
2964 scalar_t::underlying* X_ptr_underlying = reinterpret_cast<scalar_t::underlying*>(X_ptr);
2965 scalar_t::underlying* Y_ptr_underlying = reinterpret_cast<scalar_t::underlying*>(Y_ptr);
2966 auto x_sum = hsum(X_ptr_underlying, N);
2967 float y_float = static_cast<float>(x_sum) / N;
2968 *Y_ptr_underlying = std::nearbyint(y_float);
2969 }
2970 });
2971 });
2972 }
2973
2974 void qstd_inner_dim_kernel(
2975 const Tensor& self,
2976 OptionalIntArrayRef dim,
2977 const std::optional<Scalar>& correction_opt,
2978 bool keepdim,
2979 Tensor& result) {
2980 ScalarType dtype = self.scalar_type();
2981 auto in_dims = self.sizes().vec();
2982 auto out_dims = in_dims;
2983 size_t num_dims_to_squeeze = dim.has_value() && !dim.value().empty() ?
2984 dim.value().size() :
2985 self.dim();
2986 int64_t M = 1; // Num of groups
2987 int64_t N = 1; // Num of elements to take std of in each group
2988 for (size_t i = 0; i < in_dims.size() - num_dims_to_squeeze; ++i) {
2989 M *= in_dims[i];
2990 }
2991 for (size_t i = 0; i < num_dims_to_squeeze; ++i) {
2992 auto idx = out_dims.size() - 1 - i;
2993 N *= out_dims[idx];
2994 out_dims[idx] = 1;
2995 }
2996 if (!keepdim) {
2997 out_dims.erase(out_dims.end() - num_dims_to_squeeze, out_dims.end());
2998 }
2999 const auto correction = correction_opt.value_or(1).toDouble();
3000 double den = std::max(N - correction, 0.0); // Denominator when computing mean and deviation
3001 auto x_scale = self.q_scale();
3002 auto x_zp = self.q_zero_point();
3003 result = at::_empty_affine_quantized(
3004 out_dims,
3005 at::device(kCPU).dtype(dtype).memory_format(self.suggest_memory_format()),
3006 x_scale,
3007 x_zp,
3008 std::nullopt);
3009
3010 AT_DISPATCH_QINT_TYPES(self.scalar_type(), "quantized_std_kernel_impl_cpu", [&]() {
3011 scalar_t* X_data = self.data_ptr<scalar_t>();
3012 scalar_t* Y_data = result.data_ptr<scalar_t>();
3013
3014 at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {
3015 for (const auto i : c10::irange(start, end)) {
3016 scalar_t* X_ptr = X_data + i * N;
3017 scalar_t* Y_ptr = Y_data + i;
3018 scalar_t::underlying* X_ptr_underlying = reinterpret_cast<scalar_t::underlying*>(X_ptr);
3019 scalar_t::underlying* Y_ptr_underlying = reinterpret_cast<scalar_t::underlying*>(Y_ptr);
3020 auto x_sum_shifted = hsum(X_ptr_underlying, N);
3021 auto x_sum_sq_shifted = hsum_sq(X_ptr_underlying, N);
3022 // Use double for intermediate variables to avoid accuracy issue
3023 // Mean with zero point
3024 double x_mean_shifted_div_scale_x = static_cast<double>(x_sum_shifted) / N;
3025 double x_mean_unbiased_shifted_div_scale_x = static_cast<double>(x_sum_shifted) / den;
3026 // variance / x_scale^2
3027 double x_var_div_scale_x_sq =
3028 std::max(static_cast<double>(x_sum_sq_shifted) / den -
3029 2 * x_mean_shifted_div_scale_x * x_mean_unbiased_shifted_div_scale_x +
3030 x_mean_shifted_div_scale_x * x_mean_shifted_div_scale_x * N / den, (double)0.0);
3031 double y_float = std::sqrt(x_var_div_scale_x_sq) * x_scale;
3032 *Y_ptr_underlying = at::native::quantize_val<scalar_t>(
3033 x_scale, x_zp, y_float)
3034 .val_;
3035 }
3036 });
3037 });
3038 }
3039
3040 // For group norm of channels_last input
3041 void quantized_groupnorm_nhwc_kernel(
3042 const Tensor& X, // input tensor
3043 const Tensor& gamma, // weight (optional)
3044 const Tensor& beta, // bias (optional)
3045 bool affine_per_channel, // must be true for group/instance norm
3046 int num_channels, // only used if affine_per_channel is set
3047 int num_groups, // only used if affine_per_channel is set
3048 int64_t M, // number of groups = Bs * G
3049 int64_t N, // number of elements in each group = C * H * W / G
3050 double eps,
3051 Tensor* Y) {
3052 AT_DISPATCH_QINT_TYPES(X.scalar_type(), "quantized_norm_nhwc_kernel_impl_cpu", [&]() {
3053 using qVec = vec::Vectorized<scalar_t>;
3054 using fVec = vec::Vectorized<float>;
3055
3056 int64_t G = num_groups;
3057 int64_t Bs = M / G;
3058 int64_t C = num_channels;
3059
3060 TORCH_INTERNAL_ASSERT(X.numel() == M * N, "Unexpected num elements in X");
3061 TORCH_INTERNAL_ASSERT(
3062 !gamma.defined() ||
3063 (!affine_per_channel && gamma.numel() == N) ||
3064 (affine_per_channel && gamma.numel() == C),
3065 "Unexpected size of gamma");
3066 TORCH_INTERNAL_ASSERT(
3067 !beta.defined() ||
3068 (!affine_per_channel && beta.numel() == N) ||
3069 (affine_per_channel && beta.numel() == C),
3070 "Unexpected size of beta");
3071
3072 scalar_t* X_data = X.data_ptr<scalar_t>();
3073 const float* gamma_data = gamma.defined() ? gamma.const_data_ptr<float>() : nullptr;
3074 const float* beta_data = beta.defined() ? beta.const_data_ptr<float>() : nullptr;
3075 scalar_t* Y_data = Y->data_ptr<scalar_t>();
3076 const bool gamma_null = gamma_data == nullptr;
3077 const bool beta_null = beta_data == nullptr;
3078 int64_t x_zp = X.q_zero_point();
3079 float x_scale = X.q_scale();
3080 fVec x_zp_vec((float)x_zp);
3081 fVec one_vec(1.0f);
3082 fVec zero_vec(0.0f);
3083 float x_fake_scale = 1.0f;
3084 fVec x_fake_scale_vec(x_fake_scale);
3085 fVec x_fake_scale_zp_neg_premul_vec = x_fake_scale_vec * x_zp_vec.neg();
3086 int64_t y_zp = Y->q_zero_point();
3087 float y_scale = Y->q_scale();
3088 float y_inv_scale = 1.0f / y_scale;
3089
3090 constexpr int kFloatVLen = fVec::size();
3091 int64_t kIntVLen = kFloatVLen * qVec::float_num_vecs();
3092 int64_t channels_per_group = C / G;
3093 int64_t HxW = N / channels_per_group;
3094 int64_t kNumIntVecInHxW = channels_per_group / kIntVLen;
3095 int64_t kNonVecRemInHxW = channels_per_group % kIntVLen;
3096 int64_t kNumIntVecOnChannel = C / kIntVLen;
3097 int64_t kNonVecRemOnChannel = C % kIntVLen;
3098
3099 // Buffer for x and x^2
3100 Tensor buffer = at::empty({M, 2 * channels_per_group}, X.options().dtype(at::kFloat));
3101 float* buffer_data = buffer.mutable_data_ptr<float>();
3102
3103 // We can parallel in the following 2 impls:
3104 //
3105 // impl-1: parallel on N * G. Only need one omp session but memory access
3106 // per thread is non-contiguous.
3107 //
3108 // impl-2: parallel on N * HxW. Memory access per thread is contiguous,
3109 // but requires help of extra temp buffer of size {T, N, 2C}.
3110 //
3111 // Generally impl-2 has better performance when HxW is large enough
3112 // The threshold is found by tests.
3113 constexpr int64_t feature_map_threshold = 512;
3114 if (HxW < feature_map_threshold) {
3115 // Impl-1: Parallel for each group
3116 //
3117 // Parallel for each group, M = Bs * G
3118 at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
3119 int64_t n{0} /* batch index */, g{0} /* group index in each batch */;
3120 data_index_init(begin, n, N, g, G);
3121 for (const auto grpIdx : c10::irange(begin, end)) { // For each group
3122
3123 // Step 1: calculate mean and variance.
3124 int64_t l_sum_shifted = 0;
3125 int64_t l_sum_sq_shifted = 0;
3126 for (const auto hw : c10::irange(HxW)) {
3127 scalar_t* X_ptr = X_data + n * N * G + g * channels_per_group + hw * C;
3128 scalar_t::underlying* X_ptr_underlying = reinterpret_cast<scalar_t::underlying*>(X_ptr);
3129 l_sum_shifted += hsum(X_ptr_underlying, channels_per_group);
3130 l_sum_sq_shifted += hsum_sq(X_ptr_underlying, channels_per_group);
3131 }
3132
3133 // mean(dqX) / scale_x + x_zp
3134 float l_mean_shifted_div_scale_x = static_cast<float>(l_sum_shifted) / N;
3135 // mean(dqX) / scale_x
3136 float layer_mean_div_scale_x = l_mean_shifted_div_scale_x - x_zp;
3137 // var(dqX) / scale_x^2
3138 float layer_var_div_scale_x_sq =
3139 std::max(static_cast<float>(l_sum_sq_shifted) / N -
3140 l_mean_shifted_div_scale_x * l_mean_shifted_div_scale_x, 0.0f);
3141 // scale_x / sqrt(var(dqX) + eps)
3142 float scale_x_div_layer_std = x_scale /
3143 std::sqrt(layer_var_div_scale_x_sq * x_scale * x_scale + eps);
3144
3145 // Step 2: calculate scale and bias
3146 float* scale_ptr = buffer_data + grpIdx * 2 * channels_per_group;
3147 float* bias_ptr = scale_ptr + channels_per_group;
3148 for (const auto d : c10::irange(channels_per_group)) {
3149 const int64_t chIdx = g * channels_per_group + d;
3150 scale_ptr[d] = scale_x_div_layer_std * (gamma_null ? 1.0f : gamma_data[chIdx]);
3151 bias_ptr[d] = -scale_ptr[d] * layer_mean_div_scale_x + (beta_null ? 0.0f : beta_data[chIdx]);
3152 }
3153
3154 // Step 3: applying scale and bias
3155 for (const auto hwIdx : c10::irange(HxW)) {
3156 const scalar_t* X_ptr = X_data + n * N * G + g * channels_per_group + hwIdx * C;
3157 scalar_t* Y_ptr = Y_data + n * N * G + g * channels_per_group + hwIdx * C;
3158 // vectorized
3159 for (const auto vecIdx : c10::irange(kNumIntVecInHxW)) {
3160 int64_t vecStartIdx = vecIdx * kIntVLen;
3161 auto qXVec = qVec::loadu(X_ptr + vecStartIdx);
3162 auto dqXVec = qXVec.dequantize(x_fake_scale_vec, x_zp_vec,
3163 x_fake_scale_zp_neg_premul_vec);
3164 for (size_t fvecIdx = 0; fvecIdx < dqXVec.size(); ++fvecIdx) {
3165 auto scaleVec = fVec::loadu(scale_ptr + vecStartIdx + fvecIdx * kFloatVLen);
3166 auto biasVec = fVec::loadu(bias_ptr + vecStartIdx + fvecIdx * kFloatVLen);
3167 dqXVec[fvecIdx] = dqXVec[fvecIdx] * scaleVec + biasVec;
3168 }
3169 qVec::quantize(dqXVec, y_scale, y_zp, y_inv_scale)
3170 .store(Y_ptr + vecStartIdx);
3171 }
3172 // Remaining scalar
3173 for (int64_t remIdx = kNumIntVecInHxW * kIntVLen;
3174 remIdx < kNonVecRemInHxW + kNumIntVecInHxW * kIntVLen;
3175 ++remIdx) {
3176 auto qXVal = X_ptr[remIdx];
3177 float dqXVal = at::native::dequantize_val(x_fake_scale, x_zp, qXVal);
3178 float dqY = dqXVal * scale_ptr[remIdx] + bias_ptr[remIdx];
3179 Y_ptr[remIdx] = at::native::quantize_val<scalar_t>(y_scale, y_zp, dqY);
3180 }
3181 } // loop over HxW
3182
3183 data_index_step(n, N, g, G);
3184 } // for each group
3185 }); // parallel_for
3186 } else { // HxW > feature_map_threshold
3187 // impl-2: parallel on Bs * HxW.
3188 //
3189 // Buffer for x and x^2
3190 // To avoid thread conflict, we use a temp buffer of {T, Bs, 2*C}
3191 int num_threads = at::get_num_threads();
3192 Tensor buffer = at::empty({num_threads, Bs, 2 * C}, X.options().dtype(at::kFloat)).zero_();
3193 float* buffer_data = buffer.mutable_data_ptr<float>();
3194 Tensor mean = at::empty(M, X.options().dtype(at::kFloat));
3195 float* mean_data = mean.mutable_data_ptr<float>();
3196 Tensor rstd = at::empty(M, X.options().dtype(at::kFloat));
3197 float* rstd_data = rstd.mutable_data_ptr<float>();
3198
3199 // Step 1: Accumulate on C dimension
3200 at::parallel_for(0, Bs * HxW, 1, [&](int64_t begin, int64_t end) {
3201 int tid = at::get_thread_num();
3202 float* buffer_ptr = buffer_data + tid * Bs * 2 * C;
3203
3204 int64_t n{0} /* batch index */, m{0} /* HxW index */;
3205 data_index_init(begin, n, Bs, m, HxW);
3206 for (const auto nhwIdx : c10::irange(begin, end)) {
3207 float* mean_ptr = buffer_ptr + n * 2 * C;
3208 float* rstd_ptr = mean_ptr + C;
3209 scalar_t* X_ptr = X_data + nhwIdx * C;
3210 scalar_t::underlying* X_ptr_underlying = reinterpret_cast<scalar_t::underlying*>(X_ptr);
3211 for (int chIdx = 0; chIdx < C; ++chIdx) {
3212 auto x = X_ptr_underlying[chIdx];
3213 mean_ptr[chIdx] += x;
3214 rstd_ptr[chIdx] += x * x;
3215 }
3216 data_index_step(n, Bs, m, HxW);
3217 }
3218 });
3219
3220 // Step 2: Calculate mean and rstd
3221 for (const auto n : c10::irange(Bs)) {
3222 for (const auto g : c10::irange(G)) {
3223 float mean_val{0}, rstd_val{0};
3224 for (const auto t : c10::irange(num_threads)) {
3225 float* buffer_ptr = buffer_data + t * Bs * 2 * C + n * 2 * C;
3226 for (const auto d : c10::irange(channels_per_group)) {
3227 mean_val += buffer_ptr[g * channels_per_group + d];
3228 rstd_val += buffer_ptr[g * channels_per_group + d + C];
3229 } // for d
3230 } // for t
3231
3232 // mean / scale_x + x_zp
3233 float l_mean_shifted_div_scale_x = mean_val / N;
3234 // mean / scale_x
3235 float layer_mean_div_scale_x = l_mean_shifted_div_scale_x - x_zp;
3236 // var / scale_x^2
3237 float layer_var_div_scale_x_sq =
3238 std::max(rstd_val / N -
3239 l_mean_shifted_div_scale_x * l_mean_shifted_div_scale_x, 0.0f);
3240 // scale_x / sqrt(var + eps)
3241 float scale_x_div_layer_std = x_scale /
3242 std::sqrt(layer_var_div_scale_x_sq * x_scale * x_scale + eps);
3243 mean_data[n * G + g] = layer_mean_div_scale_x;
3244 rstd_data[n * G + g] = scale_x_div_layer_std;
3245
3246 } // for g
3247 } // for n
3248
3249 // Step 3: Calculate scale and bias
3250 //
3251 // We could fuse step 3 and 4 into a single session but this way is better:
3252 // a. D might be too small for vectorization;
3253 // b. Avoid duplicate calculation of scale/bias, each HxW plain share the same scale/bias
3254 //
3255 for (const auto n : c10::irange(Bs)) {
3256 for (const auto g : c10::irange(G)) {
3257 float* scale_ptr = buffer_data + n * 2 * C;
3258 float* bias_ptr = scale_ptr + C;
3259 float mean_val = mean_data[n * G + g];
3260 float rstd_val = rstd_data[n * G + g];
3261 for (const auto d : c10::irange(channels_per_group)) {
3262 const int64_t chIdx = g * channels_per_group + d;
3263 scale_ptr[chIdx] = rstd_val * (gamma_null ? 1.0f : gamma_data[chIdx]);
3264 bias_ptr[chIdx] = -scale_ptr[chIdx] * mean_val + (beta_null ? 0.0f : beta_data[chIdx]);
3265 } // for d
3266 } // for g
3267 } // for n
3268
3269 // step-4: apply scale and bias
3270 //
3271 // Parallel on all the outer dimensions of Bs and HxW
3272 // and vectorize on C.
3273 //
3274 at::parallel_for(0, Bs * HxW, 1, [&](int64_t begin, int64_t end) {
3275 int64_t n{0}, m{0};
3276 data_index_init(begin, n, Bs, m, HxW);
3277 for (const auto nhwIdx : c10::irange(begin, end)) {
3278 const scalar_t* X_ptr = X_data + nhwIdx * C;
3279 scalar_t* Y_ptr = Y_data + nhwIdx * C;
3280 float* scale_ptr = buffer_data + n * 2 * C;
3281 float* bias_ptr = scale_ptr + C;
3282 // Vectorized
3283 for (const auto vecIdx : c10::irange(kNumIntVecOnChannel)) {
3284 int64_t vecStartIdx = vecIdx * kIntVLen;
3285 auto qXVec = qVec::loadu(X_ptr + vecStartIdx);
3286 auto dqXVec = qXVec.dequantize(x_fake_scale_vec, x_zp_vec,
3287 x_fake_scale_zp_neg_premul_vec);
3288 for (size_t fvecIdx = 0; fvecIdx < dqXVec.size(); ++fvecIdx) {
3289 auto scaleVec = fVec::loadu(scale_ptr + vecStartIdx + fvecIdx * kFloatVLen);
3290 auto biasVec = fVec::loadu(bias_ptr + vecStartIdx + fvecIdx * kFloatVLen);
3291 dqXVec[fvecIdx] = dqXVec[fvecIdx] * scaleVec + biasVec;
3292 }
3293 qVec::quantize(dqXVec, y_scale, y_zp, y_inv_scale)
3294 .store(Y_ptr + vecStartIdx);
3295 }
3296 // Remaining scalar
3297 for (int64_t remIdx = kNumIntVecOnChannel * kIntVLen;
3298 remIdx < kNonVecRemOnChannel + kNumIntVecOnChannel * kIntVLen;
3299 ++remIdx) {
3300 auto qXVal = X_ptr[remIdx];
3301 float dqXVal = at::native::dequantize_val(x_fake_scale, x_zp, qXVal);
3302 float dqY = dqXVal * scale_ptr[remIdx] + bias_ptr[remIdx];
3303 Y_ptr[remIdx] = at::native::quantize_val<scalar_t>(y_scale, y_zp, dqY);
3304 }
3305
3306 data_index_step(n, Bs, m, HxW);
3307 } // for idx on nhw
3308 }); // parallel_for on nhw
3309
3310 } // if HxW > feature_map_threshold
3311
3312 }); // AT_DISPATCH_QINT_TYPES
3313 }
3314
3315 #ifdef USE_FBGEMM
3316 void quantize_tensor_per_tensor_affine_cpu(
3317 const Tensor& rtensor,
3318 Tensor& qtensor,
3319 double scale,
3320 int64_t zero_point) {
3321 AT_DISPATCH_QINT_TYPES(
3322 qtensor.scalar_type(), "quantize_tensor_per_tensor_affine_cpu", [&]() {
3323 check_tensor_memory_format(rtensor, qtensor);
3324 const float* rd = rtensor.const_data_ptr<float>();
3325 auto qd = reinterpret_cast<underlying_t*>(qtensor.data_ptr<scalar_t>());
3326 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
3327 fbgemm::TensorQuantizationParams qparams;
3328 qparams.scale = scale;
3329 qparams.zero_point = zero_point;
3330 qparams.precision = CHAR_BIT * sizeof(underlying_t);
3331 int num_tasks = at::get_num_threads();
3332 at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
3333 for (const auto task_id : c10::irange(begin, end)) {
3334 fbgemm::Quantize<underlying_t, false /*LEGACY*/>(
3335 // NOLINTNEXTLINE(bugprone-argument-comment)
3336 rd, /*src=*/
3337 // NOLINTNEXTLINE(bugprone-argument-comment)
3338 qd, /*dst=*/
3339 rtensor.numel(), /*len*/
3340 // NOLINTNEXTLINE(bugprone-argument-comment)
3341 qparams, /*qparams=*/
3342 task_id, /*thread_id*/
3343 num_tasks /*num_threads*/);
3344 }
3345 });
3346 });
3347 }
3348
3349 void dequantize_tensor_per_tensor_affine_cpu(
3350 const Tensor& qtensor,
3351 Tensor& rtensor,
3352 double scale,
3353 int64_t zero_point) {
3354 AT_DISPATCH_QINT_TYPES(
3355 qtensor.scalar_type(), "dequantize_tensor_per_tensor_affine_cpu", [&]() {
3356 check_tensor_memory_format(qtensor, rtensor);
3357 const auto* qd =
3358 reinterpret_cast<const underlying_t*>(qtensor.data_ptr<scalar_t>());
3359 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
3360 fbgemm::TensorQuantizationParams qparams;
3361 qparams.scale = scale;
3362 qparams.zero_point = zero_point;
3363 qparams.precision = CHAR_BIT * sizeof(underlying_t);
3364 float* rd = rtensor.data_ptr<float>();
3365 int num_tasks = at::get_num_threads();
3366 at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
3367 for (const auto task_id : c10::irange(begin, end)) {
3368 fbgemm::Dequantize<underlying_t>(
3369 // NOLINTNEXTLINE(bugprone-argument-comment)
3370 qd, /*src=*/
3371 // NOLINTNEXTLINE(bugprone-argument-comment)
3372 rd, /*dst=*/
3373 // NOLINTNEXTLINE(bugprone-argument-comment)
3374 qtensor.numel(), /*len=*/
3375 // NOLINTNEXTLINE(bugprone-argument-comment)
3376 qparams, /*qparams=*/
3377 task_id, /*thread_id*/
3378 num_tasks /*num_threads*/);
3379 }
3380 });
3381 });
3382 }
3383 #else // USE_FBGEMM
3384
3385 #if defined(__ARM_NEON__) || defined(__aarch64__)
3386
3387 const static int PARALLEL_THRESHOLD = 1 << 20;
3388
3389 // Generic template defaults to naive quantize implementation
3390 template <typename T>
3391 void quantize_tensor_arm(
3392 const float* __restrict__ in,
3393 T* __restrict__ out,
3394 const int64_t N,
3395 const float scale,
3396 const int32_t zero_point) {
3397 for (const auto i : c10::irange(N)) {
3398 out[i] = at::native::quantize_val<T>(scale, zero_point, in[i]);
3399 }
3400 }
3401
3402 namespace quantize_tensor_arm_intrinsics {
3403 template <typename Tx8>
3404 C10_ALWAYS_INLINE Tx8 vqmov(int16x8_t vraw);
3405
3406 template <>
3407 C10_ALWAYS_INLINE uint8x8_t vqmov<uint8x8_t>(int16x8_t vraw) {
3408 return vqmovun_s16(vraw);
3409 }
3410
3411 template <>
3412 C10_ALWAYS_INLINE int8x8_t vqmov<int8x8_t>(int16x8_t vraw) {
3413 return vqmovn_s16(vraw);
3414 }
3415
3416 template <typename T, typename Tx8>
3417 C10_ALWAYS_INLINE void vst1(T* out, Tx8 vout);
3418
3419 template <>
3420 C10_ALWAYS_INLINE void vst1<uint8_t, uint8x8_t>(uint8_t* out, uint8x8_t vout) {
3421 vst1_u8(out, vout);
3422 }
3423
3424 template <>
3425 C10_ALWAYS_INLINE void vst1<int8_t, int8x8_t>(int8_t* out, int8x8_t vout) {
3426 vst1_s8(out, vout);
3427 }
3428 } // namespace quantize_tensor_arm_intrinsics
3429
3430 // Specialized implementation from caffe2::Int8Quantize.
3431 // There may be slight accuracy difference between this and implementation of
3432 // quantize_val
3433 // TODO Update quantize_tensor_arm implementation to follow quantize_val,
3434 // i.e. f = Round(value/scale + zero_point)
3435 // TODO Make quantize_tensor_arm work for int32 datatype too.
3436 template <typename scalar_t, typename underlying_t, typename underlying_x8_t>
3437 void quantize_tensor_arm_q8(
3438 const float* __restrict__ in,
3439 scalar_t* __restrict__ out,
3440 const int64_t N,
3441 const float scale,
3442 const int32_t zero_point) {
3443 const float inv_scale = 1.0f / scale;
3444 uint32_t i = 0;
3445 underlying_t* out_underlying = reinterpret_cast<underlying_t*>(out);
3446 const float32x4_t vinv_scale = vdupq_n_f32(inv_scale);
3447 #if defined(__ARM_NEON__)
3448 // magic float and magic int to take care of rounding
3449 // int magic_round(float f): interpret_int32(f + 12582912.0f) - 0x4B400000
3450 // Some detail:
3451 // 12582912.0f is 2**23 + 2**22. The trick is based on the fact that when you
3452 // add a small number to a large number, the result rounds to the precision of
3453 // the least significant bit of the large number. For IEEE-754
3454 // single-precision number mantissa has 23 bits, and adding 2**23 would cause
3455 // rounding to the nearest even integer. The we cast to int and subtract the
3456 // same number (0x4B400000 is the integer representation of 12582912.0f) to
3457 // get only the mantissa. This works if -2**22 < x < 2**22, but preserves the
3458 // sign for negative numbers.
3459 const int32x4_t voffset = vdupq_n_s32(zero_point - 0x4B400000);
3460 const float32x4_t vmagic_float = vdupq_n_f32(12582912.0f);
3461 for (i = 0; i + 8 <= N; i += 8) {
3462 const float32x4_t vin0123 = vld1q_f32(in);
3463 in += 4;
3464 const float32x4_t vin4567 = vld1q_f32(in);
3465 in += 4;
3466 const int32x4_t vraw0123 = vaddq_s32(
3467 voffset,
3468 vreinterpretq_s32_f32(
3469 vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale))));
3470 const int32x4_t vraw4567 = vaddq_s32(
3471 voffset,
3472 vreinterpretq_s32_f32(
3473 vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale))));
3474 const int16x8_t vraw01234567 =
3475 vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567));
3476 const underlying_x8_t vout01234567 =
3477 quantize_tensor_arm_intrinsics::vqmov<underlying_x8_t>(vraw01234567);
3478 quantize_tensor_arm_intrinsics::vst1<underlying_t, underlying_x8_t>(
3479 out_underlying, vout01234567);
3480 out_underlying += 8;
3481 }
3482 for (; i < N; ++i) {
3483 (*out_underlying++) =
3484 at::native::quantize_val_arm<underlying_t>(scale, zero_point, (*in++));
3485 }
3486 #else
3487 const int16x8_t vzero_point = vdupq_n_s16((int16_t)(uint16_t)zero_point);
3488 for (i = 0; i + 8 <= N; i += 8) {
3489 const float32x4_t vin0123 = vld1q_f32(in);
3490 in += 4;
3491 const float32x4_t vin4567 = vld1q_f32(in);
3492 in += 4;
3493 const int32x4_t v0123_rounded = vcvtnq_s32_f32(vmulq_f32(vin0123, vinv_scale));
3494 const int32x4_t v4567_rounded = vcvtnq_s32_f32(vmulq_f32(vin4567, vinv_scale));
3495 const int16x8_t v01234567_packed = vqaddq_s16(
3496 vqmovn_high_s32(vqmovn_s32(v0123_rounded), v4567_rounded), vzero_point);
3497 const underlying_x8_t vout01234567 =
3498 quantize_tensor_arm_intrinsics::vqmov<underlying_x8_t>(
3499 v01234567_packed);
3500 quantize_tensor_arm_intrinsics::vst1<underlying_t, underlying_x8_t>(
3501 out_underlying, vout01234567);
3502 out_underlying += 8;
3503 }
3504 for (; i < N; ++i) {
3505 (*out_underlying++) =
3506 at::native::quantize_val_arm<underlying_t>(scale, zero_point, (*in++));
3507 }
3508 #endif
3509 }
3510
3511 template <>
3512 void quantize_tensor_arm<c10::quint8>(
3513 const float* __restrict__ in,
3514 c10::quint8* __restrict__ out,
3515 const int64_t N,
3516 const float scale,
3517 const int32_t zero_point) {
3518 quantize_tensor_arm_q8<c10::quint8, uint8_t, uint8x8_t>(
3519 in, out, N, scale, zero_point);
3520 }
3521
3522 template <>
3523 void quantize_tensor_arm<c10::qint8>(
3524 const float* __restrict__ in,
3525 c10::qint8* __restrict__ out,
3526 const int64_t N,
3527 const float scale,
3528 const int32_t zero_point) {
3529 quantize_tensor_arm_q8<c10::qint8, int8_t, int8x8_t>(
3530 in, out, N, scale, zero_point);
3531 }
3532
3533 #if defined(__aarch64__)
3534 #define VMOVL_HIGH_U8(x) vmovl_high_u8(x)
3535 #define VMOVL_HIGH_S8(x) vmovl_high_s8(x)
3536 #define VMOVL_HIGH_U16(x) vmovl_high_u16(x)
3537 #define VMOVL_HIGH_S16(x) vmovl_high_s16(x)
3538 #else // vmovl_high intrinsic not supported
3539 #define VMOVL_HIGH_U8(x) vmovl_u8(vget_high_u8(x))
3540 #define VMOVL_HIGH_S8(x) vmovl_s8(vget_high_s8(x))
3541 #define VMOVL_HIGH_U16(x) vmovl_u16(vget_high_u16(x))
3542 #define VMOVL_HIGH_S16(x) vmovl_s16(vget_high_s16(x))
3543 #endif
3544
3545 // Generic template defaults to naive dequantize implementation
3546 template <typename T>
3547 void dequantize_tensor_arm(
3548 const T* __restrict__ in,
3549 float* __restrict__ out,
3550 const int64_t N,
3551 const float scale,
3552 const int32_t zero_point) {
3553 for (int i = 0; i < N; ++i) {
3554 out[i] = dequantize_val<T>(scale, zero_point, in[i]);
3555 }
3556 }
3557
3558 template <>
3559 void dequantize_tensor_arm<c10::qint8>(
3560 const c10::qint8* __restrict__ in,
3561 float* __restrict__ out,
3562 const int64_t N,
3563 const float scale,
3564 const int32_t zero_point) {
3565 const int8_t* in_underlying = reinterpret_cast<const int8_t*>(in);
3566
3567 const float32x4_t scale_fp32x4 = vdupq_n_f32(scale);
3568 // Zero point is restricted to be in bounds of a signed 8 bit integer
3569 const int8x8_t zero_point_s8x8 = vget_low_s8(vdupq_n_s8(static_cast<int8_t>(zero_point)));
3570
3571 int i;
3572 for (i = 0; i + 16 <= N; i += 16) {
3573 const int8x16_t vin_s8 = vld1q_s8(in_underlying);
3574
3575 // Extract upper or lower values to int16x8 and subtract zero point
3576 // Each input element and the zero point are restricted to be in bounds of
3577 // a signed 8 bit integer, so the difference will fit in a signed 16 bit
3578 // integer
3579 const int16x8_t minus_zp_low_s16 = vsubl_s8(vget_low_s8(vin_s8), zero_point_s8x8); // 0 ... 7
3580 const int16x8_t minus_zp_high_s16 = vsubl_s8(vget_high_s8(vin_s8), zero_point_s8x8); // 8 ... 15
3581
3582 const int32x4_t minus_zp_low_low = vmovl_s16(vget_low_s16(minus_zp_low_s16)); // 0 ... 3
3583 const int32x4_t minus_zp_low_high = VMOVL_HIGH_S16(minus_zp_low_s16); // 4 ... 7
3584 const int32x4_t minus_zp_high_low = vmovl_s16(vget_low_s16(minus_zp_high_s16)); // 8 ... 11
3585 const int32x4_t minus_zp_high_high = VMOVL_HIGH_S16(minus_zp_high_s16); // 12 ... 15
3586
3587 // Store * scale int32->fp32
3588 vst1q_f32(out, vmulq_f32(vcvtq_f32_s32(minus_zp_low_low), scale_fp32x4));
3589 vst1q_f32(out + 4, vmulq_f32(vcvtq_f32_s32(minus_zp_low_high), scale_fp32x4));
3590 vst1q_f32(out + 8, vmulq_f32(vcvtq_f32_s32(minus_zp_high_low), scale_fp32x4));
3591 vst1q_f32(out + 12, vmulq_f32(vcvtq_f32_s32(minus_zp_high_high), scale_fp32x4));
3592
3593 out += 16;
3594 in += 16;
3595 in_underlying += 16;
3596 }
3597
3598 for (; i < N; ++i) { // use default dequantize for remaining vals
3599 (*out++) = dequantize_val<c10::qint8>(scale, zero_point, (*in++));
3600 }
3601 }
3602
3603 template <>
3604 void dequantize_tensor_arm<c10::quint8>(
3605 const c10::quint8* __restrict__ in,
3606 float* __restrict__ out,
3607 const int64_t N,
3608 const float scale,
3609 const int32_t zero_point) {
3610 const uint8_t* in_underlying = reinterpret_cast<const uint8_t*>(in);
3611
3612 const float32x4_t scale_fp32x4 = vdupq_n_f32(scale);
3613 // Zero point is restricted to be in bounds of an unsigned 8 bit integer
3614 const uint8x8_t zero_point_u8x8 = vget_low_u8(vdupq_n_u8(static_cast<uint8_t>(zero_point)));
3615
3616 int i;
3617 for (i = 0; i + 16 <= N; i += 16) {
3618 const uint8x16_t vin_u8 = vld1q_u8(in_underlying);
3619
3620 // Extract upper or lower values to uint16x8 and subtract zero point
3621 // Each input element and the zero point are restricted to be in bounds of
3622 // an unsigned 8 bit integer, so the difference will fit in a signed 16 bit
3623 // integer
3624 const int16x8_t minus_zp_low_s16 = vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(vin_u8), zero_point_u8x8)); // 0 ... 7
3625 const int16x8_t minus_zp_high_s16 = vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(vin_u8), zero_point_u8x8)); // 8 ... 15
3626
3627 const int32x4_t minus_zp_low_low = vmovl_s16(vget_low_s16(minus_zp_low_s16)); // 0 ... 3
3628 const int32x4_t minus_zp_low_high = VMOVL_HIGH_S16(minus_zp_low_s16); // 4 ... 7
3629 const int32x4_t minus_zp_high_low = vmovl_s16(vget_low_s16(minus_zp_high_s16)); // 8 ... 11
3630 const int32x4_t minus_zp_high_high = VMOVL_HIGH_S16(minus_zp_high_s16); // 12 ... 15
3631
3632 // Store * scale int32->fp32
3633 vst1q_f32(out, vmulq_f32(vcvtq_f32_s32(minus_zp_low_low), scale_fp32x4));
3634 vst1q_f32(out + 4, vmulq_f32(vcvtq_f32_s32(minus_zp_low_high), scale_fp32x4));
3635 vst1q_f32(out + 8, vmulq_f32(vcvtq_f32_s32(minus_zp_high_low), scale_fp32x4));
3636 vst1q_f32(out + 12, vmulq_f32(vcvtq_f32_s32(minus_zp_high_high), scale_fp32x4));
3637
3638 out += 16;
3639 in += 16;
3640 in_underlying += 16;
3641 }
3642
3643 for (; i < N; ++i) { // use default dequantize for remaining vals
3644 (*out++) = dequantize_val<c10::quint8>(scale, zero_point, (*in++));
3645 }
3646 }
3647
3648 #endif // defined(__ARM_NEON__) || defined(__aarch64__)
3649
3650 void quantize_tensor_per_tensor_affine_cpu(
3651 const Tensor& rtensor,
3652 Tensor& qtensor,
3653 double scale,
3654 int64_t zero_point) {
3655 check_tensor_memory_format(rtensor, qtensor);
3656 const float* rdata = rtensor.const_data_ptr<float>();
3657 int numel = rtensor.numel();
3658 #if defined(__ARM_NEON__) || defined(__aarch64__)
3659 AT_DISPATCH_QINT_TYPES(
3660 qtensor.scalar_type(), "quantize_tensor_per_tensor_affine_cpu", [&]() {
3661 scalar_t* qdata = qtensor.data_ptr<scalar_t>();
3662 auto quantize_range = [&](int64_t begin, int64_t end) {
3663 quantize_tensor_arm<scalar_t>(
3664 rdata + begin, qdata + begin, end - begin, scale, zero_point);
3665 };
3666 if (numel >= PARALLEL_THRESHOLD) {
3667 at::parallel_for(0, numel, 1, quantize_range);
3668 } else {
3669 quantize_range(0, numel);
3670 }
3671 });
3672 #else
3673 // Fallback path
3674 AT_DISPATCH_QINT_TYPES(
3675 qtensor.scalar_type(), "quantize_tensor_per_tensor_affine_cpu", [&]() {
3676 scalar_t* qdata = qtensor.data_ptr<scalar_t>();
3677 for (const auto i : c10::irange(numel)) {
3678 qdata[i] = quantize_val<scalar_t>(scale, zero_point, rdata[i]);
3679 }
3680 });
3681 #endif // defined(__ARM_NEON__) || defined(__aarch64__)
3682 }
3683
3684 void dequantize_tensor_per_tensor_affine_cpu(
3685 const Tensor& qtensor,
3686 Tensor& rtensor,
3687 double scale,
3688 int64_t zero_point) {
3689 check_tensor_memory_format(qtensor, rtensor);
3690 float* rdata = rtensor.data_ptr<float>();
3691 int numel = qtensor.numel();
3692 #if defined(__ARM_NEON__) || defined(__aarch64__)
3693 AT_DISPATCH_QINT_TYPES(
3694 qtensor.scalar_type(), "dequantize_tensor_per_tensor_affine_cpu", [&]() {
3695 const scalar_t* qdata = qtensor.const_data_ptr<scalar_t>();
3696 auto dequantize_range = [&](int64_t begin, int64_t end) {
3697 dequantize_tensor_arm<scalar_t>(
3698 qdata + begin, rdata + begin, end - begin, scale, zero_point);
3699 };
3700 if (numel >= PARALLEL_THRESHOLD) {
3701 at::parallel_for(0, numel, 1, dequantize_range);
3702 } else {
3703 dequantize_range(0, numel);
3704 }
3705 });
3706 #else
3707 // Fallback path
3708 AT_DISPATCH_QINT_TYPES(
3709 qtensor.scalar_type(), "dequantize_tensor_per_tensor_affine_cpu", [&]() {
3710 const scalar_t* qdata = qtensor.const_data_ptr<scalar_t>();
3711 for (const auto i : c10::irange(numel)) {
3712 rdata[i] = dequantize_val<scalar_t>(scale, zero_point, qdata[i]);
3713 }
3714 });
3715 #endif // defined(__ARM_NEON__) || defined(__aarch64__)
3716 }
3717 #endif // USE_FBGEMM
3718
3719 // TODO: add fbgemm for per channel
3720 // Generic template defaults to naive quantize implementation
3721 template <typename T>
3722 void quantize_tensor_per_channel_impl(
3723 const Tensor& rtensor,
3724 Tensor& qtensor,
3725 const Tensor& scales,
3726 const Tensor& zero_points,
3727 int64_t axis) {
3728 // TODO: channels last kernel can be made faster.
3729 // For contiguous tensors, e.g. NCHW, arbitrary axis can be used.
3730 // For channels_last/3d however axis == 0 or 1.
3731 // Since current implementation on channels_last format does not
3732 // cover per channel quant with arbitrary axis value, it is better
3733 // to check and fail.
3734 int64_t batches = size_to_dim_(axis, rtensor.sizes());
3735 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
3736 int64_t elements_per_channel = size_from_dim_(axis + 1, rtensor.sizes());
3737 int64_t channels = rtensor.size(axis);
3738 auto scales_data = scales.data_ptr<double>();
3739 auto zero_points_data = zero_points.data_ptr<int64_t>();
3740 const float* in = rtensor.const_data_ptr<float>();
3741 auto out = qtensor.data_ptr<T>();
3742 if (axis == 1 &&
3743 (rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
3744 rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
3745 // This code handles per channel quant when axis = 1 and
3746 // channels_last contig.
3747 // If axis = 0 and channels_last contig, implementation for channels
3748 // first (NCHW) works.
3749 for (const auto b : c10::irange(batches)) {
3750 for (const auto e : c10::irange(elements_per_channel)) {
3751 for (const auto c : c10::irange(channels)) {
3752 auto i = b * channels * elements_per_channel + e * channels + c;
3753 out[i] = at::native::quantize_val<T>(
3754 scales_data[c], zero_points_data[c], in[i]);
3755 }
3756 }
3757 }
3758 } else {
3759 for (const auto b : c10::irange(batches)) {
3760 for (const auto c : c10::irange(channels)) {
3761 for (const auto e : c10::irange(elements_per_channel)) {
3762 auto i = b * channels * elements_per_channel +
3763 c * elements_per_channel + e;
3764 out[i] = at::native::quantize_val<T>(
3765 scales_data[c], zero_points_data[c], in[i]);
3766 }
3767 }
3768 }
3769 }
3770 }
3771
3772 #if defined(__ARM_NEON__) || defined(__aarch64__)
3773 // Specialized implementation from caffe2::Int8Quantize.
3774 // There may be slight accuracy difference between this and implementation of
3775 // quantize_val
3776 // TODO Update quantize_tensor_per_channel_impl implementation to follow
3777 // quantize_val, i.e. f = Round(value/scale + zero_point)
3778 // TODO Make quantize_tensor_per_channel_impl work for other datatypes too
3779 // (int8, int32).
3780 template <>
3781 void quantize_tensor_per_channel_impl<c10::quint8>(
3782 const Tensor& rtensor,
3783 Tensor& qtensor,
3784 const Tensor& scales,
3785 const Tensor& zero_points,
3786 int64_t axis) {
3787 int64_t batches = size_to_dim_(axis, rtensor.sizes());
3788 int64_t elements_per_channel = size_from_dim_(axis + 1, rtensor.sizes());
3789 int64_t channels = rtensor.size(axis);
3790 auto scales_data = scales.data_ptr<double>();
3791 auto zero_points_data = zero_points.data_ptr<int64_t>();
3792 const float* in = rtensor.const_data_ptr<float>();
3793 auto out = (uint8_t*)qtensor.data_ptr<c10::quint8>();
3794 #if defined(__ARM_NEON__)
3795 // magic float and magic int to take care of rounding
3796 // int magic_round(float f): interpret_int32(f + 12582912.0f) - 0x4B400000
3797 // Some detail:
3798 // 12582912.0f is 2**23 + 2**22. The trick is based on the fact that when you
3799 // add a small number to a large number, the result rounds to the precision of
3800 // the least significant bit of the large number. For IEEE-754
3801 // single-precision number mantissa has 23 bits, and adding 2**23 would cause
3802 // rounding to the nearest even integer. The we cast to int and subtract the
3803 // same number (0x4B400000 is the integer representation of 12582912.0f) to
3804 // get only the mantissa. This works if -2**22 < x < 2**22, but preserves the
3805 // sign for negative numbers.
3806 const float32x4_t vmagic_float = vdupq_n_f32(12582912.0f);
3807 // Copy reciprocal of scales (double) into float array
3808 // Copy zero_points with magic int (int64_t) into int32_t array
3809 std::vector<float> inv_scales(channels);
3810 std::vector<int32_t> zero_points_int32t(channels);
3811 for (const auto i : c10::irange(channels)) {
3812 inv_scales[i] = 1.0f / (float)scales_data[i];
3813 zero_points_int32t[i] = (int32_t)(uint32_t)zero_points_data[i] - 0x4B400000;
3814 }
3815 if (axis == 1 &&
3816 (rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
3817 rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
3818 // This code handles per channel quant when axis = 1 and
3819 // channels_last contig.
3820 // If axis = 0 and channels_last contig, implementation for channels
3821 // first (NCHW) works.
3822 for (C10_UNUSED const auto b : c10::irange(batches)) {
3823 for (C10_UNUSED const auto e : c10::irange(elements_per_channel)) {
3824 uint32_t c = 0;
3825 while (c + 8 < channels) {
3826 const int32x4_t voffset0123 = vld1q_s32(&zero_points_int32t[c]);
3827 const float32x4_t vinv_scale0123 = vld1q_f32(&inv_scales[c]);
3828 c += 4;
3829 const int32x4_t voffset4567 = vld1q_s32(&zero_points_int32t[c]);
3830 const float32x4_t vinv_scale4567 = vld1q_f32(&inv_scales[c]);
3831 c += 4;
3832 const float32x4_t vin0123 = vld1q_f32(in);
3833 in += 4;
3834 const float32x4_t vin4567 = vld1q_f32(in);
3835 in += 4;
3836 const int32x4_t vraw0123 = vaddq_s32(
3837 voffset0123,
3838 vreinterpretq_s32_f32(
3839 vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale0123))));
3840 const int32x4_t vraw4567 = vaddq_s32(
3841 voffset4567,
3842 vreinterpretq_s32_f32(
3843 vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale4567))));
3844 const int16x8_t vraw01234567 =
3845 vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567));
3846 const uint8x8_t vout01234567 = vqmovun_s16(vraw01234567);
3847 vst1_u8(out, vout01234567);
3848 out += 8;
3849 }
3850 for (; c < channels; ++c) {
3851 (*out++) = at::native::quantize_val_arm<uint8_t>(
3852 scales_data[c], zero_points_data[c], (*in++));
3853 }
3854 }
3855 }
3856 } else {
3857 for (C10_UNUSED const auto b : c10::irange(batches)) {
3858 for (const auto c : c10::irange(channels)) {
3859 uint32_t e = 0;
3860 const int32x4_t voffset = vdupq_n_s32(zero_points_int32t[c]);
3861 const float32x4_t vinv_scale = vdupq_n_f32(inv_scales[c]);
3862 for (; e + 8 < elements_per_channel; e += 8) {
3863 const float32x4_t vin0123 = vld1q_f32(in);
3864 in += 4;
3865 const float32x4_t vin4567 = vld1q_f32(in);
3866 in += 4;
3867 const int32x4_t vraw0123 = vaddq_s32(
3868 voffset,
3869 vreinterpretq_s32_f32(
3870 vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale))));
3871 const int32x4_t vraw4567 = vaddq_s32(
3872 voffset,
3873 vreinterpretq_s32_f32(
3874 vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale))));
3875 const int16x8_t vraw01234567 =
3876 vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567));
3877 const uint8x8_t vout01234567 = vqmovun_s16(vraw01234567);
3878 vst1_u8(out, vout01234567);
3879 out += 8;
3880 }
3881 for (; e < elements_per_channel; ++e) {
3882 (*out++) = at::native::quantize_val_arm<uint8_t>(
3883 scales_data[c], zero_points_data[c], (*in++));
3884 }
3885 }
3886 }
3887 }
3888 #else // defined(__ARM_NEON__)
3889 // Copy scales (double) into float array
3890 // Copy zero_points (int64_t) into int16_t array
3891 std::vector<float> inv_scales(channels);
3892 std::vector<int16_t> zero_points_int16t(channels);
3893 for (const auto i : c10::irange(channels)) {
3894 inv_scales[i] = 1.0f / (float)scales_data[i];
3895 zero_points_int16t[i] = (int16_t)(uint16_t)zero_points_data[i];
3896 }
3897 if (axis == 1 &&
3898 (rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
3899 rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
3900 // This code handles per channel quant when axis = 1 and
3901 // channels_last contig.
3902 // If axis = 0 and channels_last contig, implementation for channels
3903 // first (NCHW) works.
3904 for (const auto b C10_UNUSED : c10::irange(batches)) {
3905 for (const auto e C10_UNUSED : c10::irange(elements_per_channel)) {
3906 uint32_t c = 0;
3907 while (c + 8 < channels) {
3908 const int16x8_t vzero_point = vld1q_s16(&zero_points_int16t[c]);
3909 const float32x4_t vinv_scale0123 = vld1q_f32(&inv_scales[c]);
3910 c += 4;
3911 const float32x4_t vinv_scale4567 = vld1q_f32(&inv_scales[c]);
3912 c += 4;
3913 const float32x4_t vin0123 = vld1q_f32(in);
3914 in += 4;
3915 const float32x4_t vin4567 = vld1q_f32(in);
3916 in += 4;
3917 const int32x4_t v0123_rounded =
3918 vcvtnq_s32_f32(vmulq_f32(vin0123, vinv_scale0123));
3919 const int32x4_t v4567_rounded =
3920 vcvtnq_s32_f32(vmulq_f32(vin4567, vinv_scale4567));
3921 const int16x8_t v01234567_packed = vqaddq_s16(
3922 vqmovn_high_s32(vqmovn_s32(v0123_rounded), v4567_rounded),
3923 vzero_point);
3924 const uint8x8_t vout01234567 = vqmovun_s16(v01234567_packed);
3925 vst1_u8(out, vout01234567);
3926 out += 8;
3927 }
3928 for (; c < channels; ++c) {
3929 (*out++) = at::native::quantize_val_arm<uint8_t>(
3930 scales_data[c], zero_points_data[c], (*in++));
3931 }
3932 }
3933 }
3934 } else {
3935 for (const auto b C10_UNUSED : c10::irange(batches)) {
3936 for (const auto c C10_UNUSED : c10::irange(channels)) {
3937 uint32_t e = 0;
3938 const int16x8_t vzero_point = vdupq_n_s16(zero_points_int16t[c]);
3939 const float32x4_t vinv_scale = vdupq_n_f32(inv_scales[c]);
3940 for (; e + 8 < elements_per_channel; e += 8) {
3941 const float32x4_t vin0123 = vld1q_f32(in);
3942 in += 4;
3943 const float32x4_t vin4567 = vld1q_f32(in);
3944 in += 4;
3945 const int32x4_t v0123_rounded =
3946 vcvtnq_s32_f32(vmulq_f32(vin0123, vinv_scale));
3947 const int32x4_t v4567_rounded =
3948 vcvtnq_s32_f32(vmulq_f32(vin4567, vinv_scale));
3949 const int16x8_t v01234567_packed = vqaddq_s16(
3950 vqmovn_high_s32(vqmovn_s32(v0123_rounded), v4567_rounded),
3951 vzero_point);
3952 const uint8x8_t vout01234567 = vqmovun_s16(v01234567_packed);
3953 vst1_u8(out, vout01234567);
3954 out += 8;
3955 }
3956 for (; e < elements_per_channel; ++e) {
3957 (*out++) = at::native::quantize_val_arm<uint8_t>(
3958 scales_data[c], zero_points_data[c], (*in++));
3959 }
3960 }
3961 }
3962 }
3963 #endif // defined(__ARM_NEON__)
3964 }
3965 #endif // defined(__ARM_NEON__) || defined(__aarch64__)
3966
3967 void quantize_tensor_per_channel_affine_cpu(
3968 const Tensor& rtensor,
3969 Tensor& qtensor,
3970 const Tensor& scales,
3971 const Tensor& zero_points,
3972 int64_t axis) {
3973 TORCH_CHECK(
3974 rtensor.is_contiguous() || (axis <= 1),
3975 "If tensor is channels_last contig then per channel quantization "
3976 "is supported only for axis = 0 or 1.");
3977 AT_DISPATCH_QINT_TYPES(
3978 qtensor.scalar_type(), "quantize_tensor_per_channel_affine_cpu", [&]() {
3979 check_tensor_memory_format(rtensor, qtensor);
3980 quantize_tensor_per_channel_impl<scalar_t>(
3981 rtensor, qtensor, scales, zero_points, axis);
3982 });
3983 }
3984
3985 template<typename T, typename N, typename Q>
3986 void dequantize_per_channel_affine_kernel(
3987 const Tensor& qtensor,
3988 Tensor& rtensor,
3989 const Tensor& scales,
3990 const Tensor& zero_points,
3991 int64_t axis,
3992 int bit_width=8) {
3993
3994 // For contiguous tensors, e.g. NCHW, arbitrary axis can be used.
3995 // For channels_last/3d however axis == 0 or 1.
3996 // Since current implementation on channels_last format does not
3997 // cover per channel quant with arbitrary axis value, it is better
3998 // to check and fail.
3999 TORCH_CHECK(rtensor.is_contiguous() || (axis <=1),
4000 "If tensor is channels_last contig then per channel quantization "
4001 "is supported only for axis = 0 or 1.");
4002 int64_t batches = size_to_dim_(axis, rtensor.sizes());
4003 int64_t elements_per_channel =
4004 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
4005 size_from_dim_(axis + 1, rtensor.sizes());
4006 int64_t channel = rtensor.size(axis);
4007 auto scales_data = scales.data_ptr<T>();
4008 auto zero_points_data = zero_points.data_ptr<N>();
4009 check_tensor_memory_format(qtensor, rtensor);
4010 const auto* qd = qtensor.const_data_ptr<Q>();
4011 float* rd = rtensor.data_ptr<float>();
4012 const auto elem_per_byte = 8 / bit_width;
4013 if (axis == 1 && (rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
4014 rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
4015 for (const auto b : c10::irange(batches)) {
4016 for (const auto e : c10::irange(elements_per_channel)) {
4017 for (const auto c : c10::irange(channel)) {
4018 auto i = b * channel * elements_per_channel + e * channel + c;
4019 // We need to convert the qint8 value to float to ensure the
4020 // subtraction subexpression returns a float
4021 auto qvalue = qd[i / elem_per_byte].val_;
4022 if (bit_width < 8) {
4023 qvalue >>= (i % elem_per_byte) * bit_width;
4024 qvalue &= (1 << bit_width) - 1;
4025 }
4026 rd[i] = (static_cast<float>(qvalue) - zero_points_data[c]) * scales_data[c];
4027 }
4028 }
4029 }
4030 } else {
4031 for (const auto b : c10::irange(batches)) {
4032 for (const auto c : c10::irange(channel)) {
4033 for (const auto e : c10::irange(elements_per_channel)) {
4034 auto i = b * channel * elements_per_channel +
4035 c * elements_per_channel + e;
4036 // We need to convert the qint8 value to float to ensure the
4037 // subtraction subexpression returns a float
4038 auto qvalue = qd[i / elem_per_byte].val_;
4039 if (bit_width < 8) {
4040 qvalue >>= (i % elem_per_byte) * bit_width;
4041 qvalue &= (1 << bit_width) - 1;
4042 }
4043 rd[i] = (static_cast<float>(qvalue) - zero_points_data[c]) * scales_data[c];
4044 }
4045 }
4046 }
4047 }
4048 }
4049
4050 void dequantize_tensor_per_channel_affine_cpu(
4051 const Tensor& qtensor,
4052 Tensor& rtensor,
4053 const Tensor& scales,
4054 const Tensor& zero_points,
4055 int64_t axis) {
4056 AT_DISPATCH_QINT_TYPES(
4057 qtensor.scalar_type(), "dequantize_tensor_per_channel_affine_cpu", [&]() {
4058 dequantize_per_channel_affine_kernel<double, int64_t, scalar_t>(qtensor, rtensor, scales, zero_points, axis);
4059 });
4060 }
4061
4062 // quantize stubs for floating point scale and zero_point.
4063 void quantize_tensor_per_channel_float_qparams_cpu(
4064 const Tensor& rtensor,
4065 Tensor& qtensor,
4066 const Tensor& scales,
4067 const Tensor& zero_points,
4068 int64_t axis) {
4069 // For contiguous tensors, e.g. NCHW, arbitrary axis can be used.
4070 // For channels_last/3d however axis == 0 or 1.
4071 // Since current implementation on channels_last format does not
4072 // cover per channel quant with arbitrary axis value, it is better
4073 // to check and fail.
4074 TORCH_CHECK(rtensor.is_contiguous() || (axis <=1),
4075 "If tensor is channels_last contig then per channel quantization "
4076 "is supported only for axis = 0 or 1.");
4077 AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(
4078 qtensor.scalar_type(), "quantize_tensor_per_channel_float_qparams_cpu", [&]() {
4079 int64_t batches = size_to_dim_(axis, rtensor.sizes());
4080 int64_t elements_per_channel =
4081 size_from_dim_(axis + 1, rtensor.sizes());
4082 int64_t channel = rtensor.size(axis);
4083 auto scales_data = scales.data_ptr<float>();
4084 auto zero_points_data = zero_points.data_ptr<float>();
4085 check_tensor_memory_format(rtensor, qtensor);
4086 const float* rdata = rtensor.const_data_ptr<float>();
4087 auto qdata = reinterpret_cast<underlying_t*>(qtensor.data_ptr<scalar_t>());
4088 const auto elem_per_byte = CHAR_BIT / bit_width;
4089 int qvalue = 0;
4090 if (axis == 1 && (rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
4091 rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
4092 for (const auto b : c10::irange(batches)) {
4093 for (const auto e : c10::irange(elements_per_channel)) {
4094 for (const auto c : c10::irange(channel)) {
4095 auto i = b * channel * elements_per_channel + e * channel + c;
4096 qvalue = quantize_val_float_qparams(
4097 scales_data[c], zero_points_data[c], rdata[i], quant_min, quant_max);
4098 if (i % elem_per_byte == 0) {
4099 qdata[i / elem_per_byte] = static_cast<underlying_t>(qvalue);
4100 } else {
4101 qdata[i / elem_per_byte] |= static_cast<underlying_t>(qvalue << ((i % elem_per_byte) * bit_width));
4102 }
4103 }
4104 }
4105 }
4106 } else {
4107 for (const auto b : c10::irange(batches)) {
4108 for (const auto c : c10::irange(channel)) {
4109 for (const auto e : c10::irange(elements_per_channel)) {
4110 auto i = b * channel * elements_per_channel +
4111 c * elements_per_channel + e;
4112 qvalue = quantize_val_float_qparams(
4113 scales_data[c], zero_points_data[c], rdata[i], quant_min, quant_max);
4114 if (i % elem_per_byte == 0) {
4115 qdata[i / elem_per_byte] = static_cast<underlying_t>(qvalue);
4116 } else {
4117 qdata[i / elem_per_byte] |= static_cast<underlying_t>(qvalue << ((i % elem_per_byte) * bit_width));
4118 }
4119 }
4120 }
4121 }
4122 }
4123 });
4124 }
4125
4126 void dequantize_tensor_per_channel_float_qparams_cpu(
4127 const Tensor& qtensor,
4128 Tensor& rtensor,
4129 const Tensor& scales,
4130 const Tensor& zero_points,
4131 int64_t axis) {
4132 AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(
4133 qtensor.scalar_type(), "dequantize_tensor_per_channel_float_qparams_cpu", [&]() {
4134 dequantize_per_channel_affine_kernel<float, float, scalar_t>(qtensor, rtensor, scales, zero_points, axis, bit_width);
4135 });
4136 }
4137
4138 void quantize_tensor_per_tensor_affine_sub_byte_cpu(
4139 const Tensor& rtensor,
4140 Tensor& qtensor,
4141 float scale,
4142 float zero_point) {
4143 // TODO Use fbgemm kernel to pack values
4144 AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(
4145 qtensor.scalar_type(), "quantize_tensor_per_tensor_affine_sub_byte_cpu", [&]() {
4146 check_tensor_memory_format(rtensor, qtensor);
4147 const float* const rdata = rtensor.const_data_ptr<float>();
4148 auto qdata = reinterpret_cast<underlying_t*>(qtensor.data_ptr<scalar_t>());
4149 auto numel = rtensor.numel();
4150 const auto elem_per_byte = CHAR_BIT / bit_width;
4151 for (const auto i : c10::irange(numel)) {
4152 float inv_scale = scale == 0 ? 1.0f : 1.0f / scale;
4153 int64_t qvalue = lrintf(std::nearbyint(rdata[i] * inv_scale) + zero_point);
4154 qvalue = std::max(quant_min, std::min(qvalue, quant_max));
4155
4156 // We pack sub_byte values and align them to a byte.
4157 // Eg. for 4-bits Index 0 is packed in the lower 4-bits
4158 // and index 1 is packed in the upper 4-bits.
4159 if (i % elem_per_byte == 0) {
4160 qdata[i / elem_per_byte] = static_cast<underlying_t>(qvalue);
4161 } else {
4162 qdata[i / elem_per_byte] |= static_cast<underlying_t>(qvalue << ((i % elem_per_byte) * bit_width));
4163 }
4164 } // for numel
4165 });
4166 }
4167
4168 void dequantize_tensor_per_tensor_affine_sub_byte_cpu(
4169 const Tensor& qtensor,
4170 Tensor& rtensor,
4171 float scale,
4172 float zero_point) {
4173 // TODO Use fbgemm kernel to pack values
4174 AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(
4175 qtensor.scalar_type(), "dequantize_tensor_per_tensor_affine_sub_byte_cpu", [&]() {
4176 check_tensor_memory_format(rtensor, qtensor);
4177 auto rdata = rtensor.data_ptr<float>();
4178 const underlying_t* qdata = reinterpret_cast<const underlying_t*>(qtensor.const_data_ptr<scalar_t>());
4179 auto numel = rtensor.numel();
4180 const auto elem_per_byte = CHAR_BIT / bit_width;
4181
4182 for (const auto i : c10::irange(numel)) {
4183 underlying_t qvalue = qdata[i / elem_per_byte];
4184 qvalue >>= (i % elem_per_byte) * bit_width;
4185 qvalue &= (1 << bit_width) - 1;
4186 rdata[i] = (static_cast<float>(qvalue) - zero_point) * scale;
4187 }
4188 });
4189 }
4190
4191 // This function expects quantized_val input to already be quantized
4192 template <typename scalar_t>
4193 void cpu_masked_fill_kernel_quantized_cpu(TensorIterator& iter, scalar_t quantized_val) {
4194 auto loop = [&](char** data, const int64_t* strides, int64_t n) {
4195 char* dst = data[0];
4196 char* mask = data[1];
4197 for (const auto i : c10::irange(n)) {
4198 bool mask_value = *reinterpret_cast<bool*>(mask + strides[1] * i);
4199
4200 if (mask_value) {
4201 *(scalar_t*)(dst + strides[0] * i) = quantized_val;
4202 }
4203 }
4204 };
4205 iter.for_each(loop);
4206 }
4207
4208 void masked_fill_kernel_quantized_cpu(TensorIterator& iter, const Scalar& value, double scale, int zero_point) {
4209 AT_DISPATCH_QINT_TYPES(iter.dtype(), "masked_fill", [&] {
4210 float float_val = value.to<float>();
4211 auto quantized_val = quantize_val<scalar_t>(scale, zero_point, float_val);
4212 auto mask_dtype = iter.input_dtype(0);
4213 TORCH_CHECK(mask_dtype == ScalarType::Bool, "masked_fill only supports boolean masks, "
4214 "but got mask with dtype ", mask_dtype);
4215 cpu_masked_fill_kernel_quantized_cpu<scalar_t>(iter, quantized_val);
4216 });
4217 }
4218
4219 // currently, we do not support accumulate=True for quantized tensors. We throw an exception in _index_put_impl_quantized_cpu_
4220 void index_put_kernel_quantized_cpu(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate, double scale, int zero_point) {
4221 // NOTE: duplicate indices are only supported if accumulate is true.
4222 AT_DISPATCH_QINT_TYPES(iter.dtype(), "index_put", [&] {
4223 // See Note [Enabling Deterministic Operations]
4224 // Parallel cpu_index_kernel with accumulation is nondeterministic, so we
4225 // must enable serial execution if deterministic algorithms are enabled.
4226 const bool is_deterministic = at::globalContext().deterministicAlgorithms();
4227 at::native::cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [scale, zero_point](char* dst, char* src, int64_t offset) {
4228 *(scalar_t*)(dst + offset) = quantize_val<scalar_t>(scale, zero_point, *(float*)src);
4229 }, /*serial_execution=*/is_deterministic);
4230 });
4231 }
4232 } // anonymous namespace
4233
4234 // Some quantization tests are flaky on Windows with AVX512. If --continue-through-error
4235 // is used, only one fails. But if the failing test is skipped, another one fails.
4236 // If the second test is also skipped, a third one fails.
4237 // So, until Quantization support for Windows is fixed for AVX512,
4238 // AVX2 kernels would be used instead. Ref: GH 56992.
4239 #if defined(_WIN32)
4240 REGISTER_DISPATCH(dequantize_tensor_per_channel_affine_stub,
4241 &dequantize_tensor_per_channel_affine_cpu);
4242 REGISTER_DISPATCH(dequantize_tensor_per_channel_float_qparams_stub,
4243 &dequantize_tensor_per_channel_float_qparams_cpu);
4244 REGISTER_DISPATCH(fake_quant_per_channel_cachemask_stub,
4245 &fake_quant_per_channel_cachemask_cpu);
4246 REGISTER_DISPATCH(qavg_pool2d_nhwc_stub, &qavg_pool2d_nhwc_kernel);
4247 REGISTER_DISPATCH(qavg_pool3d_nhwc_stub, &qavg_pool3d_nhwc_kernel);
4248 #else
4249 // These kernels are dispatched to AVX512
4250 ALSO_REGISTER_AVX512_DISPATCH(dequantize_tensor_per_channel_affine_stub,
4251 &dequantize_tensor_per_channel_affine_cpu);
4252 ALSO_REGISTER_AVX512_DISPATCH(dequantize_tensor_per_channel_float_qparams_stub,
4253 &dequantize_tensor_per_channel_float_qparams_cpu);
4254 ALSO_REGISTER_AVX512_DISPATCH(fake_quant_per_channel_cachemask_stub,
4255 &fake_quant_per_channel_cachemask_cpu);
4256 ALSO_REGISTER_AVX512_DISPATCH(qavg_pool2d_nhwc_stub, &qavg_pool2d_nhwc_kernel);
4257 ALSO_REGISTER_AVX512_DISPATCH(qavg_pool3d_nhwc_stub, &qavg_pool3d_nhwc_kernel);
4258 #endif // CPU_CAPABILITY_AVX512 && _WIN32
4259
4260 // The kernels below are dispatched to AVX2 because they don't perform as well
4261 // with AVX512. We might revisit this decision in the near future.
4262 REGISTER_DISPATCH(dequantize_tensor_per_tensor_affine_stub,
4263 &dequantize_tensor_per_tensor_affine_cpu);
4264 REGISTER_DISPATCH(fake_quant_grad_learnable_tensor_stub,
4265 &fake_quantize_learnable_tensor_grad_kernel_cpu);
4266 REGISTER_DISPATCH(fake_quant_tensor_cachemask_stub,
4267 &fake_quantize_tensor_cachemask_kernel);
4268 REGISTER_DISPATCH(fake_quant_tensor_cachemask_tensor_qparams_stub,
4269 &fake_quantize_tensor_cachemask_tensor_qparams_kernel);
4270 REGISTER_DISPATCH(qadaptive_avg_pool2d_nhwc_stub,
4271 &qadaptive_avg_pool2d_nhwc_kernel);
4272 REGISTER_DISPATCH(qadaptive_avg_pool3d_ndhwc_stub,
4273 &qadaptive_avg_pool3d_ndhwc_kernel);
4274 REGISTER_DISPATCH(qadd_relu_stub, &qadd_kernel<true>);
4275 REGISTER_DISPATCH(qadd_scalar_relu_stub, &qadd_scalar_kernel<true>);
4276 REGISTER_DISPATCH(qadd_scalar_stub, &qadd_scalar_kernel<false>);
4277 REGISTER_DISPATCH(qadd_stub, &qadd_kernel<false>);
4278
4279 REGISTER_DISPATCH(qbatch_norm_relu_stub, &q_batch_norm_kernel<true>);
4280 REGISTER_DISPATCH(qbatch_norm_stub, &q_batch_norm_kernel<false>);
4281 REGISTER_DISPATCH(qcat_nhwc_stub, &qcat_nhwc_kernel<false>);
4282 REGISTER_DISPATCH(qcat_relu_nhwc_stub, &qcat_nhwc_kernel<true>);
4283 REGISTER_DISPATCH(qclamp_stub, &qclamp_kernel);
4284 REGISTER_DISPATCH(qclamp_min_stub, &qclamp_min_kernel);
4285 REGISTER_DISPATCH(qclamp_max_stub, &qclamp_max_kernel);
4286 REGISTER_DISPATCH(qelu_stub, &qelu_kernel);
4287 REGISTER_DISPATCH(qhardsigmoid_stub, &qhardsigmoid_kernel);
4288 REGISTER_DISPATCH(qhardswish_stub, &qhardswish_kernel);
4289 REGISTER_DISPATCH(qmaxpool_2d_nhwc_stub, &qmaxpool_2d_nhwc_kernel);
4290 REGISTER_DISPATCH(qmaxpool_3d_nthwc_stub, &qmaxpool_3d_nthwc_kernel);
4291 REGISTER_DISPATCH(qmul_relu_stub, &qmul_kernel<true>);
4292 REGISTER_DISPATCH(qmul_stub, &qmul_kernel<false>);
4293 REGISTER_DISPATCH(qrelu_leaky_stub, &leaky_qrelu_out_kernel);
4294 REGISTER_DISPATCH(qrelu_stub, &qrelu_kernel);
4295 REGISTER_DISPATCH(qprelu_stub, &qprelu_out_kernel);
4296 REGISTER_DISPATCH(qgelu_stub, &qgelu_kernel);
4297 REGISTER_DISPATCH(qsigmoid_stub, &qsigmoid_kernel);
4298 REGISTER_DISPATCH(qtanh_stub, &qtanh_kernel);
4299 REGISTER_DISPATCH(qthreshold_stub, &qthreshold_kernel);
4300 REGISTER_DISPATCH(qtopk_stub, &qtopk_kernel);
4301 REGISTER_DISPATCH(fake_quant_grad_learnable_channel_stub,
4302 &fake_quantize_learnable_channel_grad_kernel_cpu);
4303 REGISTER_DISPATCH(
4304 quantize_tensor_per_tensor_affine_stub,
4305 &quantize_tensor_per_tensor_affine_cpu);
4306 REGISTER_DISPATCH(
4307 quantize_tensor_per_channel_affine_stub,
4308 &quantize_tensor_per_channel_affine_cpu);
4309 REGISTER_DISPATCH(
4310 quantize_tensor_per_channel_float_qparams_stub,
4311 &quantize_tensor_per_channel_float_qparams_cpu);
4312 REGISTER_DISPATCH(quantized_normalize_stub, &quantized_normalize_kernel);
4313 REGISTER_DISPATCH(quantized_groupnorm_nhwc_stub, &quantized_groupnorm_nhwc_kernel);
4314 REGISTER_DISPATCH(qupsample_bilinear2d_nhwc_stub,
4315 &qupsample_bilinear2d_nhwc_kernel);
4316 REGISTER_DISPATCH(
4317 quantize_tensor_per_tensor_affine_sub_byte_stub,
4318 &quantize_tensor_per_tensor_affine_sub_byte_cpu);
4319 REGISTER_DISPATCH(
4320 dequantize_tensor_per_tensor_affine_sub_byte_stub,
4321 &dequantize_tensor_per_tensor_affine_sub_byte_cpu);
4322 REGISTER_DISPATCH(
4323 masked_fill_kernel_quantized_stub,
4324 &masked_fill_kernel_quantized_cpu);
4325 REGISTER_DISPATCH(
4326 index_put_kernel_quantized_stub,
4327 &index_put_kernel_quantized_cpu);
4328 REGISTER_DISPATCH(qmean_inner_dim_stub, &qmean_inner_dim_kernel);
4329 REGISTER_DISPATCH(qstd_inner_dim_stub, &qstd_inner_dim_kernel);
4330 } // namespace at::native
4331 // NOLINTEND(*-c-arrays)
4332