xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/core/List.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/native/Activation.h>
7 #include <ATen/native/TopKImpl.h>
8 #include <ATen/native/TensorIterator.h>
9 #include <ATen/native/UpSample.h>
10 #include <ATen/native/cpu/IndexKernelUtils.h>
11 #include <ATen/native/cpu/Loops.h>
12 #include <ATen/native/quantized/AffineQuantizer.h>
13 #include <ATen/native/quantized/FakeQuantAffine.h>
14 #include <ATen/native/quantized/IndexKernel.h>
15 #include <ATen/native/quantized/cpu/QuantizedOps.h>
16 #include <ATen/native/cpu/utils.h>
17 #include <c10/util/irange.h>
18 #include <ATen/native/cpu/utils.h>
19 
20 #ifndef AT_PER_OPERATOR_HEADERS
21 #include <ATen/Functions.h>
22 #else
23 #include <ATen/ops/_empty_affine_quantized.h>
24 #include <ATen/ops/empty.h>
25 #endif
26 
27 #include <cmath>
28 #ifdef USE_FBGEMM
29 #include <fbgemm/QuantUtils.h>
30 #endif
31 #ifdef _OPENMP
32 #include <omp.h>
33 #endif
34 #if defined(__ARM_NEON__) || defined(__aarch64__)
35 #include <ATen/quantized/Quantizer.h>
36 #include <arm_neon.h>
37 #endif
38 
39 
40 // NOLINTBEGIN(*-c-arrays)
41 namespace at::native {
42 namespace {
43 
check_tensor_memory_format(const Tensor & ref,const Tensor & other)44 void check_tensor_memory_format(const Tensor& ref, const Tensor& other) {
45   TORCH_CHECK(
46       ref.is_contiguous(ref.suggest_memory_format()),
47       "Quantized tensor should be contiguous");
48   TORCH_CHECK(
49       other.is_contiguous(ref.suggest_memory_format()),
50       "Float tensor should be contiguous "
51       "in same memory format as quantized tensor");
52 }
53 
54 // ****************** HEY YOU! YES YOU! Read this! ********************
55 //
56 // Please read the README.md in this directory before editing this file
57 
58 template <bool ReLUFused = false>
qcat_nhwc_kernel(const MaterializedITensorListRef & qxs,int64_t dim,double scale,int64_t zero_point)59 Tensor qcat_nhwc_kernel(
60     const MaterializedITensorListRef& qxs,
61     int64_t dim,
62     double scale,
63     int64_t zero_point) {
64   const at::Tensor& qx0 = qxs[0];
65   int64_t C_out = 0;
66   std::vector<int64_t> Cs_in;
67   // Prefix sum of input channels for fast indexing
68   std::vector<int64_t> Cs_sum;
69   std::vector<double> scales;
70   std::vector<int64_t> zero_pts;
71   std::vector<void*> data_ptrs;
72   std::vector<bool> is_fast_path;
73 
74   for (const at::Tensor& qx : qxs) {
75     TORCH_CHECK(
76         qx.dim() == qx0.dim(),
77         "Tensors must have the same number of dimensions: got ",
78         qx.dim(),
79         " and ",
80         qx0.dim());
81 #define CHECK_DIM(d)                                            \
82   TORCH_CHECK(                                                  \
83       qx.size(d) == qx0.size(d),                                \
84       "Sizes of tensors must match expect in dimension 1. Got", \
85       qx.size(d),                                               \
86       " and ",                                                  \
87       qx0.size(d));
88     CHECK_DIM(0);
89     CHECK_DIM(2);
90     CHECK_DIM(3);
91     TORCH_CHECK(
92         qx.scalar_type() == qx0.scalar_type(),
93         "Expected object of scalar type ",
94         toString(qx0.scalar_type()),
95         " but got scalar type ",
96         toString(qx.scalar_type()));
97     Cs_in.push_back(qx.size(1));
98     Cs_sum.push_back(C_out);
99     C_out += qx.size(1);
100     scales.push_back(qx.q_scale());
101     zero_pts.push_back(qx.q_zero_point());
102     data_ptrs.push_back(qx.data_ptr());
103     is_fast_path.push_back(
104         qx.q_scale() == scale &&
105         qx.q_zero_point() == zero_point);
106   }
107 
108   const int64_t N = qx0.size(0);
109   const int64_t H = qx0.size(2);
110   const int64_t W = qx0.size(3);
111   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
112   float inv_scale = 1.0 / scale;
113 
114   auto output = at::_empty_affine_quantized(
115       {N, C_out, H, W},
116       qx0.options().memory_format(MemoryFormat::ChannelsLast),
117       scale,
118       zero_point,
119       std::nullopt);
120 
121   // N, H, and W are explicitly captured here because there's a bug in GCC5
122   // and clang5 which causes an internal compiler error if they're not
123   AT_DISPATCH_QINT_TYPES(output.scalar_type(), "qcat_nhwc", [&, N, H, W]() {
124     using Vec = Vectorized<scalar_t>;
125     at::parallel_for(0, N * H * W, 0, [&](int64_t begin, int64_t end) {
126       for (const auto i : c10::irange(begin, end)) {
127         // loop over input tensors
128         for (const auto tidx : c10::irange(Cs_in.size())) {
129           scalar_t::underlying* optr =
130               reinterpret_cast<scalar_t::underlying*>(output.data_ptr()) +
131               i * C_out + Cs_sum[tidx];
132 
133           auto curr_C = Cs_in[tidx];
134           float curr_scale = scales[tidx];
135           int64_t curr_zero_pt = zero_pts[tidx];
136 
137           scalar_t::underlying* iptr =
138               reinterpret_cast<scalar_t::underlying*>(data_ptrs[tidx]) +
139               i * curr_C;
140 
141           if (is_fast_path[tidx] && !ReLUFused) {
142             std::memcpy(optr, iptr, curr_C * sizeof(typename scalar_t::underlying));
143             continue;
144           }
145 
146           constexpr auto VLEN = Vec::size();
147           int64_t c = 0;
148 
149           // Vectorized loop
150           if (c + VLEN <= curr_C) {
151             auto curr_scale_vec = Vectorized<float>(curr_scale);
152             auto curr_zero_pt_vec = Vectorized<float>((float)curr_zero_pt);
153             auto scale_neg_zp_premul = curr_scale_vec * curr_zero_pt_vec.neg();
154             for (; c + VLEN <= curr_C; c += VLEN) {
155               auto inp_vec = Vec::loadu(iptr + c);
156               auto float_values = inp_vec.dequantize(
157                   curr_scale_vec, curr_zero_pt_vec, scale_neg_zp_premul);
158               Vec::float_vec_return_type retvals;
159               for (int i = 0; i < Vec::float_num_vecs(); ++i) {
160                 if constexpr (ReLUFused) {
161                   retvals[i] =
162                       vec::maximum(float_values[i], Vectorized<float>(0.0f));
163                 } else {
164                   retvals[i] = float_values[i];
165                 }
166               }
167               auto quantized =
168                   Vec::quantize(retvals, scale, zero_point, inv_scale);
169               quantized.store(optr + c);
170             }
171           }
172 
173           // Vectorized loop for channel between 8 and 32 (avx2)
174           constexpr auto kVLEN = Vectorized<float>::size();
175           int64_t elem_size = curr_C - c;
176           if ((VLEN == 4 * kVLEN) && elem_size >= kVLEN) {
177             auto curr_scale_vec = Vectorized<float>(curr_scale);
178             auto curr_zero_pt_vec = Vectorized<float>((float)curr_zero_pt);
179             auto scale_neg_zp_premul = curr_scale_vec * curr_zero_pt_vec.neg();
180             int64_t vec_num = elem_size / kVLEN;
181             std::array<typename scalar_t::underlying, VLEN> buf_in{};
182             memcpy(buf_in.data(), iptr + c, vec_num * kVLEN);
183             auto inp_vec = Vec::loadu(buf_in.data());
184             auto float_values = inp_vec.dequantize(
185                 curr_scale_vec, curr_zero_pt_vec, scale_neg_zp_premul);
186             Vec::float_vec_return_type retvals;
187             for (int i = 0; i < vec_num; ++i) {
188               if constexpr (ReLUFused) {
189                 retvals[i] =
190                     vec::maximum(float_values[i], Vectorized<float>(0.0f));
191               } else {
192                 retvals[i] = float_values[i];
193               }
194             }
195             auto quantized =
196                 Vec::quantize(retvals, scale, zero_point, inv_scale);
197             quantized.store(optr + c, vec_num * kVLEN);
198             c += vec_num * kVLEN;
199           }
200 
201           // Scalar loop
202           for (; c < curr_C; ++c) {
203             auto float_val = at::native::dequantize_val(
204                 curr_scale,
205                 curr_zero_pt,
206                 reinterpret_cast<scalar_t*>(iptr)[c]);
207             if constexpr (ReLUFused) {
208               float_val = std::max(0.0f, float_val);
209             }
210             optr[c] = at::native::quantize_val<scalar_t>(
211                           scale, zero_point, float_val)
212                           .val_;
213           } // for c
214         } // for tidx
215       } // for i
216     });
217   });
218 
219   return output;
220 }
221 
222 // horizontal sum over a range of uint8_t
hsum(const uint8_t * A,int len)223 int64_t hsum(const uint8_t* A, int len) {
224   int64_t row_sum = 0;
225   int i = 0;
226 
227 #ifdef CPU_CAPABILITY_AVX2
228   __m256i sum_v = _mm256_setzero_si256();
229   __m256i one_epi16_v = _mm256_set1_epi16(1);
230   __m256i one_epi8_v = _mm256_set1_epi8(1);
231   // vectorized
232   for (; i < len / 32 * 32; i += 32) {
233     __m256i src_v = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i));
234     sum_v = _mm256_add_epi32(
235       sum_v,
236       _mm256_madd_epi16(
237         // first argument is unsigned, second is signed
238         _mm256_maddubs_epi16(src_v, one_epi8_v),
239       one_epi16_v)
240     );
241   }
242 
243   alignas(64) int32_t temp[8];
244   _mm256_store_si256(reinterpret_cast<__m256i*>(temp), sum_v);
245   for (const auto k : c10::irange(8)) {
246     row_sum += temp[k];
247   }
248 #elif defined(CPU_CAPABILITY_AVX512)
249   __m512i sum_v = _mm512_setzero_si512();
250   __m512i one_epi16_v = _mm512_set1_epi16(1);
251   __m512i one_epi8_v = _mm512_set1_epi8(1);
252   // vectorized
253   for (; i < len / 64 * 64; i += 64) {
254     __m512i src_v = _mm512_loadu_si512(reinterpret_cast<__m512i const*>(A + i));
255     sum_v = _mm512_add_epi32(
256       sum_v,
257       _mm512_madd_epi16(
258         // first argument is unsigned, second is signed
259         _mm512_maddubs_epi16(src_v, one_epi8_v),
260       one_epi16_v)
261     );
262   }
263 
264   alignas(64) int32_t temp[16];
265   _mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum_v);
266   for (const auto k : c10::irange(16)) {
267     row_sum += temp[k];
268   }
269 #endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512
270 
271   // scalar
272   for (; i < len; ++i) {
273     row_sum += A[i];
274   }
275 
276   return row_sum;
277 }
278 
279 // horizontal sum over a range of int8_t
hsum(const int8_t * A,int len)280 int64_t hsum(const int8_t* A, int len) {
281   int64_t row_sum = 0;
282   int i = 0;
283 
284 #ifdef CPU_CAPABILITY_AVX2
285   __m256i sum_v = _mm256_setzero_si256();
286   __m256i one_epi16_v = _mm256_set1_epi16(1);
287   __m256i one_epi8_v = _mm256_set1_epi8(1);
288   // vectorized
289   for (; i < len / 32 * 32; i += 32) {
290     __m256i src_v = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i));
291     sum_v = _mm256_add_epi32(
292       sum_v,
293       _mm256_madd_epi16(
294         // first argument is unsigned, second is signed
295         _mm256_maddubs_epi16(one_epi8_v, src_v),
296       one_epi16_v)
297     );
298   }
299 
300   alignas(64) int32_t temp[8];
301   _mm256_store_si256(reinterpret_cast<__m256i*>(temp), sum_v);
302   for (const auto k : c10::irange(8)) {
303     row_sum += temp[k];
304   }
305 #elif defined(CPU_CAPABILITY_AVX512)
306   __m512i sum_v = _mm512_setzero_si512();
307   __m512i one_epi16_v = _mm512_set1_epi16(1);
308   __m512i one_epi8_v = _mm512_set1_epi8(1);
309   // vectorized
310   for (; i < len / 64 * 64; i += 64) {
311     __m512i src_v = _mm512_loadu_si512(reinterpret_cast<__m512i const*>(A + i));
312     sum_v = _mm512_add_epi32(
313       sum_v,
314       _mm512_madd_epi16(
315         // first argument is unsigned, second is signed
316         _mm512_maddubs_epi16(one_epi8_v, src_v),
317       one_epi16_v)
318     );
319   }
320 
321   alignas(64) int32_t temp[16];
322   _mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum_v);
323   for (const auto k : c10::irange(16)) {
324     row_sum += temp[k];
325   }
326 #endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512
327 
328   // scalar
329   for (; i < len; ++i) {
330     row_sum += A[i];
331   }
332 
333   return row_sum;
334 }
335 
336 // horizontal sum over a range of int32_t
hsum(const int32_t * A,int len)337 int64_t hsum(const int32_t* A, int len) {
338   int64_t row_sum = 0;
339   int i = 0;
340 
341 #ifdef CPU_CAPABILITY_AVX2
342   __m256i sum_epi64 = _mm256_setzero_si256();
343   // vectorized
344   for (; i < len / 8 * 8; i += 8) {
345     __m256i src_epi32 = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i));
346     // widen
347     __m128i src_lo_epi32 = _mm256_castsi256_si128(src_epi32);
348     __m128i src_hi_epi32 = _mm256_extracti128_si256(src_epi32, 1);
349     __m256i src_lo_epi64 = _mm256_cvtepi32_epi64(src_lo_epi32);
350     __m256i src_hi_epi64 = _mm256_cvtepi32_epi64(src_hi_epi32);
351     // add
352     sum_epi64 = _mm256_add_epi64(sum_epi64, src_lo_epi64);
353     sum_epi64 = _mm256_add_epi64(sum_epi64, src_hi_epi64);
354   }
355 
356   alignas(64) int64_t temp[4];
357   _mm256_store_si256(reinterpret_cast<__m256i*>(temp), sum_epi64);
358   for (const auto k : c10::irange(4)) {
359     row_sum += temp[k];
360   }
361 #elif defined(CPU_CAPABILITY_AVX512)
362   __m512i sum_epi64 = _mm512_setzero_si512();
363   // vectorized
364   for (; i < len / 16 * 16; i += 16) {
365     __m512i src_epi32 = _mm512_loadu_si512(reinterpret_cast<__m512i const*>(A + i));
366     // widen
367     __m256i src_lo_epi32 = _mm512_castsi512_si256(src_epi32);
368     __m256i src_hi_epi32 = _mm512_extracti32x8_epi32(src_epi32, 1);
369     __m512i src_lo_epi64 = _mm512_cvtepi32_epi64(src_lo_epi32);
370     __m512i src_hi_epi64 = _mm512_cvtepi32_epi64(src_hi_epi32);
371     // add
372     sum_epi64 = _mm512_add_epi64(sum_epi64, src_lo_epi64);
373     sum_epi64 = _mm512_add_epi64(sum_epi64, src_hi_epi64);
374   }
375 
376   alignas(64) int64_t temp[8];
377   _mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum_epi64);
378   for (const auto k : c10::irange(8)) {
379     row_sum += temp[k];
380   }
381 #endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512
382 
383   // scalar
384   for (; i < len; ++i) {
385     row_sum += A[i];
386   }
387 
388   return row_sum;
389 }
390 
391 // horizontal sum of squares over a range of uint8_t
hsum_sq(const uint8_t * A,int len)392 int64_t hsum_sq(const uint8_t* A, int len) {
393   int64_t row_sum = 0;
394   int i = 0;
395 
396 #ifdef CPU_CAPABILITY_AVX2
397   // vectorized
398   __m256i sum_v_epu32 = _mm256_setzero_si256();
399   alignas(64) int32_t temp[8];
400   int overflow_threshold = 262144; // 2147483647(max of int32)/(256*256)*8 = 262144
401   int loop = len / overflow_threshold + 1;
402   for(int j=0; j<=loop; j++){
403     for (; ((i < overflow_threshold * j) && (i < len / 16 * 16)); i += 16) {
404       // (i15, ..., i0)
405       __m128i src_epu8 = _mm_loadu_si128(reinterpret_cast<__m128i const*>(A + i));
406       __m256i src_epu16 = _mm256_cvtepu8_epi16(src_epu8);
407       // (i15 ^ 2, ..., i0 ^ 2)
408       __m256i sq_epu16 = _mm256_mullo_epi16(src_epu16, src_epu16);
409       // (i7 ^ 2, ..., i0 ^ 2)
410       __m128i sq_lo_epu16 = _mm256_castsi256_si128(sq_epu16);
411       // (i15 ^ 2, ..., i8 ^ 2)
412       __m128i sq_hi_epu16 = _mm256_extractf128_si256(sq_epu16, 1);
413       // widen to epu32
414       __m256i sq_lo_epu32 = _mm256_cvtepu16_epi32(sq_lo_epu16);
415       __m256i sq_hi_epu32 = _mm256_cvtepu16_epi32(sq_hi_epu16);
416       // add to running sum
417       sum_v_epu32 = _mm256_add_epi32(sum_v_epu32, sq_lo_epu32);
418       sum_v_epu32 = _mm256_add_epi32(sum_v_epu32, sq_hi_epu32);
419     }
420     _mm256_store_si256(reinterpret_cast<__m256i*>(temp), sum_v_epu32);
421     for (const auto k : c10::irange(8)) {
422       row_sum += temp[k];
423     }
424     sum_v_epu32 = _mm256_setzero_si256();
425   }
426 #elif defined(CPU_CAPABILITY_AVX512)
427   __m512i sum_v_epu32 = _mm512_setzero_si512();
428   alignas(64) int32_t temp[16];
429   int overflow_threshold = 262144; // 2147483647(max of int32)/(512*512)*8 = 262144
430   int loop = len / overflow_threshold + 1;
431   for(int j=0; j<=loop; j++){
432     for (; ((i < overflow_threshold * j) && (i < len / 32 * 32)); i += 32) {
433       // (i31, ..., i0)
434       __m256i src_epu8 = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i));
435       __m512i src_epu16 = _mm512_cvtepu8_epi16(src_epu8);
436       // (i31 ^ 2, ..., i0 ^ 2)
437       __m512i sq_epu16 = _mm512_mullo_epi16(src_epu16, src_epu16);
438       // (i15 ^ 2, ..., i0 ^ 2)
439       __m256i sq_lo_epu16 = _mm512_castsi512_si256(sq_epu16);
440       // (i31 ^ 2, ..., i16 ^ 2)
441       __m256i sq_hi_epu16 = _mm512_extracti32x8_epi32(sq_epu16, 1);
442       // widen to epu32
443       __m512i sq_lo_epu32 = _mm512_cvtepu16_epi32(sq_lo_epu16);
444       __m512i sq_hi_epu32 = _mm512_cvtepu16_epi32(sq_hi_epu16);
445       // add to running sum
446       sum_v_epu32 = _mm512_add_epi32(sum_v_epu32, sq_lo_epu32);
447       sum_v_epu32 = _mm512_add_epi32(sum_v_epu32, sq_hi_epu32);
448     }
449     _mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum_v_epu32);
450     for (const auto k : c10::irange(16)) {
451       row_sum += temp[k];
452     }
453     sum_v_epu32 = _mm512_setzero_si512();
454   }
455 #endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512
456 
457   // scalar
458   for (; i < len; ++i) {
459     row_sum += A[i] * A[i];
460   }
461 
462   return row_sum;
463 }
464 
465 // horizontal sum of squares over a range of int8_t
hsum_sq(const int8_t * A,int len)466 int64_t hsum_sq(const int8_t* A, int len) {
467   int64_t row_sum = 0;
468   int i = 0;
469 
470 #ifdef CPU_CAPABILITY_AVX2
471   // vectorized
472   __m256i sum_v_epi32 = _mm256_setzero_si256();
473   alignas(64) int32_t temp[8];
474 
475   int overflow_threshold = 1048576; //2147483647/(128*128)*8 = 1048576
476   int loop = len / overflow_threshold + 1;
477 
478   for(int j=0; j<=loop; j++){
479     for (; ((i < overflow_threshold * j) && (i < len / 16 * 16)); i += 16) {
480       // (i15, ..., i0)
481       __m128i src_epi8 = _mm_loadu_si128(reinterpret_cast<__m128i const*>(A + i));
482       __m256i src_epi16 = _mm256_cvtepi8_epi16(src_epi8);
483       // (i15 ^ 2, ..., i0 ^ 2)
484       __m256i sq_epi16 = _mm256_mullo_epi16(src_epi16, src_epi16);
485       // (i7 ^ 2, ..., i0 ^ 2)
486       __m128i sq_lo_epi16 = _mm256_castsi256_si128(sq_epi16);
487       // (i15 ^ 2, ..., i8 ^ 2)
488       __m128i sq_hi_epi16 = _mm256_extractf128_si256(sq_epi16, 1);
489       // widen to epi32
490       __m256i sq_lo_epi32 = _mm256_cvtepi16_epi32(sq_lo_epi16);
491       __m256i sq_hi_epi32 = _mm256_cvtepi16_epi32(sq_hi_epi16);
492       // add to running sum
493       sum_v_epi32 = _mm256_add_epi32(sum_v_epi32, sq_lo_epi32);
494       sum_v_epi32 = _mm256_add_epi32(sum_v_epi32, sq_hi_epi32);
495     }
496     _mm256_store_si256(reinterpret_cast<__m256i*>(temp), sum_v_epi32);
497 
498     for (const auto k : c10::irange(8)) {
499       row_sum += temp[k];
500     }
501     sum_v_epi32 = _mm256_setzero_si256();
502   }
503 #elif defined(CPU_CAPABILITY_AVX512)
504   // vectorized
505   __m512i sum_v_epi32 = _mm512_setzero_si512();
506   alignas(64) int32_t temp[16];
507 
508   int overflow_threshold = 1048576; //2147483647/(256*256)*8 = 1048576
509   int loop = len / overflow_threshold + 1;
510 
511   for(int j=0; j<=loop; j++){
512     for (; ((i < overflow_threshold * j) && (i < len / 32 * 32)); i += 32) {
513       // (i31, ..., i0)
514       __m256i src_epi8 = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i));
515       __m512i src_epi16 = _mm512_cvtepi8_epi16(src_epi8);
516       // (i31 ^ 2, ..., i0 ^ 2)
517       __m512i sq_epi16 = _mm512_mullo_epi16(src_epi16, src_epi16);
518       // (i15 ^ 2, ..., i0 ^ 2)
519       __m256i sq_lo_epi16 = _mm512_castsi512_si256(sq_epi16);
520       // (i31 ^ 2, ..., i16 ^ 2)
521       __m256i sq_hi_epi16 = _mm512_extracti32x8_epi32(sq_epi16, 1);
522       // widen to epi32
523       __m512i sq_lo_epi32 = _mm512_cvtepi16_epi32(sq_lo_epi16);
524       __m512i sq_hi_epi32 = _mm512_cvtepi16_epi32(sq_hi_epi16);
525       // add to running sum
526       sum_v_epi32 = _mm512_add_epi32(sum_v_epi32, sq_lo_epi32);
527       sum_v_epi32 = _mm512_add_epi32(sum_v_epi32, sq_hi_epi32);
528     }
529     _mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum_v_epi32);
530 
531     for (const auto k : c10::irange(16)) {
532       row_sum += temp[k];
533     }
534     sum_v_epi32 = _mm512_setzero_si512();
535   }
536 #endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512
537 
538   // scalar
539   for (; i < len; ++i) {
540     row_sum += A[i] * A[i];
541   }
542 
543   return row_sum;
544 }
545 
546 // horizontal sum os squares over a range of int32_t
547 // floats throughout are necessary to prevent overflow
hsum_sq(const int32_t * A,int len)548 float hsum_sq(const int32_t* A, int len) {
549   float row_sum = 0;
550   int i = 0;
551 
552 #ifdef CPU_CAPABILITY_AVX2
553   __m256 sum_ps = _mm256_setzero_ps();
554   // vectorized
555   for (; i < len / 8 * 8; i += 8) {
556     __m256i src_epi32 = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i));
557     __m256 src_ps = _mm256_cvtepi32_ps(src_epi32);
558     sum_ps = _mm256_add_ps(sum_ps, _mm256_mul_ps(src_ps, src_ps));
559   }
560 
561   alignas(64) float temp[8];
562   _mm256_store_ps(temp, sum_ps);
563   for (const auto k : c10::irange(8)) {
564     row_sum += static_cast<float>(temp[k]);
565   }
566 #elif defined(CPU_CAPABILITY_AVX512)
567   __m512 sum_ps = _mm512_setzero_ps();
568   // vectorized
569   for (; i < len / 16 * 16; i += 16) {
570     __m512i src_epi32 = _mm512_loadu_si512(reinterpret_cast<__m512i const*>(A + i));
571     __m512 src_ps = _mm512_cvtepi32_ps(src_epi32);
572     sum_ps = _mm512_add_ps(sum_ps, _mm512_mul_ps(src_ps, src_ps));
573   }
574 
575   alignas(64) float temp[16];
576   _mm512_store_ps(temp, sum_ps);
577   for (const auto k : c10::irange(16)) {
578     row_sum += static_cast<float>(temp[k]);
579   }
580 #endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512
581 
582   // scalar
583   for (; i < len; ++i) {
584     int64_t cur = static_cast<int64_t>(A[i]);
585     row_sum += (float)cur * (float)cur;
586   }
587 
588   return row_sum;
589 }
590 
qrelu_kernel(const Tensor & qx,Tensor & qy)591 void qrelu_kernel(const Tensor& qx, Tensor& qy) {
592   const auto zero_point = qx.q_zero_point();
593   AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qrelu", [&]() {
594     qy = at::_empty_affine_quantized(
595         qx.sizes(),
596         at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
597         qx.q_scale(),
598         qx.q_zero_point(),
599         std::nullopt);
600     using Vec = Vectorized<scalar_t>;
601     auto zero_point_vec = Vec(scalar_t(zero_point));
602     auto iter = TensorIterator::unary_op(qy, qx);
603     cpu_kernel_vec(
604         iter,
605         [&](scalar_t value) -> scalar_t {
606           return scalar_t(std::max<underlying_t>(value.val_, zero_point));
607         },
608         [&](Vec value) -> Vec { return value.relu(zero_point_vec); });
609   });
610 }
611 
leaky_qrelu_out_kernel(Tensor & out,const Tensor & qx,const Scalar & negval_)612 static void leaky_qrelu_out_kernel(Tensor& out, const Tensor& qx,
613                                    const Scalar& negval_) {
614   int64_t i_zp = qx.q_zero_point();
615   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
616   float i_scale = qx.q_scale();
617 
618   int64_t o_zp = out.q_zero_point();
619   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
620   float o_scale = out.q_scale();
621   float o_inv_scale = 1.0f / o_scale;
622 
623   float negval = negval_.to<float>();
624 
625   AT_DISPATCH_QINT_TYPES(out.scalar_type(), "leaky_qrelu", [&] {
626     using Vec = Vectorized<float>;  // Naive implementation uses dequant/quant loop.
627     using qVec = Vectorized<scalar_t>;
628     Vec zero_vec = Vec(0.0f);
629     Vec one_vec = Vec(1.0f);
630 
631     Vec i_scale_vec = Vec((float)i_scale);
632     Vec i_zp_vec = Vec((float)i_zp);
633     Vec i_scale_zp_neg_premul_vec = i_scale_vec * i_zp_vec.neg();
634 
635     Vec negval_vec = Vec(negval);
636 
637     auto iter = TensorIterator::unary_op(out, qx);
638 
639     cpu_kernel_vec(
640         iter,
641         [&](scalar_t value_qx) -> scalar_t {
642           auto value_dx = at::native::dequantize_val(i_scale, i_zp, value_qx);
643           auto value_dy = value_dx > 0 ? value_dx : value_dx * negval;
644           return at::native::quantize_val<scalar_t>(o_scale, o_zp, value_dy);
645         },
646         [&](qVec qx_vec) -> qVec {
647           /* Vectorized implementation creates a multiplicand vector, which has
648            * "alpha" for all negative dx values and ones-vector for all
649            * positive values of dx. The multiplicand then is multiplied by the
650            * input.
651            */
652           auto dx_vec_vec = qx_vec.dequantize(i_scale_vec, i_zp_vec,
653                                               i_scale_zp_neg_premul_vec);
654           for (auto & dx_vec : dx_vec_vec) {
655             const auto multiplicand = Vec::blendv(negval_vec, one_vec,
656                                                   dx_vec > zero_vec);
657             dx_vec *= multiplicand;
658           }
659           return qVec::quantize(dx_vec_vec, o_scale, o_zp, o_inv_scale);
660         });
661   });
662 }
663 
qprelu_out_kernel(Tensor & out,const Tensor & qx,const Tensor & qw)664 static void qprelu_out_kernel(Tensor& out,
665                               const Tensor& qx,
666                               const Tensor& qw) {
667   int32_t i_zp = static_cast<int32_t>(qx.q_zero_point());
668   float i_scale = static_cast<float>(qx.q_scale());
669 
670   int32_t w_zp = static_cast<int32_t>(qw.q_zero_point());
671   float w_scale = static_cast<float>(qw.q_scale());
672 
673   int32_t o_zp = static_cast<int32_t>(out.q_zero_point());
674   float o_scale = static_cast<float>(out.q_scale());
675   float o_inv_scale = 1.0f / o_scale;
676 
677   float multiplier = i_scale * w_scale * o_inv_scale;
678 
679   int64_t input_ndim = qx.dim();
680   TORCH_CHECK(input_ndim > 0, "qprelu: zero-dim input tensor is not allowed.");
681 
682   // This logic is present in at::prelu and repeated here, as this path can be
683   // hit via quantized::prelu, which is registered under quantized/cpu/qprelu.cpu
684   auto qw_nd = qw;
685   if (input_ndim != qw_nd.dim()) {
686     DimVector dim_w(input_ndim, 1);
687     if (input_ndim > 1) {
688       dim_w[1] = qw.numel();
689     }
690     // This will always be a view in CPU/CUDA, but some backends
691     // like MKLDNN do not support views
692     qw_nd = qw_nd.reshape(dim_w);
693   }
694 
695   auto iter = TensorIteratorConfig()
696     .add_output(out)
697     .add_input(qx)
698     .add_input(qw_nd)
699     .build();
700 
701   AT_DISPATCH_QINT_TYPES(out.scalar_type(), "qprelu", [&] {
702     using qVec = Vectorized<scalar_t>;
703     qVec i_zp_vec = qVec(static_cast<scalar_t>(i_zp));
704     qVec w_zp_vec = qVec(static_cast<scalar_t>(w_zp));
705 
706     // Quantized one as weight
707     auto qw_one = at::native::quantize_val<scalar_t>(w_scale, w_zp, 1.0f);
708     qVec vec_qw_one = qVec(qw_one);
709     auto vec_qw_one_sub_zp = vec_qw_one.widening_subtract(w_zp_vec)[0];
710     int32_t qw_one_sub_zp = qw_one.val_ - w_zp;
711 
712     cpu_kernel_vec(
713       iter,
714       [=](scalar_t val_qx, scalar_t val_qw) -> scalar_t {
715         int32_t qx_pos = std::max(static_cast<int32_t>(val_qx.val_), i_zp);
716         int32_t qx_neg = std::min(static_cast<int32_t>(val_qx.val_), i_zp);
717         int32_t qx_pos_sub_zp = qx_pos - i_zp;
718         int32_t qx_neg_sub_zp = qx_neg - i_zp;
719         int32_t qw_sub_zp = val_qw.val_ - w_zp;
720         auto qy_sub_zp = qx_pos_sub_zp * qw_one_sub_zp + qx_neg_sub_zp * qw_sub_zp;
721         return at::native::requantize_from_int<scalar_t>(
722             multiplier, o_zp, qy_sub_zp);
723       },
724       [=](qVec vec_qx, qVec vec_qw) -> qVec {
725         auto vec_qx_pos = vec_qx.maximum(i_zp_vec);
726         auto vec_qx_neg = vec_qx.minimum(i_zp_vec);
727         qVec::int_vec_return_type qx_pos_sub_zp = vec_qx_pos.widening_subtract(i_zp_vec);
728         qVec::int_vec_return_type qx_neg_sub_zp = vec_qx_neg.widening_subtract(i_zp_vec);
729         qVec::int_vec_return_type qw_sub_zp = vec_qw.widening_subtract(w_zp_vec);
730         qVec::int_vec_return_type qy_sub_zp;
731         for (const auto i : c10::irange(qVec::int_num_vecs())) {
732           qy_sub_zp[i] = qx_pos_sub_zp[i] * vec_qw_one_sub_zp + qx_neg_sub_zp[i] * qw_sub_zp[i];
733         }
734         return qVec::requantize_from_int(qy_sub_zp, multiplier, o_zp);
735       });
736   });
737 
738 }
739 
qgelu_kernel(const Tensor & qx,Tensor & qy,GeluType approximate)740 void qgelu_kernel(const Tensor& qx, Tensor& qy, GeluType approximate) {
741   int64_t zero_point = qx.q_zero_point();
742   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
743   float scale = qx.q_scale();
744   auto scale_vec = Vectorized<float>(scale);
745   auto zero_point_vec = Vectorized<float>((float)zero_point);
746   auto scale_neg_zp_premul_vec = scale_vec * zero_point_vec.neg();
747   int64_t output_zero_point = zero_point;
748   float output_scale = scale;
749   float inv_output_scale = 1.0 / output_scale;
750   const auto kAlphaVec = Vectorized<float>(M_SQRT1_2);
751   const auto kBetaVec = Vectorized<float>(M_SQRT2 * M_2_SQRTPI * 0.5);
752   const auto kKappaVec = Vectorized<float>(0.044715);
753   const auto kOneVec = Vectorized<float>(1);
754   const auto kPointFiveVec = Vectorized<float>(0.5);
755 
756   if (approximate == GeluType::Tanh) {
757     AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qgelu", [&]() {
758       qy = at::_empty_affine_quantized(
759           qx.sizes(),
760           at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
761           output_scale,
762           output_zero_point,
763           std::nullopt);
764       auto iter = TensorIterator::unary_op(qy, qx);
765 
766       using Vec = Vectorized<scalar_t>;
767       cpu_kernel_vec(
768           iter,
769           [&](scalar_t value_qx) -> scalar_t {
770             const auto value_dx =
771                 at::native::dequantize_val(scale, zero_point, value_qx);
772 
773             const auto kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
774             const auto kKappa = 0.044715;
775             const auto x_cube = value_dx * value_dx * value_dx;
776             const auto inner = kBeta * (value_dx + kKappa * x_cube);
777             const auto value_dy = 0.5 * value_dx * (1.0 + std::tanh(inner));
778 
779             return at::native::quantize_val<scalar_t>(
780                 output_scale, output_zero_point, value_dy);
781           },
782           [&](Vec value_qx) -> Vec {
783             auto value_dx = value_qx.dequantize(
784                 scale_vec, zero_point_vec, scale_neg_zp_premul_vec);
785             for (auto & value : value_dx) {
786               auto value_cube = value * value * value;
787               auto inner = kBetaVec * (value + kKappaVec * value_cube);
788               value = kPointFiveVec * value * (kOneVec + inner.tanh());
789             }
790             return Vec::quantize(
791                 value_dx, output_scale, output_zero_point, inv_output_scale);
792           });
793     });
794   } else {
795     AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qgelu", [&]() {
796       qy = at::_empty_affine_quantized(
797           qx.sizes(),
798           at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
799           output_scale,
800           output_zero_point,
801           std::nullopt);
802       auto iter = TensorIterator::unary_op(qy, qx);
803 
804       using Vec = Vectorized<scalar_t>;
805       cpu_kernel_vec(
806           iter,
807           [&](scalar_t value_qx) -> scalar_t {
808             const auto value_dx =
809                 at::native::dequantize_val(scale, zero_point, value_qx);
810             const auto value_dy =
811                 value_dx * 0.5 * (1 + std::erf(value_dx * M_SQRT1_2));
812             return at::native::quantize_val<scalar_t>(
813                 output_scale, output_zero_point, value_dy);
814           },
815           [&](Vec value_qx) -> Vec {
816             auto value_dx = value_qx.dequantize(
817                 scale_vec, zero_point_vec, scale_neg_zp_premul_vec);
818             for (auto & value : value_dx) {
819               value = value * kPointFiveVec * (kOneVec + (value * kAlphaVec).erf());
820             }
821             return Vec::quantize(
822                 value_dx, output_scale, output_zero_point, inv_output_scale);
823           });
824     });
825   }
826 }
827 
828 
qsigmoid_kernel(const Tensor & qx,Tensor & qy,double output_scale,int64_t output_zero_point)829 void qsigmoid_kernel(
830     const Tensor& qx, Tensor& qy, double output_scale, int64_t output_zero_point ) {
831   int64_t zero_point = qx.q_zero_point();
832   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
833   float scale = qx.q_scale();
834   auto scale_vec = Vectorized<float>(scale);
835   auto zero_point_vec = Vectorized<float>((float)zero_point);
836 
837   AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qsigmoid", [&]() {
838     float inv_output_scale = 1.0 / output_scale;
839 
840     qy = at::_empty_affine_quantized(
841         qx.sizes(),
842         at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
843         output_scale,
844         output_zero_point,
845         std::nullopt);
846     auto iter = TensorIterator::unary_op(qy, qx);
847 
848     using Vec = Vectorized<scalar_t>;
849     cpu_kernel_vec(
850         iter,
851         [&](scalar_t value_qx) -> scalar_t {
852           const auto value_dx =
853               at::native::dequantize_val(scale, zero_point, value_qx);
854           const auto value_dy = 1.0f / (1.0 + std::exp((-value_dx)));
855           return at::native::quantize_val<scalar_t>(
856               output_scale, output_zero_point, value_dy);
857         },
858         [&](Vec value_qx) -> Vec {
859           auto value_dx = value_qx.dequantize(scale_vec, zero_point_vec);
860           for (auto & value : value_dx) {
861             value = value.neg();
862             value = value.exp();
863             value = Vectorized<float>(1.0f) + value;
864             value = value.reciprocal();
865           }
866           return Vec::quantize(
867               value_dx, output_scale, output_zero_point, inv_output_scale);
868         });
869   });
870 }
871 
qhardsigmoid_kernel(const Tensor & qx,Tensor & qy)872 void qhardsigmoid_kernel(const Tensor& qx, Tensor& qy) {
873   int64_t zero_point = qx.q_zero_point();
874   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
875   float scale = qx.q_scale();
876   auto scale_vec = Vectorized<float>(scale);
877   auto zero_point_vec = Vectorized<float>((float)zero_point);
878   auto scale_neg_zp_premul_vec = scale_vec * zero_point_vec.neg();
879 
880   AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qhardsigmoid", [&]() {
881 
882     // - Output scale is set to 1.0 / 2^(BIT_NUM)
883     float output_scale = 0.00390625;  // 1.0 / 2^8
884     if (SCALAR_TYPE == at::kQInt32) {
885       output_scale = 2.3283064365386963e-10;  // 1.0 / 2^32
886     }
887     float inv_output_scale = 1.0 / output_scale;
888 
889     // The default zero-point is zero.  As a one-off optimization for
890     // kQInt8, we set the zero-point to -128 to maximize precision in the
891     // [0, 1] output range. kQInt32 can be handled in a future PR if needed.
892     int64_t output_zero_point = 0;
893     if (SCALAR_TYPE == at::kQInt8) {
894       output_zero_point = -128;
895     }
896 
897     qy = at::_empty_affine_quantized(
898         qx.sizes(),
899         at::device(kCPU).dtype(SCALAR_TYPE),
900         output_scale,
901         output_zero_point,
902         qx.suggest_memory_format());
903     auto iter = TensorIterator::unary_op(qy, qx);
904 
905     using qVec = Vectorized<scalar_t>;
906     using fVec = Vectorized<float>;
907     fVec kZeroVec(0.0f);
908     fVec kThreeVec(3.0f);
909     fVec kSixVec(6.0f);
910 
911     // Naive implementation: uses dequantize/execute/quantize routine
912     cpu_kernel_vec(
913         iter,
914         [&](scalar_t qx) -> scalar_t {
915           auto x = at::native::dequantize_val(scale, zero_point, qx);
916           const auto y = std::min(std::max(x + 3.0f, 0.0f), 6.0f) / 6.0f;
917           return at::native::quantize_val<scalar_t>(
918               output_scale, output_zero_point, y);
919         },
920         [&](qVec value_qx) -> qVec {
921           auto value_dx = value_qx.dequantize(
922               scale_vec, zero_point_vec, scale_neg_zp_premul_vec);
923           for (auto & value : value_dx) {
924             value =
925                 vec::minimum(
926                     vec::maximum(value + kThreeVec, kZeroVec),
927                     kSixVec) /
928                 kSixVec;
929           }
930           return qVec::quantize(
931               value_dx, output_scale, output_zero_point, inv_output_scale);
932         });
933   });
934 }
935 
qclamp_kernel(const Tensor & qx,const Scalar & min_scalar,const Scalar & max_scalar,Tensor & qy)936 void qclamp_kernel(
937     const Tensor& qx,
938     const Scalar& min_scalar,
939     const Scalar& max_scalar,
940     Tensor& qy) {
941   AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qclamp", [&]() {
942     qy = at::_empty_affine_quantized(
943         qx.sizes(),
944         at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
945         qx.q_scale(),
946         qx.q_zero_point(),
947         std::nullopt);
948     using Vec = Vectorized<scalar_t>;
949     auto iter = TensorIterator::unary_op(qy, qx);
950     auto min = min_scalar.to<float>();
951     auto max = max_scalar.to<float>();
952     scalar_t min_q = at::native::quantize_val<scalar_t>(
953         qx.q_scale(), qx.q_zero_point(), min);
954     scalar_t max_q = at::native::quantize_val<scalar_t>(
955         qx.q_scale(), qx.q_zero_point(), max);
956     auto min_vec = Vec(min_q);
957     auto max_vec = Vec(max_q);
958     cpu_kernel_vec(
959         iter,
960         [&](scalar_t value) -> scalar_t {
961           underlying_t min_clamped =
962               std::max<underlying_t>(value.val_, min_q.val_);
963           return scalar_t(std::min<underlying_t>(min_clamped, max_q.val_));
964         },
965         [&](Vec val) -> Vec {
966           auto min_clamped = val.maximum(min_vec);
967           return min_clamped.minimum(max_vec);
968         });
969   });
970 }
971 
qclamp_min_kernel(const Tensor & qx,const Scalar & min_scalar,Tensor & qy)972 void qclamp_min_kernel(const Tensor& qx, const Scalar& min_scalar, Tensor& qy) {
973   AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qclamp", [&]() {
974     qy = at::_empty_affine_quantized(
975         qx.sizes(),
976         at::device(kCPU)
977             .dtype(SCALAR_TYPE)
978             .memory_format(qx.suggest_memory_format()),
979         qx.q_scale(),
980         qx.q_zero_point(),
981         std::nullopt);
982     using Vec = Vectorized<scalar_t>;
983     auto iter = TensorIterator::unary_op(qy, qx);
984     auto min = min_scalar.to<float>();
985     scalar_t min_q = at::native::quantize_val<scalar_t>(
986         qx.q_scale(), qx.q_zero_point(), min);
987     auto min_vec = Vec(min_q);
988     cpu_kernel_vec(
989         iter,
990         [&](scalar_t value) -> scalar_t {
991           return scalar_t(std::max<underlying_t>(value.val_, min_q.val_));
992         },
993         [&](Vec val) -> Vec { return val.maximum(min_vec); });
994   });
995 }
996 
qclamp_max_kernel(const Tensor & qx,const Scalar & max_scalar,Tensor & qy)997 void qclamp_max_kernel(const Tensor& qx, const Scalar& max_scalar, Tensor& qy) {
998   AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qclamp", [&]() {
999     qy = at::_empty_affine_quantized(
1000         qx.sizes(),
1001         at::device(kCPU)
1002             .dtype(SCALAR_TYPE)
1003             .memory_format(qx.suggest_memory_format()),
1004         qx.q_scale(),
1005         qx.q_zero_point(),
1006         std::nullopt);
1007     using Vec = Vectorized<scalar_t>;
1008     auto iter = TensorIterator::unary_op(qy, qx);
1009     auto max = max_scalar.to<float>();
1010     scalar_t max_q = at::native::quantize_val<scalar_t>(
1011         qx.q_scale(), qx.q_zero_point(), max);
1012     auto max_vec = Vec(max_q);
1013     cpu_kernel_vec(
1014         iter,
1015         [&](scalar_t value) -> scalar_t {
1016           return scalar_t(std::min<underlying_t>(value.val_, max_q.val_));
1017         },
1018         [&](Vec val) -> Vec { return val.minimum(max_vec); });
1019   });
1020 }
1021 
qthreshold_kernel(const Tensor & qx,const Scalar & threshold_scalar,const Scalar & value_scalar,Tensor & qy)1022 void qthreshold_kernel(
1023   // TODO: For future tasks, since output quantization parameters are set equal to
1024   // the input ones, it might make sense to implement this completely in the
1025   // quantized domain.
1026    const Tensor& qx,
1027    const Scalar& threshold_scalar,
1028    const Scalar& value_scalar,
1029    Tensor& qy) {
1030 
1031   // defines input and output scales and zero_points
1032   int64_t input_zero_point = qx.q_zero_point();
1033   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1034   float input_scale = qx.q_scale();
1035   int64_t output_zero_point = qy.q_zero_point();
1036   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1037   float output_scale = qy.q_scale();
1038   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1039   float inv_output_scale = 1.0 / output_scale;
1040 
1041   AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qthreshold", [&]() {
1042     qy = at::_empty_affine_quantized(
1043       qx.sizes(),
1044       at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
1045       qx.q_scale(),
1046       qx.q_zero_point(),
1047       std::nullopt);
1048 
1049     // vectorized
1050     using Vec = Vectorized<float>;
1051     using qVec = Vectorized<scalar_t>;
1052     // defines the iterator
1053     auto iter = TensorIterator::unary_op(qy, qx);
1054     // defines the vectorized versions
1055     Vec input_scale_vec = Vec(input_scale);
1056     Vec input_zero_point_vec = Vec(input_zero_point);
1057     Vec input_scale_neg_zp_premul_vec = input_scale_vec * input_zero_point_vec.neg();
1058     // defines the floating-point versions of threshold and value
1059     float threshold_float = threshold_scalar.to<float>();
1060     float value_float = value_scalar.to<float>();
1061     Vec threshold_vec = Vec(threshold_float);
1062     Vec value_vec = Vec(value_float);
1063 
1064     // Naive implementation: uses dequantize/execute/quantize routine
1065     cpu_kernel_vec(
1066         iter,
1067         [&](scalar_t value_qx) -> scalar_t {
1068           // dequantize
1069           const auto x = at::native::dequantize_val(input_scale, input_zero_point, value_qx);
1070           // Applies the Threshold operation
1071           const auto y = x > threshold_float ? x : value_float;
1072           // quantize
1073           return at::native::quantize_val<scalar_t>(output_scale, output_zero_point, y);
1074         },
1075         [&](qVec value_qx) -> qVec {
1076           // dequantize
1077           auto dx_vec = value_qx.dequantize(
1078             input_scale_vec, input_zero_point_vec, input_scale_neg_zp_premul_vec);
1079           for (auto & value : dx_vec) {
1080             // check if any elements are below threshold
1081             const auto cmp_to_threshold = value > threshold_vec;
1082             if (cmp_to_threshold.zero_mask()) {
1083               // blend
1084               value = Vec::blendv(value_vec, value, cmp_to_threshold);
1085             }
1086           }
1087           // quantize
1088           return qVec::quantize(dx_vec, output_scale, output_zero_point, inv_output_scale);
1089         });
1090   });
1091 }
1092 
1093 
qhardswish_kernel(const Tensor & qx,Tensor & qy)1094 void qhardswish_kernel(const Tensor& qx, Tensor& qy) {
1095   const auto i_scale = qx.q_scale();
1096   const auto i_zero_point = qx.q_zero_point();
1097 
1098   const auto o_scale = qy.q_scale();
1099   const auto o_zero_point = qy.q_zero_point();
1100   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1101   const float o_inv_scale = 1.0 / o_scale;
1102 
1103   using fVec = Vectorized<float>;
1104   fVec i_scale_vec(i_scale);
1105   fVec i_zero_point_vec(i_zero_point);
1106   fVec i_scale_neg_zp_premul_vec = i_scale_vec * i_zero_point_vec.neg();
1107   fVec zero_vec(0.0f);
1108   fVec three_vec(3.0f);
1109   fVec six_vec(6.0f);
1110 
1111   AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qhardswish", [&]() {
1112     using qVec = Vectorized<scalar_t>;
1113     auto iter = TensorIterator::unary_op(qy, qx);
1114     cpu_kernel_vec(
1115         iter,
1116         [&](scalar_t value) -> scalar_t {
1117           const auto x =
1118               at::native::dequantize_val(i_scale, i_zero_point, value);
1119           const auto y = x * std::min(std::max(x + 3.0f, 0.0f), 6.0f) / 6.0f;
1120           return at::native::quantize_val<scalar_t>(o_scale, o_zero_point, y);
1121         },
1122         [&](qVec value) -> qVec {
1123           auto value_dx = value.dequantize(i_scale_vec, i_zero_point_vec,
1124                                            i_scale_neg_zp_premul_vec);
1125           for (auto & value : value_dx) {
1126             value = value * vec::minimum(
1127               vec::maximum(value + three_vec, zero_vec),
1128               six_vec
1129             ) / six_vec;
1130           }
1131           return qVec::quantize(value_dx, o_scale, o_zero_point, o_inv_scale);
1132         });
1133   });
1134 }
1135 
1136 
qtanh_kernel(const Tensor & qx,Tensor & qy)1137 void qtanh_kernel(const Tensor& qx, Tensor& qy) {
1138   int64_t zero_point = qx.q_zero_point();
1139   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1140   float scale = qx.q_scale();
1141   auto scale_vec = Vectorized<float>(scale);
1142   auto zero_point_vec = Vectorized<float>((float)zero_point);
1143   auto scale_neg_zp_premul_vec = scale_vec * zero_point_vec.neg();
1144 
1145   AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qtanh", [&]() {
1146     // Naive implementation: uses dequantize/execute/quantize routine
1147     // - Output scale is set to 2.0 / 2^(BIT_NUM)
1148     // - For signed types output zero point is set to 0
1149     // - For unsigned types output zero point is set to (qmax + qmin) / 2.0
1150     float output_scale = 0.0078125;  // 2.0 / 512
1151     int64_t output_zero_point = 0;
1152     if (SCALAR_TYPE == at::kQInt32) {
1153       output_scale = 4.656612873077393e-10;  // 2.0 / 2^32
1154     } else if (SCALAR_TYPE == at::kQUInt8) {
1155       output_zero_point = 128;
1156     }
1157     float inv_output_scale = 1.0 / output_scale;
1158 
1159     qy = at::_empty_affine_quantized(
1160         qx.sizes(),
1161         at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
1162         output_scale,
1163         output_zero_point,
1164         std::nullopt);
1165     auto iter = TensorIterator::unary_op(qy, qx);
1166 
1167     using Vec = Vectorized<scalar_t>;
1168     cpu_kernel_vec(
1169         iter,
1170         [&](scalar_t value_qx) -> scalar_t {
1171           const auto value_dx =
1172               at::native::dequantize_val(scale, zero_point, value_qx);
1173           return at::native::quantize_val<scalar_t>(
1174               output_scale, output_zero_point, std::tanh(value_dx));
1175         },
1176         [&](Vec value_qx) -> Vec {
1177           const auto value_dx = value_qx.dequantize(
1178               scale_vec, zero_point_vec, scale_neg_zp_premul_vec);
1179           Vec::float_vec_return_type retvals;
1180           for (const auto idx : c10::irange(Vec::float_num_vecs())) {
1181             retvals[idx] = value_dx[idx].tanh();
1182           }
1183           return Vec::quantize(
1184               retvals, output_scale, output_zero_point, inv_output_scale);
1185         });
1186   });
1187 }
1188 
qelu_kernel(const Tensor & qx,const Scalar & alpha,const Scalar & scale,const Scalar & input_scale,Tensor & qy)1189 void qelu_kernel(
1190     const Tensor& qx,
1191     const Scalar& alpha,
1192     const Scalar& scale,
1193     const Scalar& input_scale,
1194     Tensor& qy) {
1195   // scale and input_scale arguments refer to a generalized ELU formula
1196   // if x >= 0, ELU(x) = x * scale
1197   // if x <= 0, ELU(x) = (exp(x * input_scale) - 1) * scale
1198   // in the normal ELU formula, both are equal to 1
1199   // they are NOT related to the quantization scale term
1200 
1201   int64_t i_zp = qx.q_zero_point();
1202   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1203   float i_scale = qx.q_scale();
1204 
1205   // In a future PR, we can improve on output scale and zero_point
1206   // selection.
1207   int64_t o_zp = qy.q_zero_point();
1208   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1209   float o_scale = qy.q_scale();
1210   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1211   float inv_o_scale = 1.0 / o_scale;
1212 
1213   float alpha_float = alpha.to<float>();
1214   float scale_coef = scale.to<float>();
1215   float input_scale_coef = input_scale.to<float>();
1216 
1217   AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qelu_kernel", [&] {
1218 
1219     auto iter = TensorIterator::unary_op(qy, qx);
1220 
1221     // vectorized
1222     using Vec = Vectorized<float>;
1223     using qVec = Vectorized<scalar_t>;
1224 
1225     Vec zero_vec = Vec(0.0f);
1226     Vec one_vec = Vec(1.0f);
1227     Vec alpha_vec = Vec(alpha_float);
1228     Vec scale_coef_vec = Vec(scale_coef);
1229     Vec input_scale_coef_vec = Vec(input_scale_coef);
1230     Vec i_scale_vec = Vec(i_scale);
1231     Vec i_zero_point_vec = Vec((float)i_zp);
1232     Vec i_scale_neg_zp_premul_vec = i_scale_vec * i_zero_point_vec.neg();
1233 
1234     cpu_kernel_vec(
1235       iter,
1236       [&](scalar_t value_qx) -> scalar_t {
1237         // dequantize
1238         const auto x = at::native::dequantize_val(i_scale, i_zp, value_qx);
1239         // ELU
1240         const auto y = x >= 0
1241           ? x * scale_coef
1242           : ((std::exp(x * input_scale_coef) - 1) * alpha_float * scale_coef);
1243 
1244         // quantize
1245         return at::native::quantize_val<scalar_t>(o_scale, o_zp, y);
1246       },
1247       [&](qVec value_qx) -> qVec {
1248         // dequantize
1249         auto dx_vec_vec = value_qx.dequantize(i_scale_vec, i_zero_point_vec,
1250                                             i_scale_neg_zp_premul_vec);
1251         for (auto & value : dx_vec_vec) {
1252           // quickly check if any elements are below zero
1253           const auto cmp_to_zero = value > zero_vec;
1254 
1255           if (cmp_to_zero.zero_mask()) {
1256 
1257             Vec dx_vec_copy_neg_elu = value * one_vec;
1258             // calculate the negative part of ELU on the copy
1259             dx_vec_copy_neg_elu = dx_vec_copy_neg_elu * input_scale_coef_vec;
1260             dx_vec_copy_neg_elu = dx_vec_copy_neg_elu.exp();
1261             dx_vec_copy_neg_elu = dx_vec_copy_neg_elu - one_vec;
1262             dx_vec_copy_neg_elu = dx_vec_copy_neg_elu * alpha_vec;
1263             // blend
1264             value = Vec::blendv(dx_vec_copy_neg_elu, value,
1265                                         value > zero_vec);
1266           }
1267 
1268           value = value * scale_coef_vec;
1269         }
1270         // quantize
1271         return qVec::quantize(dx_vec_vec, o_scale, o_zp, inv_o_scale);
1272       }
1273     );
1274 
1275   });
1276 }
1277 
1278 // Note: out is assumed to be the same size as self and other.
1279 // Note: Addition is only supported when self and out are of the same dtype.
1280 // Note: other is already assumed to be in int32, i.e., it's
1281 // round(float/self_scale)
1282 template <bool ReLUFused = false>
qadd_scalar_kernel(Tensor & out,const Tensor & self,const Scalar & other)1283 void qadd_scalar_kernel(Tensor& out, const Tensor& self, const Scalar& other) {
1284   int64_t zero_point = out.q_zero_point();
1285   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1286   float scale = out.q_scale();
1287   float inv_scale = 1.0f / scale;
1288   int64_t self_zero_point = self.q_zero_point();
1289   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1290   float self_scale = self.q_scale();
1291 
1292   float multiplier = self_scale * inv_scale;
1293 
1294   AT_DISPATCH_QINT_TYPES(self.scalar_type(), "qadd_scalar", [&]() {
1295     using Vec = Vectorized<scalar_t>;
1296     auto iter = TensorIterator::unary_op(out, self);
1297     auto other_val = other.to<int32_t>();
1298     auto other_vec = Vectorized<c10::qint32>(static_cast<c10::qint32>(other_val));
1299     cpu_kernel_vec(
1300         iter,
1301         [&](scalar_t a) -> scalar_t {
1302           int32_t a_sub_z = static_cast<int32_t>(a.val_) -
1303               static_cast<int32_t>(self_zero_point);
1304           int32_t c = a_sub_z + other_val;
1305           scalar_t res = at::native::requantize_from_int<scalar_t>(
1306               multiplier, zero_point, c);
1307           if constexpr (ReLUFused) {
1308             res.val_ = std::max<scalar_t::underlying>(res.val_, zero_point);
1309           }
1310           return res;
1311         },
1312         [&](Vec a) -> Vec {
1313           Vec::int_vec_return_type a_sub_z =
1314               a.widening_subtract(Vec(static_cast<scalar_t>(self_zero_point)));
1315           Vec::int_vec_return_type c;
1316           for (const auto i : c10::irange(Vec::int_num_vecs())) {
1317             c[i] = a_sub_z[i] + other_vec;
1318           }
1319           Vec rv = Vec::requantize_from_int(c, multiplier, zero_point);
1320           if constexpr (ReLUFused) {
1321             rv = rv.maximum(Vec(static_cast<scalar_t>(zero_point)));
1322           }
1323           return rv;
1324         });
1325   });
1326 }
1327 // Note: out is assumed to be the same size as self and other.
1328 // Note: Addition is only supported when self, other, out are of the same dtype.
1329 template <bool ReLUFused = false>
qadd_kernel(Tensor & out,const Tensor & self,const Tensor & other)1330 void qadd_kernel(Tensor& out, const Tensor& self, const Tensor& other) {
1331   int64_t zero_point = out.q_zero_point();
1332   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1333   float scale = out.q_scale();
1334   float inv_scale = 1.0f / scale;
1335   int64_t self_zero_point = self.q_zero_point();
1336   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1337   float self_scale = self.q_scale();
1338   int64_t other_zero_point = other.q_zero_point();
1339   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1340   float other_scale = other.q_scale();
1341 
1342   // Broadcast out the parameters here to amortize out that cost across
1343   // loop iterations.
1344   // TODO: we can optimize dequantization by doing a premultiplication
1345   // of the zero point by scale and doing FMA on scale*x_q - (scale*zero_point)
1346   auto self_zero_point_vec = Vectorized<float>((float)self_zero_point);
1347   auto self_scale_vec = Vectorized<float>(self_scale);
1348   auto other_zero_point_vec = Vectorized<float>((float)other_zero_point);
1349   auto other_scale_vec = Vectorized<float>(other_scale);
1350 
1351   auto self_scale_neg_zp_premul_vec = self_scale_vec * self_zero_point_vec.neg();
1352   auto other_scale_zp_premul_vec = other_scale_vec * other_zero_point_vec.neg();
1353 
1354   auto iter = TensorIterator::borrowing_binary_op(out, self, other);
1355 
1356   AT_DISPATCH_QINT_TYPES(out.scalar_type(), "qadd", [&]() {
1357     using Vec = Vectorized<scalar_t>;
1358     cpu_kernel_vec(
1359         iter,
1360         [&](scalar_t a, scalar_t b) -> scalar_t {
1361           const auto da =
1362               at::native::dequantize_val(self_scale, self_zero_point, a);
1363           const auto db =
1364               at::native::dequantize_val(other_scale, other_zero_point, b);
1365           float c = da + db;
1366           if (ReLUFused) {
1367             c = std::max<float>(c, 0.0);
1368           }
1369           return at::native::quantize_val<scalar_t>(scale, zero_point, c);
1370         },
1371         [&](Vec a, Vec b) -> Vec {
1372           const auto da = a.dequantize(
1373               self_scale_vec, self_zero_point_vec, self_scale_neg_zp_premul_vec);
1374           const auto db = b.dequantize(
1375               other_scale_vec, other_zero_point_vec, other_scale_zp_premul_vec);
1376           Vec::float_vec_return_type retvals;
1377           for (const auto i : c10::irange(Vec::float_num_vecs())) {
1378             auto c = da[i] + db[i];
1379             if constexpr (ReLUFused) {
1380               c = vec::maximum(c, Vectorized<float>(0.0f));
1381             }
1382             retvals[i] = c;
1383           }
1384           // TODO: fbgemm::Quantize doesn't support taking in the
1385           // pre-broadcasted parameters. We might be able to save some cycles by
1386           // enabling that in the API.
1387           // TODO: specialize fbgemm::Quantize for a single vector and make it
1388           // inlineable. This could help with interleaving as suggested by the
1389           // TensorIterator implementations
1390           auto rv = Vec::quantize(retvals, scale, zero_point, inv_scale);
1391           return rv;
1392         });
1393   });
1394 }
1395 
1396 // Note: out is assumed to be the same size as self and other.
1397 // Note: Multiplication is only supported when self, other, out are of the same
1398 // dtype.
1399 template <bool ReLUFused = false>
qmul_kernel(Tensor & out,const Tensor & self,const Tensor & other)1400 void qmul_kernel(Tensor& out, const Tensor& self, const Tensor& other) {
1401   int64_t zero_point = out.q_zero_point();
1402   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1403   float scale = out.q_scale();
1404   float inv_scale = 1.0f / scale;
1405   int64_t self_zero_point = self.q_zero_point();
1406   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1407   float self_scale = self.q_scale();
1408   int64_t other_zero_point = other.q_zero_point();
1409   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1410   float other_scale = other.q_scale();
1411 
1412   float multiplier = self_scale * other_scale * inv_scale;
1413 
1414   auto iter = TensorIterator::borrowing_binary_op(out, self, other);
1415 
1416   AT_DISPATCH_QINT_TYPES(out.scalar_type(), "qmul", [&]() {
1417     using Vec = Vectorized<scalar_t>;
1418     cpu_kernel_vec(
1419         iter,
1420         [&](scalar_t a, scalar_t b) -> scalar_t {
1421           int32_t a_sub_z = static_cast<int32_t>(a.val_) -
1422               static_cast<int32_t>(self_zero_point);
1423           int32_t b_sub_z = static_cast<int32_t>(b.val_) -
1424               static_cast<int32_t>(other_zero_point);
1425           int32_t c = a_sub_z * b_sub_z;
1426           scalar_t res = at::native::requantize_from_int<scalar_t>(
1427               multiplier, zero_point, c);
1428           if constexpr (ReLUFused) {
1429             res.val_ = std::max<scalar_t::underlying>(res.val_, zero_point);
1430           }
1431           return res;
1432         },
1433         [&](Vec a, Vec b) -> Vec {
1434           Vec::int_vec_return_type a_sub_zp =
1435               a.widening_subtract(Vec(static_cast<scalar_t>(self_zero_point)));
1436           Vec::int_vec_return_type b_sub_zp =
1437               b.widening_subtract(Vec(static_cast<scalar_t>(other_zero_point)));
1438           Vec::int_vec_return_type c;
1439           for (const auto i : c10::irange(Vec::int_num_vecs())) {
1440             c[i] = a_sub_zp[i] * b_sub_zp[i];
1441           }
1442           Vec rv = Vec::requantize_from_int(c, multiplier, zero_point);
1443           if constexpr (ReLUFused) {
1444             rv = rv.maximum(Vec(static_cast<scalar_t>(zero_point)));
1445           }
1446           return rv;
1447         });
1448   });
1449 }
1450 
1451 template <typename scalar_t, typename scalar_t_underlying>
_qmaxpool_2d_nhwc_kernel(const Tensor & qx,int64_t iC,int64_t iH,int64_t iW,int64_t oH,int64_t oW,int64_t kH,int64_t kW,int64_t sH,int64_t sW,int64_t pH,int64_t pW,int64_t dH,int64_t dW,Tensor & qy)1452 void _qmaxpool_2d_nhwc_kernel(
1453     const Tensor& qx,
1454     int64_t iC, // input/output channels
1455     int64_t iH,
1456     int64_t iW, // input sizes
1457     int64_t oH,
1458     int64_t oW, // output sizes
1459     int64_t kH,
1460     int64_t kW, // kernel size
1461     int64_t sH,
1462     int64_t sW, // strides
1463     int64_t pH,
1464     int64_t pW, // padding
1465     int64_t dH,
1466     int64_t dW, // dilation
1467     Tensor& qy) {
1468     scalar_t* idata = static_cast<scalar_t*>(qx.data_ptr());
1469     scalar_t* odata = static_cast<scalar_t*>(qy.data_ptr());
1470 
1471     int64_t nBatch = qx.size(0);
1472     at::parallel_for(0, nBatch * oH * oW, 0, [&](int64_t begin, int64_t end) {
1473       int64_t b{0}, row{0}, col{0};
1474       data_index_init(begin, b, nBatch, row, oH, col, oW);
1475 
1476       for (const auto i : c10::irange(begin, end)) {
1477         auto* i_p = reinterpret_cast<scalar_t_underlying*>(idata + b * iW * iH * iC);
1478         auto* o_p = reinterpret_cast<scalar_t_underlying*>(odata + i * iC);
1479 
1480         // Loop over reduction block
1481         int64_t h_start = row * sH - pH;
1482         int64_t w_start = col * sW - pW;
1483         int64_t h_end = std::min(h_start + (kH - 1) * dH + 1, iH);
1484         int64_t w_end = std::min(w_start + (kW - 1) * dW + 1, iW);
1485         while (h_start < 0)
1486           h_start += dH;
1487         while (w_start < 0)
1488           w_start += dW;
1489 
1490         int64_t c = 0;
1491 
1492         // Interleaved vector loop 4x
1493         constexpr auto vec_width = Vectorized<scalar_t>::size();
1494         for (; c + 4 * vec_width <= iC; c += 4 * vec_width) {
1495           Vectorized<scalar_t> acc{
1496               scalar_t(std::numeric_limits<scalar_t_underlying>::lowest())};
1497           // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
1498           Vectorized<scalar_t> accs[4] = {acc, acc, acc, acc};
1499           int64_t tcntr = 0;
1500           int64_t x, y;
1501           for (y = h_start; y < h_end; y += dH) {
1502             for (x = w_start; x < w_end; x += dW) {
1503               for (const auto i : c10::irange(4)) {
1504                 tcntr = y * iW + x;
1505                 auto vals = Vectorized<scalar_t>::loadu(
1506                     i_p + tcntr * iC + c + Vectorized<scalar_t>::size() * i);
1507                 accs[i] = vec::maximum(accs[i], vals);
1508               }
1509             } // for x
1510           } // for y
1511           for (const auto i : c10::irange(4)) {
1512             accs[i].store(o_p + c + Vectorized<scalar_t>::size() * i);
1513           }
1514         } // for c
1515 
1516         // Vector loop
1517         for (; c + vec_width <= iC; c += vec_width) {
1518           Vectorized<scalar_t> acc{
1519               scalar_t(std::numeric_limits<scalar_t_underlying>::lowest())};
1520           int64_t tcntr = 0;
1521           int64_t x, y;
1522           for (y = h_start; y < h_end; y += dH) {
1523             for (x = w_start; x < w_end; x += dW) {
1524               tcntr = y * iW + x;
1525               auto vals = Vectorized<scalar_t>::loadu(i_p + tcntr * iC + c);
1526               acc = vec::maximum(acc, vals);
1527             } // for x
1528           } // for y
1529           acc.store(o_p + c);
1530         } // for c
1531 
1532         for (; c < iC; ++c) {
1533           auto max_val = std::numeric_limits<scalar_t_underlying>::lowest();
1534           int64_t tcntr = 0;
1535           int64_t x, y;
1536           for (y = h_start; y < h_end; y += dH) {
1537             for (x = w_start; x < w_end; x += dW) {
1538               tcntr = y * iW + x;
1539               auto val = *(i_p + tcntr * iC + c);
1540               max_val = std::max(max_val, val);
1541             } // for x
1542           } // for y
1543 
1544           o_p[c] = max_val;
1545         } // for c
1546 
1547         data_index_step(b, nBatch, row, oH, col, oW);
1548       }
1549     });
1550 }
1551 
qmaxpool_2d_nhwc_kernel(const Tensor & qx,int64_t iC,int64_t iH,int64_t iW,int64_t oH,int64_t oW,int64_t kH,int64_t kW,int64_t sH,int64_t sW,int64_t pH,int64_t pW,int64_t dH,int64_t dW,Tensor & qy)1552 void qmaxpool_2d_nhwc_kernel(
1553     const Tensor& qx,
1554     int64_t iC, // input/output channels
1555     int64_t iH,
1556     int64_t iW, // input sizes
1557     int64_t oH,
1558     int64_t oW, // output sizes
1559     int64_t kH,
1560     int64_t kW, // kernel size
1561     int64_t sH,
1562     int64_t sW, // strides
1563     int64_t pH,
1564     int64_t pW, // padding
1565     int64_t dH,
1566     int64_t dW, // dilation
1567     Tensor& qy) {
1568   if (qx.scalar_type() == ScalarType::Byte) {
1569     AT_DISPATCH_INTEGRAL_TYPES(qx.scalar_type(), "max_pool2d_nhwc", [&]() {
1570       _qmaxpool_2d_nhwc_kernel<scalar_t, scalar_t>(qx, iC, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, qy);
1571     });
1572   } else {
1573     AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "max_pool2d_nhwc", [&]() {
1574       _qmaxpool_2d_nhwc_kernel<scalar_t, scalar_t::underlying>(qx, iC, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, qy);
1575     });
1576   }
1577 }
1578 
qmaxpool_3d_nthwc_kernel(const Tensor & qx,int64_t iC,int64_t iT,int64_t iH,int64_t iW,int64_t oT,int64_t oH,int64_t oW,int64_t kT,int64_t kH,int64_t kW,int64_t sT,int64_t sH,int64_t sW,int64_t pT,int64_t pH,int64_t pW,int64_t dT,int64_t dH,int64_t dW,Tensor & qy)1579 void qmaxpool_3d_nthwc_kernel(
1580     const Tensor& qx,
1581     int64_t iC, // input/output channels
1582     int64_t iT,
1583     int64_t iH,
1584     int64_t iW, // input sizes
1585     int64_t oT,
1586     int64_t oH,
1587     int64_t oW, // output sizes
1588     int64_t kT,
1589     int64_t kH,
1590     int64_t kW, // kernel size
1591     int64_t sT,
1592     int64_t sH,
1593     int64_t sW, // strides
1594     int64_t pT,
1595     int64_t pH,
1596     int64_t pW, // padding
1597     int64_t dT,
1598     int64_t dH,
1599     int64_t dW, // dilation
1600     Tensor& qy) {
1601   AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "max_pool3d_nthwc", [&]() {
1602     scalar_t* idata = static_cast<scalar_t*>(qx.data_ptr());
1603     scalar_t* odata = static_cast<scalar_t*>(qy.data_ptr());
1604     int64_t nBatch = qx.size(0);
1605     at::parallel_for(0, nBatch * oT * oH * oW, 0, [&](int64_t begin, int64_t end) {
1606       int64_t b{0}, time{0}, row{0}, col{0};
1607 
1608       data_index_init(begin, b, nBatch, time, oT, row, oH, col, oW);
1609 
1610       for (const auto i : c10::irange(begin, end)) {
1611         auto* i_p = reinterpret_cast<scalar_t::underlying*>(idata + b * iT * iW * iH * iC);
1612         auto* o_p = reinterpret_cast<scalar_t::underlying*>(odata + i * iC);
1613 
1614         // Loop over reduction block
1615         int64_t t_start = time * sT - pT;
1616         int64_t h_start = row * sH - pH;
1617         int64_t w_start = col * sW - pW;
1618         int64_t t_end = std::min(t_start + (kT - 1) * dT + 1, iT);
1619         int64_t h_end = std::min(h_start + (kH - 1) * dH + 1, iH);
1620         int64_t w_end = std::min(w_start + (kW - 1) * dW + 1, iW);
1621         while (t_start < 0)
1622           t_start += dT;
1623         while (h_start < 0)
1624           h_start += dH;
1625         while (w_start < 0)
1626           w_start += dW;
1627 
1628         int64_t c = 0;
1629         constexpr auto vec_width = Vectorized<scalar_t>::size();
1630         // Vector loop
1631         for (; c + vec_width <= iC; c += vec_width) {
1632           Vectorized<scalar_t> acc{
1633               scalar_t(std::numeric_limits<scalar_t::underlying>::lowest())};
1634           int64_t tcntr = 0;
1635           int64_t t, x, y;
1636           for (t = t_start; t < t_end; t += dT) {
1637             for (y = h_start; y < h_end; y += dH) {
1638               for (x = w_start; x < w_end; x += dW) {
1639                 tcntr = t * iH * iW + y * iW + x;
1640                 auto vals = Vectorized<scalar_t>::loadu(i_p + tcntr * iC + c);
1641                 acc = vec::maximum(acc, vals);
1642               } // for x
1643             } // for y
1644           } // for t
1645           acc.store(o_p + c);
1646         } // for c
1647 
1648         for (; c < iC; ++c) {
1649           auto max_val = std::numeric_limits<scalar_t::underlying>::lowest();
1650           int64_t tcntr = 0;
1651           int64_t t, x, y;
1652           for (t = t_start; t < t_end; t += dT) {
1653             for (y = h_start; y < h_end; y += dH) {
1654               for (x = w_start; x < w_end; x += dW) {
1655                 tcntr = t * iH * iW + y * iW + x;
1656                 auto val = *(i_p + tcntr * iC + c);
1657                 max_val = std::max(max_val, val);
1658               } // for x
1659             } // for y
1660           } // for t
1661           o_p[c] = max_val;
1662         } // for c
1663         data_index_step(b, nBatch, time, oT, row, oH, col, oW);
1664       }
1665 
1666     });
1667 
1668   });
1669 }
1670 
1671 template <typename T>
do_avg_pool_nhwc_on_AVX_n(const typename T::underlying * i_p,typename T::underlying * o_p,int & c_start,int input_zero_point_m_size,int output_zero_point,float multiplier,int dstart,int dend,int hstart,int hend,int wstart,int wend,int dsize,int hsize,int wsize,int csize)1672 void do_avg_pool_nhwc_on_AVX_n(
1673     const typename T::underlying* i_p,
1674     typename T::underlying* o_p,
1675     int& c_start,
1676     int input_zero_point_m_size,
1677     int output_zero_point,
1678     float multiplier,
1679     int dstart,
1680     int dend,
1681     int hstart,
1682     int hend,
1683     int wstart,
1684     int wend,
1685     int dsize,
1686     int hsize,
1687     int wsize,
1688     int csize) {
1689 #if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER)
1690   // buffer for channel accumulator, used to interchange channel-loop
1691   // to inner-most, so that memory access of the input tensor data is
1692   // continuous.
1693 #ifdef CPU_CAPABILITY_AVX2
1694   constexpr int cb_size = 16;
1695 #else
1696   constexpr int cb_size = 8;
1697 #endif
1698   constexpr int vec_width = Vectorized<T>::size() / 4;
1699   constexpr int cb_step = cb_size * vec_width;
1700   Vectorized<int32_t> acc_buffer[cb_size];
1701   Vectorized<float> acc_buffer_fp[cb_size];
1702 
1703 #ifdef CPU_CAPABILITY_AVX2
1704   if (vec_width == 8) {
1705 #else
1706   if (vec_width == 16) {
1707 #endif
1708     for (int c = c_start; c < csize; c += cb_step) {
1709       int cend = std::min(cb_size, (csize - c) / vec_width);
1710       // initialize loop
1711       for (const auto ic : c10::irange(cend)) {
1712         acc_buffer[ic] = Vectorized<int32_t>(input_zero_point_m_size);
1713       }
1714       // compute loop
1715       for (const auto id : c10::irange(dstart, dend)) {
1716         for (const auto ih : c10::irange(hstart, hend)) {
1717           for (const auto iw : c10::irange(wstart, wend)) {
1718             const int i_idx =
1719                 (id * wsize * hsize + ih * wsize + iw) *
1720                     csize +
1721                 c;
1722             for (const auto ic : c10::irange(cend)) {
1723               auto vals = vec::convert_to_int32<typename T::underlying>(
1724                   i_p + i_idx + ic * vec_width);
1725               acc_buffer[ic] = acc_buffer[ic] + vals;
1726             }
1727           }
1728         }
1729       }
1730       // convert int32 accumulative to fp32
1731       vec::convert((int*)acc_buffer, (float*)acc_buffer_fp, cend * vec_width);
1732 
1733       // first quantize using AVX2 or AVX512 using 32 lanes, then 8, finally falls
1734       // back to single
1735 #ifdef CPU_CAPABILITY_AVX2
1736       QuantizeAvx2<typename T::underlying>(
1737           (float*)acc_buffer_fp,
1738           o_p + c,
1739           cend * vec_width,
1740           multiplier,
1741           output_zero_point);
1742 #else
1743       QuantizeAvx512<typename T::underlying>(
1744           (float*)acc_buffer_fp,
1745           o_p + c,
1746           cend * vec_width,
1747           multiplier,
1748           output_zero_point);
1749 #endif
1750     }
1751     c_start = csize / vec_width * vec_width;
1752   }
1753 #endif
1754 }
1755 
1756 template <typename T>
1757 void do_avg_pool_on_AVX_n(
1758     typename T::underlying* i_p,
1759     typename T::underlying* o_p,
1760     int64_t& c,
1761     int64_t channel_size,
1762     int64_t channel_multiplier,
1763     int32_t input_zero_point_m_size,
1764     int32_t output_zero_point,
1765     float multiplier,
1766     int64_t dstart,
1767     int64_t dend,
1768     int64_t hstart,
1769     int64_t hend,
1770     int64_t wstart,
1771     int64_t wend,
1772     int64_t stride_C,
1773     int64_t stride_D,
1774     int64_t stride_H,
1775     int64_t stride_W) {
1776 #if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER)
1777   constexpr int vec_width = Vectorized<T>::size() / 4;
1778 #ifdef CPU_CAPABILITY_AVX2
1779   if (vec_width == 8) {
1780 #else
1781   if (vec_width == 16) {
1782 #endif
1783     for (; c + vec_width <= channel_size; c += vec_width) {
1784       int64_t tcntr = 0;
1785 
1786       Vectorized<int32_t> acc(input_zero_point_m_size);
1787       for (const auto id : c10::irange(dstart, dend)) {
1788         for (const auto ih : c10::irange(hstart, hend)) {
1789           for (const auto iw : c10::irange(wstart, wend)) {
1790             tcntr = id * stride_D + ih * stride_H + iw * stride_W;
1791             auto vals = vec::convert_to_int32<typename T::underlying>(
1792                 i_p + tcntr * channel_multiplier + c * stride_C);
1793             acc = acc + vals;
1794           }
1795         }
1796       }
1797       int32_t acc_int[vec_width];
1798       float acc_fp[vec_width];
1799       acc.store(acc_int);
1800       vec::convert(acc_int, acc_fp, vec_width);
1801       at::native::quantize_vec<T>(
1802           1.0f / multiplier,
1803           output_zero_point,
1804           acc_fp,
1805           reinterpret_cast<T*>(o_p + c),
1806           vec_width);
1807     }
1808   }
1809 #endif
1810 }
1811 
1812 template <typename T>
1813 void _qadaptive_avg_pool_kernel(
1814     const Tensor& qx,
1815     Tensor& qy,
1816     int64_t nBatch,
1817     int64_t sizeC,
1818     int64_t isizeD,  // Set to 1 for 2d
1819     int64_t isizeH,
1820     int64_t isizeW,
1821     int64_t osizeD,  // Set to 1 for 2d
1822     int64_t osizeH,
1823     int64_t osizeW,
1824     int64_t istrideB,
1825     int64_t istrideC,
1826     int64_t istrideD,  // Set to 1 for 2d
1827     int64_t istrideH,
1828     int64_t istrideW) {
1829 
1830   T* idata = static_cast<T*>(qx.data_ptr());
1831   T* odata = static_cast<T*>(qy.data_ptr());
1832 
1833   const float input_scale = qx.q_scale();
1834   const float output_scale = qy.q_scale();
1835   const int input_zero_point = qx.q_zero_point();
1836   const int output_zero_point = qy.q_zero_point();
1837 
1838   at::parallel_for(0, nBatch, 0, [&](int64_t batch_start, int64_t batch_end) {
1839     for (const auto b : c10::irange(batch_start, batch_end)) {
1840       auto* i_p = reinterpret_cast<typename T::underlying*>(
1841           idata + b * istrideB);
1842 
1843       for (const auto od : c10::irange(osizeD)) {
1844         int istartD = (int)std::floor((float)(od * isizeD) / osizeD);
1845         int iendD = (int)std::ceil((float)((od + 1) * isizeD) / osizeD);
1846         int kD = iendD - istartD;
1847         for (const auto oh : c10::irange(osizeH)) {
1848           int istartH = (int)std::floor((float)(oh * isizeH) / osizeH);
1849           int iendH = (int)std::ceil((float)((oh + 1) * isizeH) / osizeH);
1850           int kH = iendH - istartH;
1851           for (const auto ow : c10::irange(osizeW)) {
1852             auto* o_p = reinterpret_cast<typename T::underlying*>(
1853                 odata +
1854                 b * osizeD * osizeH * osizeW * sizeC +
1855                 od * osizeH * osizeW * sizeC +
1856                 oh * osizeW * sizeC +
1857                 ow * sizeC);
1858             int istartW = (int)std::floor((float)(ow * isizeW) / osizeW);
1859             int iendW = (int)std::ceil((float)((ow + 1) * isizeW) / osizeW);
1860             int kW = iendW - istartW;
1861             int size = kD * kH * kW;
1862             float multiplier = input_scale / output_scale / size;
1863             int input_zero_point_m_size = -input_zero_point * size;
1864             int64_t c = 0;
1865             // For int8 or uint8quantization, we implicitly use int32 as
1866             // accumulation Or else, it will go to the slow path
1867             // TODO: support 16bit, 32bit, and etc.
1868             auto* internal_i_p = i_p +
1869                                 istartD * istrideD +
1870                                 istartH * istrideH +
1871                                 istartW * istrideW;
1872 
1873             // Note: If AVX is not available, `do_avg_pool_on_AVX_n is a noop.
1874             //       In that case, the following loop takes over
1875             // TODO: more vectorization with loop interleaving
1876             do_avg_pool_on_AVX_n<T>(
1877                 internal_i_p,
1878                 o_p,
1879                 c,
1880                 sizeC,
1881                 1,
1882                 input_zero_point_m_size,
1883                 output_zero_point,
1884                 multiplier,
1885                 0,
1886                 kD,
1887                 0,
1888                 kH,
1889                 0,
1890                 kW,
1891                 istrideC,
1892                 istrideD,
1893                 istrideH,
1894                 istrideW);
1895             // 1) The following loop handles the remaining channels
1896             // 2) It also handles the Non-AVX2 path
1897             for (; c < sizeC; ++c) {
1898               int32_t acc_int32 = input_zero_point_m_size;
1899               int64_t tcntr = 0;
1900               for (const auto id : c10::irange(kD)) {
1901                 for (const auto ih : c10::irange(kH)) {
1902                   for (const auto iw : c10::irange(kW)) {
1903                     tcntr = id * istrideD +
1904                         ih * istrideH +
1905                         iw * istrideW;
1906                     auto val = *(internal_i_p + tcntr + c * istrideC);
1907                     acc_int32 += val;
1908                   }
1909                 }
1910               }
1911               // clamp
1912               o_p[c] = at::native::quantize_val<T>(1.0f / multiplier,
1913                                                           output_zero_point,
1914                                                           acc_int32).val_;
1915             } // c
1916           } // oh
1917         } // ow
1918       } // od
1919     }
1920   });
1921 }
1922 
1923 void qadaptive_avg_pool2d_nhwc_kernel(
1924     const Tensor& qx,
1925     Tensor& qy,
1926     int64_t nBatch,
1927     int64_t sizeC,
1928     int64_t isizeH,
1929     int64_t isizeW,
1930     int64_t osizeH,
1931     int64_t osizeW,
1932     int64_t istrideB,
1933     int64_t istrideC,
1934     int64_t istrideH,
1935     int64_t istrideW) {
1936     AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "adaptive_avg_pool2d_nhwc", [&]() {
1937         _qadaptive_avg_pool_kernel<scalar_t>(
1938           qx,
1939           qy,
1940           nBatch,
1941           sizeC,
1942           /*isizeD=*/1,
1943           isizeH,
1944           isizeW,
1945           /*osizeD=*/1,
1946           osizeH,
1947           osizeW,
1948           istrideB,
1949           istrideC,
1950           /*istrideD=*/1,
1951           istrideH,
1952           istrideW);
1953       }
1954     );
1955 }
1956 
1957 void qadaptive_avg_pool3d_ndhwc_kernel(
1958     const Tensor& qx,
1959     Tensor& qy,
1960     int64_t nBatch,
1961     int64_t sizeC,
1962     int64_t isizeD,
1963     int64_t isizeH,
1964     int64_t isizeW,
1965     int64_t osizeD,
1966     int64_t osizeH,
1967     int64_t osizeW,
1968     int64_t istrideB,
1969     int64_t istrideC,
1970     int64_t istrideD,
1971     int64_t istrideH,
1972     int64_t istrideW) {
1973   AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "adaptive_avg_pool3d_ndhwc", [&]() {
1974     _qadaptive_avg_pool_kernel<scalar_t>(
1975       qx,
1976       qy,
1977       nBatch,
1978       sizeC,
1979       isizeD,
1980       isizeH,
1981       isizeW,
1982       osizeD,
1983       osizeH,
1984       osizeW,
1985       istrideB,
1986       istrideC,
1987       istrideD,
1988       istrideH,
1989       istrideW);
1990     }
1991   );
1992 }
1993 
1994 template <typename T>
1995 void _qavg_pool_nhwc_kernel(
1996     const Tensor& qx,
1997     Tensor& qy,
1998     int64_t nBatch,
1999     int64_t nInputPlane,
2000     int64_t inputWidth,
2001     int64_t inputHeight,
2002     int64_t inputDepth,
2003     int64_t outputWidth,
2004     int64_t outputHeight,
2005     int64_t outputDepth,
2006     int kW,
2007     int kH,
2008     int kD,
2009     int dW,
2010     int dH,
2011     int dD,
2012     int padW,
2013     int padH,
2014     int padD,
2015     bool count_include_pad,
2016     std::optional<int64_t> divisor_override) {
2017   T* idata = static_cast<T*>(qx.data_ptr());
2018   T* odata = static_cast<T*>(qy.data_ptr());
2019   int strideC = 1;
2020   int strideW = strideC * nInputPlane;
2021   int istrideH = strideW * inputWidth;
2022   int istrideD = istrideH * inputHeight;
2023   int istrideB = istrideD * inputDepth;
2024 
2025   // lift these operations outside the loop to reduce access overheads
2026   float input_scale = qx.q_scale();
2027   float output_scale = qy.q_scale();
2028   int input_zero_point = qx.q_zero_point();
2029   int output_zero_point = qy.q_zero_point();
2030   int64_t divisor_override_factor =
2031       divisor_override.has_value() ? divisor_override.value() : 0;
2032 
2033   at::parallel_for(0, nBatch * outputDepth * outputHeight * outputWidth, 0, [&](int64_t begin, int64_t end) {
2034     int64_t b{0}, od{0}, oh{0}, ow{0};
2035     data_index_init(begin, b, nBatch, od, outputDepth, oh, outputHeight, ow, outputWidth);
2036 
2037     for (const auto i : c10::irange(begin, end)) {
2038       auto* i_p = reinterpret_cast<typename T::underlying*>(idata + b * istrideB);
2039       auto* o_p = reinterpret_cast<typename T::underlying*>(odata + i * strideW);
2040       int dstart = od * dD - padD;
2041       int hstart = oh * dH - padH;
2042       int wstart = ow * dW - padW;
2043 
2044       int dend = std::min(dstart + kD, (int)inputDepth + padD);
2045       int hend = std::min(hstart + kH, (int)inputHeight + padH);
2046       int wend = std::min(wstart + kW, (int)inputWidth + padW);
2047       int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
2048 
2049       dstart = std::max(dstart, 0);
2050       hstart = std::max(hstart, 0);
2051       wstart = std::max(wstart, 0);
2052       dend = std::min(dend, (int)inputDepth);
2053       hend = std::min(hend, (int)inputHeight);
2054       wend = std::min(wend, (int)inputWidth);
2055 
2056       int size = (dend - dstart) * (hend - hstart) * (wend - wstart);
2057       int divide_size = count_include_pad ? pool_size : size;
2058       int divide_factor =
2059           divisor_override_factor ? divisor_override_factor : divide_size;
2060       float multiplier = input_scale / output_scale  / divide_factor;
2061       int input_zero_point_m_size = -input_zero_point * size;
2062 
2063       int c_start = 0;
2064 
2065       // For int8 quantization, we implicitly use int32 as accumulation
2066       // Or else, it will go to the slow path
2067       // TODO: support 16bit, 32bit, and etc.
2068       do_avg_pool_nhwc_on_AVX_n<T>(
2069           i_p,
2070           o_p,
2071           c_start,
2072           input_zero_point_m_size,
2073           output_zero_point,
2074           multiplier,
2075           dstart,
2076           dend,
2077           hstart,
2078           hend,
2079           wstart,
2080           wend,
2081           inputDepth,
2082           inputHeight,
2083           inputWidth,
2084           nInputPlane);
2085 
2086       // 1) The following loop handles the remaining channels
2087       // 2) It also handles the Non-AVX2 path
2088       for (const auto c: c10::irange(c_start, nInputPlane)) {
2089         int32_t acc_int32 = input_zero_point_m_size;
2090         for (const auto id : c10::irange(dstart, dend)) {
2091           for (const auto ih : c10::irange(hstart, hend)) {
2092             for (const auto iw : c10::irange(wstart, wend)) {
2093               auto val =
2094                   *(i_p + id * istrideD + ih * istrideH + iw * strideW +
2095                   c * strideC);
2096               acc_int32 += val;
2097             }
2098           }
2099        }
2100        double acc_fp = acc_int32 * 1.0;
2101        // clamp
2102        o_p[c] = at::native::quantize_val<T>(
2103            1.0f / multiplier, output_zero_point, acc_fp)
2104            .val_;
2105       } // c
2106 
2107       data_index_step(b, nBatch, od, outputDepth, oh, outputHeight, ow, outputWidth);
2108     }
2109   });
2110 }
2111 
2112 void qavg_pool2d_nhwc_kernel(
2113     const Tensor& qx,
2114     Tensor& qy,
2115     int64_t b,
2116     int64_t nInputPlane,
2117     int64_t inputWidth,
2118     int64_t inputHeight,
2119     int64_t outputWidth,
2120     int64_t outputHeight,
2121     int kW,
2122     int kH,
2123     int dW,
2124     int dH,
2125     int padW,
2126     int padH,
2127     bool count_include_pad,
2128     std::optional<int64_t> divisor_override) {
2129   AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "avg_pool2d_nhwc", [&]() {
2130     _qavg_pool_nhwc_kernel<scalar_t>(
2131       qx,
2132       qy,
2133       b,
2134       nInputPlane,
2135       inputWidth,
2136       inputHeight,
2137       1,
2138       outputWidth,
2139       outputHeight,
2140       1,
2141       kW,
2142       kH,
2143       1,
2144       dW,
2145       dH,
2146       1,
2147       padW,
2148       padH,
2149       0,
2150       count_include_pad,
2151       divisor_override);
2152   });
2153 }
2154 
2155 void qavg_pool3d_nhwc_kernel(
2156     const Tensor& qx,
2157     Tensor& qy,
2158     int64_t b,
2159     int64_t nInputPlane,
2160     int64_t inputWidth,
2161     int64_t inputHeight,
2162     int64_t inputDepth,
2163     int64_t outputWidth,
2164     int64_t outputHeight,
2165     int64_t outputDepth,
2166     int kW,
2167     int kH,
2168     int kD,
2169     int dW,
2170     int dH,
2171     int dD,
2172     int padW,
2173     int padH,
2174     int padD,
2175     bool count_include_pad,
2176     std::optional<int64_t> divisor_override) {
2177   AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "avg_pool3d_nhwc", [&]() {
2178     _qavg_pool_nhwc_kernel<scalar_t>(
2179       qx,
2180       qy,
2181       b,
2182       nInputPlane,
2183       inputWidth,
2184       inputHeight,
2185       inputDepth,
2186       outputWidth,
2187       outputHeight,
2188       outputDepth,
2189       kW,
2190       kH,
2191       kD,
2192       dW,
2193       dH,
2194       dD,
2195       padW,
2196       padH,
2197       padD,
2198       count_include_pad,
2199       divisor_override);
2200   });
2201 }
2202 
2203 template <typename T>
2204 int64_t do_quantized_bilinear_on_AVX_n(
2205     const typename T::underlying*& pos1,
2206     typename T::underlying*& pos2,
2207     int64_t input_width,
2208     int64_t output_height,
2209     int64_t output_width,
2210     int64_t channels,
2211     int32_t output_zero_point,
2212     int32_t input_zero_point,
2213     float inverse_scale,
2214     const float h0lambda,
2215     const float h1lambda,
2216     const float w0lambda,
2217     const float w1lambda,
2218     const int64_t h1p,
2219     const int64_t w1p) {
2220   int64_t c = 0;
2221 #if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER)
2222   constexpr auto vec_width = Vectorized<T>::size() / 4;
2223 #ifdef CPU_CAPABILITY_AVX2
2224   if (vec_width == 8) {
2225 #else
2226   if (vec_width == 16) {
2227 #endif
2228     for (; c + vec_width <= channels; c += vec_width) {
2229       Vectorized<float> pos1_fp_v[4];
2230       Vectorized<int32_t> pos1_int_v[4];
2231       pos1_int_v[0] = vec::convert_to_int32<typename T::underlying>(pos1);
2232       pos1_int_v[1] = vec::convert_to_int32<typename T::underlying>(
2233           pos1 + w1p * channels);
2234       pos1_int_v[2] = vec::convert_to_int32<typename T::underlying>(
2235           pos1 + h1p * input_width * channels);
2236       pos1_int_v[3] = vec::convert_to_int32<typename T::underlying>(
2237           pos1 + (h1p * input_width + w1p) * channels);
2238       for (const auto i : c10::irange(4)) {
2239         int32_t pos1_int[vec_width];
2240         float pos1_fp[vec_width];
2241         pos1_int_v[i].store(pos1_int);
2242         vec::convert(pos1_int, pos1_fp, vec_width);
2243         pos1_fp_v[i] = Vectorized<float>::loadu(pos1_fp, 8);
2244       }
2245       Vectorized<float> h0lambda_v(h0lambda);
2246       Vectorized<float> h1lambda_v(h1lambda);
2247       Vectorized<float> w0lambda_v(w0lambda);
2248       Vectorized<float> w1lambda_v(w1lambda);
2249       Vectorized<float> input_zero_point_v(input_zero_point);
2250       Vectorized<float> result =
2251           h0lambda_v * (w0lambda_v * pos1_fp_v[0] + w1lambda_v * pos1_fp_v[1]) +
2252           h1lambda_v * (w0lambda_v * pos1_fp_v[2] + w1lambda_v * pos1_fp_v[3]) -
2253           input_zero_point_v;
2254       float result_fp[vec_width];
2255       result.store(result_fp);
2256       at::native::quantize_vec<T>(
2257           inverse_scale,
2258           output_zero_point,
2259           result_fp,
2260           reinterpret_cast<T*>(pos2),
2261           vec_width);
2262       pos1 += vec_width;
2263       pos2 += vec_width;
2264     }
2265   }
2266 #endif
2267   return c;
2268 }
2269 
2270 void qupsample_bilinear2d_nhwc_kernel(
2271     Tensor& output,
2272     const Tensor& input,
2273     int64_t input_height,
2274     int64_t input_width,
2275     int64_t output_height,
2276     int64_t output_width,
2277     int64_t nbatch,
2278     int64_t channels,
2279     bool align_corners,
2280     std::optional<double> scales_h,
2281     std::optional<double> scales_w) {
2282   AT_DISPATCH_QINT_TYPES(input.scalar_type(), "upsample_bilinear2d_nhwc", [&]() {
2283     auto* idata = static_cast<scalar_t*>(input.data_ptr());
2284     auto* odata = static_cast<scalar_t*>(output.data_ptr());
2285     float inverse_scale = output.q_scale() / input.q_scale();
2286     const auto rheight = area_pixel_compute_scale<float>(
2287         input_height, output_height, align_corners, scales_h);
2288     const auto rwidth = area_pixel_compute_scale<float>(
2289         input_width, output_width, align_corners, scales_w);
2290 
2291     auto input_q_zero_point = input.q_zero_point();
2292     auto output_q_zero_point = output.q_zero_point();
2293     at::parallel_for(0, nbatch * output_height * output_width, 0, [&](int64_t begin, int64_t end) {
2294       int64_t b{0}, h2{0}, w2{0};
2295       data_index_init(begin, b, nbatch, h2, output_height, w2, output_width);
2296 
2297       for (C10_UNUSED const auto i : c10::irange(begin, end)) {
2298         auto* i_p = reinterpret_cast<typename scalar_t::underlying*>(
2299             idata + b * input_height * input_width * channels);
2300         auto* o_p = reinterpret_cast<typename scalar_t::underlying*>(
2301             odata + b * output_height * output_width * channels);
2302 
2303         const auto h1r = area_pixel_compute_source_index<float>(
2304             rheight, h2, align_corners, /*cubic=*/false);
2305 
2306         const int64_t h1 = h1r;
2307         const int64_t h1p = (h1 < input_height - 1) ? 1 : 0;
2308         const float h1lambda = h1r - h1;
2309         const float h0lambda = static_cast<float>(1.) - h1lambda;
2310 
2311         const auto w1r = area_pixel_compute_source_index<float>(
2312             rwidth, w2, align_corners, /*cubic=*/false);
2313         const int64_t w1 = w1r;
2314         const int64_t w1p = (w1 < input_width - 1) ? 1 : 0;
2315 
2316         const float w1lambda = w1r - w1;
2317         const float w0lambda = static_cast<float>(1.) - w1lambda;
2318 
2319         int64_t c = 0;
2320         // We use float32 to do the computation
2321         const typename scalar_t::underlying* pos1 =
2322             i_p + (h1 * input_width + w1) * channels;
2323         typename scalar_t::underlying* pos2 =
2324             o_p + (h2 * output_width + w2) * channels;
2325         // We have to isolate this function out because the VS does not
2326         // expand the macro correctly.
2327         c = do_quantized_bilinear_on_AVX_n<scalar_t>(
2328             pos1,
2329             pos2,
2330             input_width,
2331             output_height,
2332             output_width,
2333             channels,
2334             output_q_zero_point,
2335             input_q_zero_point,
2336             inverse_scale,
2337             h0lambda,
2338             h1lambda,
2339             w0lambda,
2340             w1lambda,
2341             h1p,
2342             w1p);
2343         // 1) The following loop handles the remaining channels
2344         // 2) It also handles the Non-AVX2 path
2345         for (; c < channels; ++c) {
2346           float result = h0lambda *
2347                   (w0lambda * pos1[0] + w1lambda * pos1[w1p * channels]) +
2348               h1lambda *
2349                   (w0lambda * pos1[h1p * input_width * channels] +
2350                    w1lambda * pos1[(h1p * input_width + w1p) * channels]);
2351           pos2[0] = at::native::quantize_val<scalar_t>(
2352                         inverse_scale,
2353                         output_q_zero_point,
2354                         result - input_q_zero_point)
2355                         .val_;
2356           pos1 += 1;
2357           pos2 += 1;
2358         } // c
2359 
2360         data_index_step(b, nbatch, h2, output_height, w2, output_width);
2361       }
2362     });
2363   });
2364 }
2365 
2366 void qtopk_kernel(Tensor& values,
2367     Tensor& indices,
2368     const Tensor& self,
2369     int64_t k,
2370     int64_t dim,
2371     bool largest,
2372     bool sorted) {
2373   auto sizes = self.sizes();
2374   auto iter = TensorIteratorConfig()
2375     .check_all_same_dtype(false)
2376     .resize_outputs(false)
2377     .declare_static_shape(sizes, /*squash_dims=*/dim)
2378     .add_output(values)
2379     .add_output(indices)
2380     .add_input(self)
2381     .build();
2382 
2383   auto mode_values_stride = values.strides()[dim];
2384   auto mode_indices_stride = indices.strides()[dim];
2385   auto tmp_values_stride = self.strides()[dim];
2386   // If sizes is empty, the tensor is scalar. This prevents accessing an empty array.
2387   auto dim_size = sizes.empty() ? 1 : sizes[dim];
2388 
2389   AT_DISPATCH_QINT_TYPES(self.scalar_type(), "qtopk_cpu", [&] {
2390     auto loop = [&](char** data, const int64_t* strides, int64_t n) {
2391       using underlying_t = typename scalar_t::underlying;
2392       static_assert(sizeof(scalar_t) == sizeof(underlying_t), "");
2393       return topk_impl_loop<underlying_t, underlying_t>(
2394           mode_values_stride, mode_indices_stride, tmp_values_stride,
2395           k, dim_size, largest, sorted, data, strides, n);
2396     };
2397 
2398     int64_t grain_size = internal::GRAIN_SIZE / std::max(int64_t{1}, sizes[dim]);
2399     iter.for_each(loop, /*grain_size=*/grain_size);
2400   });
2401 }
2402 
2403 template <typename T, bool ReluFused>
2404 inline void do_bn_compute(
2405     typename T::underlying* X_ptr,
2406     typename T::underlying* Y_ptr,
2407     Vectorized<float> & fake_scale,
2408     Vectorized<float> & in_zp_vec,
2409     Vectorized<float> & scale_neg_zp_premul,
2410     int64_t out_zero_point,
2411     Vectorized<T> & out_zero_point_v,
2412     float*  alpha,
2413     float* beta,
2414     int64_t vec_num,
2415     int64_t kVLen
2416 ) {
2417   using Vec = Vectorized<T>;
2418   auto vals_q = Vec::loadu(X_ptr);
2419   // Fake scale of 1.0 here, should not affect performance (FMA in place of sub)
2420   auto vals_dq = vals_q.dequantize(fake_scale, in_zp_vec, scale_neg_zp_premul);
2421   for (const auto idx : c10::irange(vec_num)) {
2422     auto alpha_v = Vectorized<float>::loadu(alpha + idx * kVLen);
2423     auto beta_v = Vectorized<float>::loadu(beta + idx * kVLen);
2424     vals_dq[idx] = vec::fmadd(alpha_v, vals_dq[idx], beta_v);
2425   }
2426   // NOLINTNEXTLINE(bugprone-argument-comment)
2427   auto outputs_q = Vec::quantize(vals_dq, /*output_scale=*/1.0f, out_zero_point, /*inv_output_scale=*/1.0f);
2428   // Fake scale again
2429   if constexpr (ReluFused) {
2430     outputs_q = outputs_q.maximum(out_zero_point_v);
2431   }
2432   outputs_q.store(Y_ptr, vec_num * kVLen);
2433 }
2434 
2435 template <bool ReluFused>
2436 void q_batch_norm_kernel(
2437     int64_t N,
2438     int64_t C,
2439     int64_t HxW,
2440     int64_t in_zero_point,
2441     int64_t out_zero_point,
2442     const Tensor& input,
2443     const Tensor& a,
2444     const Tensor& b,
2445     Tensor& output) {
2446 
2447   AT_DISPATCH_QINT_TYPES(input.scalar_type(), "qbatch_norm", [&]() {
2448     float* alpha = a.data_ptr<float>();
2449     float* beta = b.data_ptr<float>();
2450     auto minimum = std::numeric_limits<scalar_t::underlying>::lowest();
2451     auto maximum = std::numeric_limits<scalar_t::underlying>::max();
2452     scalar_t::underlying* X =
2453         reinterpret_cast<scalar_t::underlying*>(input.data_ptr());
2454     scalar_t::underlying* Y = reinterpret_cast<scalar_t::underlying*>(output.data_ptr());
2455 
2456     constexpr int kVLen = Vectorized<float>::size();
2457     const int64_t outer_size = N * HxW;
2458     using Vec = Vectorized<scalar_t>;
2459     // Hoisted variables
2460     auto in_zp_vec = Vectorized<float>(static_cast<float>(in_zero_point));
2461     auto fake_scale = Vectorized<float>(1.0f);
2462     auto scale_neg_zp_premul = fake_scale * in_zp_vec.neg();
2463     auto out_zero_point_v = Vec(scalar_t(out_zero_point));
2464     const auto lanes = static_cast<int64_t>(Vec::float_num_vecs() * kVLen);
2465     at::parallel_for(0, outer_size, 0, [&](int64_t begin, int64_t end) {
2466       for (const auto i : c10::irange(begin, end)) {
2467         auto* X_ptr = reinterpret_cast<typename scalar_t::underlying*>(X + i * C);
2468         auto* Y_ptr = reinterpret_cast<typename scalar_t::underlying*>(Y + i * C);
2469         int64_t ch = 0;
2470 
2471         for(; ch + lanes <= C; ch += lanes) {
2472           do_bn_compute<scalar_t, ReluFused>(
2473             X_ptr + ch,
2474             Y_ptr + ch,
2475             fake_scale,
2476             in_zp_vec,
2477             scale_neg_zp_premul,
2478             out_zero_point,
2479             out_zero_point_v,
2480             alpha + ch,
2481             beta + ch,
2482             Vec::float_num_vecs(),
2483             kVLen
2484           );
2485         }
2486 
2487         // for channel between 8 and 32, still use 32 width for performance
2488         // Benchmark shows it is faster than doing 8 channels each time
2489         int64_t elem_size = C - ch;
2490         if ((lanes == 32) && elem_size >= kVLen) {
2491           int64_t vec_num = elem_size / kVLen;
2492           std::vector<typename scalar_t::underlying> buf_in(lanes);
2493           memcpy(buf_in.data(), X_ptr + ch, vec_num * kVLen); // 3 cycles
2494           do_bn_compute<scalar_t, ReluFused>(
2495             buf_in.data(),
2496             Y_ptr + ch,
2497             fake_scale,
2498             in_zp_vec,
2499             scale_neg_zp_premul,
2500             out_zero_point,
2501             out_zero_point_v,
2502             alpha + ch,
2503             beta + ch,
2504             vec_num,
2505             kVLen
2506           );
2507           ch += vec_num * kVLen;
2508         }
2509         // for channels less than 8
2510         for (; ch < C; ++ch) {
2511           long quantized_down = out_zero_point +
2512               lrintf(alpha[ch] * (X_ptr[ch] - in_zero_point) +
2513                           beta[ch]);
2514           if constexpr (ReluFused) { // static if
2515             quantized_down = std::max<long>(quantized_down, out_zero_point);
2516           }
2517           Y_ptr[ch] = std::min<long>(
2518               std::max<long>(quantized_down, minimum), maximum);
2519         }
2520       }
2521     });
2522   });
2523 }
2524 
2525 void _fake_quantize_tensor_helper(
2526   Tensor& output,
2527   Tensor& mask,
2528   const Tensor& input,
2529   int fake_quant_on,
2530   float sc,
2531   int64_t z_point,
2532   int64_t quant_min,
2533   int64_t quant_max) {
2534 
2535   float inv_scale = 1.0f / sc;
2536 
2537   auto iter_combined = TensorIteratorConfig()
2538     .check_all_same_dtype(false)
2539     .add_output(output)
2540     .add_output(mask)
2541     .add_input(input)
2542     .build();
2543 
2544   AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "fake_quantize_tensor_cachemask_kernel_type_handling", [&] {
2545     iter_combined.for_each([&](char** data, const int64_t* strides, int64_t n) {
2546       for (const auto i : c10::irange(n)) {
2547         scalar_t* output_val = (scalar_t*)(data[0] + i * strides[0]);
2548         bool* mask_val = (bool*)(data[1] + i * strides[1]);
2549         scalar_t* input_val = (scalar_t*)(data[2] + i * strides[2]);
2550 
2551         const auto qval = static_cast<int64_t>(z_point + std::nearbyint(*input_val * inv_scale));
2552         if (fake_quant_on) {
2553         *output_val = (std::fmin(std::fmax(qval, quant_min), quant_max) - z_point) * sc;
2554         *mask_val = ((quant_min <= qval) && (qval <= quant_max));
2555         } else {
2556           *output_val = *input_val;
2557           *mask_val = 1;
2558         }
2559       }
2560     });
2561   });
2562   }
2563 
2564 void fake_quantize_tensor_cachemask_kernel(
2565     Tensor& output,
2566     Tensor& mask,
2567     const Tensor& input,
2568     float sc,
2569     int64_t z_point,
2570     int64_t quant_min,
2571     int64_t quant_max) {
2572   _fake_quantize_tensor_helper(output, mask, input, 1, sc, z_point, quant_min, quant_max);
2573 }
2574 
2575 void fake_quantize_tensor_cachemask_tensor_qparams_kernel(
2576     Tensor& output,
2577     Tensor& mask,
2578     const Tensor& input,
2579     const Tensor& sc,
2580     const Tensor& z_point,
2581     const Tensor& fake_quant_enabled,
2582     int64_t quant_min,
2583     int64_t quant_max) {
2584   _fake_quantize_tensor_helper(output, mask, input, fake_quant_enabled.item().toInt(), sc.item().toFloat(), z_point.item().toInt(), quant_min, quant_max);
2585 }
2586 
2587 void fake_quantize_learnable_tensor_grad_kernel_cpu(
2588     TensorIterator& iter,
2589     float scale,
2590     float inv_scale,
2591     int64_t zero_point,
2592     int64_t quant_min,
2593     int64_t quant_max,
2594     float grad_factor) {
2595   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
2596   float dscale_small = quant_min - zero_point;
2597   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
2598   float dscale_big = quant_max - zero_point;
2599   iter.for_each([&](char** data, const int64_t* strides, int64_t n) {
2600     /*  When a for_each call is made on a TensorIterator with multiple inputs and outputs,
2601         the order they are accessed follows the order they are built within the iterator.
2602         For example, if an iterator is built in the following order:
2603         auto iter = TensorIteratorConfig().
2604           .add_output(firstOutput)
2605           .add_output(secondOutput)
2606           .add_input(firstInput)
2607           .add_input(secondInput)
2608           .build()
2609         data will contain 4 pointers to pointers to values in the following order:
2610         firstOutput, secondOutput, firstInput, secondInput.
2611         Proper pointer referencing and dereferencing, along with the usage of strides
2612         (to move onto different elements), can allow accessing of the input and assignment
2613         to the right output.
2614     */
2615     for (const auto i : c10::irange(n)) {
2616       float* dXOutput = (float*)(data[0] + i * strides[0]);
2617       float* dScaleOutput = (float*)(data[1] + i * strides[1]);
2618       float* dZeroPointOutput = (float*)(data[2] + i * strides[2]);
2619       float* XInput = (float*)(data[3] + i * strides[3]);
2620       float* dYInput = (float*)(data[4] + i * strides[4]);
2621       // Calculate gradients for X.
2622       int64_t xqi = std::nearbyint(zero_point + (*XInput) * inv_scale);
2623       // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
2624       *dXOutput = (*dYInput) * (xqi >= quant_min && xqi <= quant_max);
2625       // Calculate gradients for scale and zero point.
2626       // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
2627       float xfqi = static_cast<float>((std::max(std::min(xqi, quant_max), quant_min) - zero_point) * scale);
2628       // Calculate gradients according to the gradient of the clamp function.
2629       if (xqi < quant_min || xqi > quant_max) {
2630         *dZeroPointOutput = (*dYInput) * (-1) * scale * grad_factor;
2631         *dScaleOutput = ((xqi < quant_min) ? ((*dYInput) * dscale_small) : ((*dYInput) * dscale_big)) * grad_factor;
2632       } else {
2633         *dZeroPointOutput = 0;
2634         *dScaleOutput = (*dYInput) * (xfqi - (*XInput)) * inv_scale * grad_factor;
2635       }
2636     }
2637   });
2638 }
2639 
2640 template <typename SelfType>
2641 void _fake_quant_per_channel_cachemask_cpu_helper(
2642     TensorIterator& iter,
2643     TensorIterator& iter_mask,
2644     const int64_t quant_min,
2645     const int64_t quant_max) {
2646 
2647   const auto& zero_point_dtype = iter.input_dtype(2);
2648 
2649   if(at::isFloatingType(zero_point_dtype)){
2650     // When zero_point is float, quantize mirroring affine quantizer equation
2651     // Xq = Round(Xf * inv_scale + zero_point)
2652     // where zero_point is in float.
2653     AT_DISPATCH_FLOATING_TYPES_AND_HALF(zero_point_dtype, "fake_quantize_channel_cachemask_cpu_zero_point_handling", [&] {
2654       // write mask
2655       cpu_kernel(iter_mask, [=](SelfType self, float scale, scalar_t zero_point) -> bool {
2656         float inv_scale = 1.0f / scale;
2657         const auto qval = std::lrintf(zero_point + (self * inv_scale));
2658         return ((quant_min <= qval) && (qval <= quant_max));
2659       });
2660 
2661       // write fake_quant
2662       cpu_kernel(iter, [=](SelfType self, float scale, scalar_t zero_point) -> SelfType {
2663         float inv_scale = 1.0f / scale;
2664         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
2665         return (std::fmin(
2666                     std::fmax(
2667                         std::lrintf(zero_point + self * inv_scale),
2668                         quant_min),
2669                     quant_max) -
2670                 zero_point) *
2671             scale;
2672       });
2673     });
2674 
2675   } else {
2676       // write mask
2677       cpu_kernel(iter_mask, [=](SelfType self, float scale, int32_t zero_point) -> bool {
2678         float inv_scale = 1.0f / scale;
2679         const auto qval = static_cast<int64_t>(zero_point + std::nearbyint(self * inv_scale));
2680         return ((quant_min <= qval) && (qval <= quant_max));
2681       });
2682 
2683       // write fake_quant
2684       cpu_kernel(iter, [=](SelfType self, float scale, int32_t zero_point) -> SelfType {
2685         float inv_scale = 1.0f / scale;
2686         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
2687         return (std::fmin(
2688                     std::fmax(
2689                         static_cast<int64_t>(
2690                             zero_point + std::nearbyint(self * inv_scale)),
2691                         quant_min),
2692                     quant_max) -
2693                 zero_point) *
2694             scale;
2695       });
2696   }
2697 
2698 }
2699 
2700 
2701 void fake_quant_per_channel_cachemask_cpu(
2702     TensorIterator& iter,
2703     TensorIterator& iter_mask,
2704     int64_t quant_min,
2705     int64_t quant_max) {
2706   // TODO(future, optional): read once, write twice.  Not done at the moment
2707   //   for simplicity, as we do not expect this to be a bottleneck.
2708 
2709   AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "fake_quantize_channel_cachemask_cpu_type_handling", [&] {
2710     _fake_quant_per_channel_cachemask_cpu_helper<scalar_t>(iter, iter_mask, quant_min, quant_max);
2711   });
2712 }
2713 
2714 
2715 void fake_quantize_learnable_channel_grad_kernel_cpu(
2716     TensorIterator& iter,
2717     int64_t quant_min,
2718     int64_t quant_max,
2719     float grad_factor) {
2720   iter.for_each([&](char** data, const int64_t* strides, int64_t n) {
2721     /*  To see how the input and outputs are referenced and assigned,
2722         please see the implementation of
2723         fake_quantize_learnable_tensor_grad_kernel_cpu.
2724     */
2725     for (const auto i : c10::irange(n)) {
2726       float* dx_output = (float*)(data[0] + i * strides[0]);
2727       float* dscale_output = (float*)(data[1] + i * strides[1]);
2728       float* dzero_point_output = (float*)(data[2] + i * strides[2]);
2729       float* x_input = (float*)(data[3] + i * strides[3]);
2730       float* dy_input = (float*)(data[4] + i * strides[4]);
2731       float* scale_input = (float*)(data[5] + i * strides[5]);
2732       float* zero_point_input = (float*)(data[6] + i * strides[6]);
2733 
2734       float inv_scale = 1.0f / (*scale_input);
2735       float dscale_small = quant_min - (*zero_point_input);
2736       float dscale_big = quant_max - (*zero_point_input);
2737       // Calculate gradients for X.
2738       int64_t xqi = std::nearbyint((*zero_point_input) + (*x_input) * inv_scale);
2739       // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
2740       *dx_output = (*dy_input) * (xqi >= quant_min && xqi <= quant_max);
2741       // Calculate gradients for scale and zero point.
2742       float xfqi = static_cast<float>((std::max(std::min(xqi, quant_max), quant_min) - (*zero_point_input)) * (*scale_input));
2743       if (xqi < quant_min || xqi > quant_max) {
2744         *dzero_point_output = (*dy_input) * (-1) * (*scale_input) * grad_factor;
2745         *dscale_output = ((xqi < quant_min) ? ((*dy_input) * dscale_small) : ((*dy_input) * dscale_big)) * grad_factor;
2746       } else {
2747         *dzero_point_output = 0;
2748         *dscale_output = (*dy_input) * (xfqi - (*x_input)) * inv_scale * grad_factor;
2749       }
2750     }
2751   });
2752 }
2753 
2754 // Assumes X is composed of M groups of N elements. Normalizes each of the
2755 // groups and optionally applies affine scaling. Useful for LayerNorm,
2756 // GroupNorm, InstanceNorm.
2757 void quantized_normalize_kernel(
2758     const Tensor& X, // input tensor
2759     const Tensor& gamma, // weight (optional)
2760     const Tensor& beta, // bias (optional)
2761     bool affine_per_channel, // scaling applied elementwise if false, per channel if true
2762     int num_channels, // only used if affine_per_channel is set
2763     int num_groups, // only used if affine_per_channel is set
2764     int64_t M, // number of groups
2765     int64_t N, // number of elements in each group
2766     double eps,
2767     Tensor* Y) {
2768   AT_DISPATCH_QINT_TYPES(X.scalar_type(), "quantized_layer_norm_kernel_impl_cpu", [&]() {
2769     using qVec = vec::Vectorized<scalar_t>;
2770     using fVec = vec::Vectorized<float>;
2771 
2772     TORCH_INTERNAL_ASSERT(X.numel() == M * N, "Unexpected num elements in X");
2773     TORCH_INTERNAL_ASSERT(
2774         !gamma.defined() ||
2775         (!affine_per_channel && gamma.numel() == N) ||
2776         (affine_per_channel && gamma.numel() == num_channels),
2777         "Unexpected size of gamma");
2778     TORCH_INTERNAL_ASSERT(
2779         !beta.defined() ||
2780         (!affine_per_channel && beta.numel() == N) ||
2781         (affine_per_channel && beta.numel() == num_channels),
2782         "Unexpected size of beta");
2783 
2784     scalar_t* X_data = X.data_ptr<scalar_t>();
2785     const float* gamma_data = gamma.defined() ? gamma.const_data_ptr<float>() : nullptr;
2786     const float* beta_data = beta.defined() ? beta.const_data_ptr<float>() : nullptr;
2787     scalar_t* Y_data = Y->data_ptr<scalar_t>();
2788     const bool gamma_null = gamma_data == nullptr;
2789     const bool beta_null = beta_data == nullptr;
2790     int64_t x_zp = X.q_zero_point();
2791     float x_scale = X.q_scale();
2792     fVec x_zp_vec((float)x_zp);
2793     fVec one_vec(1.0f);
2794     fVec zero_vec(0.0f);
2795     float x_fake_scale = 1.0f;
2796     fVec x_fake_scale_vec(x_fake_scale);
2797     fVec x_fake_scale_zp_neg_premul_vec = x_fake_scale_vec * x_zp_vec.neg();
2798     int64_t y_zp = Y->q_zero_point();
2799     float y_scale = Y->q_scale();
2800     float y_inv_scale = 1.0f / y_scale;
2801 
2802     constexpr int kFloatVLen = fVec::size();
2803     int64_t kIntVLen = kFloatVLen * qVec::float_num_vecs();
2804     int64_t kNumIntVecInLayer = N / kIntVLen;
2805     int64_t kNonVecRemInLayer = N % kIntVLen;
2806     int channels_per_group = num_channels / num_groups;
2807     int64_t NPerChannel = N / channels_per_group;
2808     int64_t kNumIntVecInChannel = NPerChannel / kIntVLen;
2809     int64_t kNonVecRemInChannel = NPerChannel % kIntVLen;
2810 
2811     at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {
2812       for (const auto i : c10::irange(start, end)) {
2813 
2814         scalar_t* X_ptr = X_data + i * N;
2815         scalar_t* Y_ptr = Y_data + i * N;
2816 
2817         // First pass: calculate mean and variance.
2818 
2819         scalar_t::underlying* X_ptr_underlying = reinterpret_cast<scalar_t::underlying*>(X_ptr);
2820         auto l_sum_shifted = hsum(X_ptr_underlying, N);
2821         auto l_sum_sq_shifted = hsum_sq(X_ptr_underlying, N);
2822         float l_mean_shifted_div_scale_x = static_cast<float>(l_sum_shifted) / N;
2823         // mean(dqX) / scale_x
2824         float layer_mean_div_scale_x = l_mean_shifted_div_scale_x - x_zp;
2825         // var(dqX) / scale_x^2
2826         float layer_var_div_scale_x_sq =
2827           std::max(static_cast<float>(l_sum_sq_shifted) / N -
2828               l_mean_shifted_div_scale_x * l_mean_shifted_div_scale_x, 0.0f);
2829         // scale_x / sqrt(var(dqX) + eps)
2830         float scale_x_div_layer_std = x_scale /
2831           std::sqrt(layer_var_div_scale_x_sq * x_scale * x_scale + eps);
2832         fVec layer_mean_div_scale_xVec(layer_mean_div_scale_x);
2833         fVec scale_x_div_layer_stdVec(scale_x_div_layer_std);
2834 
2835         // Second pass: normalize
2836 
2837         // TODO replace with TensorIterator implementation once #33166 is fixed.
2838         if (affine_per_channel) {
2839 
2840           // if scaling per channel, scaling parameters can be pre-multiplied
2841           // with normalization parameters
2842           for (const auto chIdx : c10::irange(channels_per_group)) {
2843             int scalingIdx = (i * channels_per_group + chIdx) % (num_channels);
2844             float gamma = gamma_null ? 1.0f : gamma_data[scalingIdx];
2845             // scale_x / layer_std * gamma
2846             float gamma_p = scale_x_div_layer_std * gamma;
2847             float beta = beta_null ? 0.0f : beta_data[scalingIdx];
2848             fVec gamma_p_vec(gamma_p);
2849             fVec beta_vec(beta);
2850 
2851             int64_t chStartIdx = chIdx * NPerChannel;
2852             int64_t chEndIdx = chStartIdx + NPerChannel;
2853 
2854             for (const auto vecIdx : c10::irange(kNumIntVecInChannel)) {
2855               int64_t vecStartIdx = chStartIdx + vecIdx * kIntVLen;
2856               auto qXVec = qVec::loadu(X_ptr + vecStartIdx);
2857               auto dqXVec = qXVec.dequantize(x_fake_scale_vec, x_zp_vec,
2858                   x_fake_scale_zp_neg_premul_vec);
2859               for (auto &dq : dqXVec) {
2860                 dq =
2861                   (dq - layer_mean_div_scale_xVec) *
2862                     gamma_p_vec + beta_vec;
2863               }
2864               qVec::quantize(dqXVec, y_scale, y_zp, y_inv_scale)
2865                 .store(Y_ptr + vecStartIdx);
2866             }
2867 
2868             // Remainder
2869             if (kNonVecRemInChannel > 0) {
2870               int64_t remIdx = chEndIdx - kNonVecRemInChannel;
2871               auto qXVec = qVec::loadu(X_ptr + remIdx, kNonVecRemInChannel);
2872               auto dqXVec = qXVec.dequantize(x_fake_scale_vec, x_zp_vec,
2873                     x_fake_scale_zp_neg_premul_vec);
2874               int validDqvecLen = (kNonVecRemInChannel - 1) / fVec::size() + 1;
2875               for (int i = 0; i < validDqvecLen; ++i) {
2876                 auto &dq = dqXVec[i];
2877                 dq =
2878                   (dq - layer_mean_div_scale_xVec) *
2879                     gamma_p_vec + beta_vec;
2880               }
2881               qVec::quantize(dqXVec, y_scale, y_zp, y_inv_scale)
2882                 .store(Y_ptr + remIdx, kNonVecRemInChannel);
2883             }
2884           } // chIdx
2885 
2886         } else {
2887 
2888           for (const auto vecIdx : c10::irange(kNumIntVecInLayer)) {
2889             int64_t vecStartIdx = vecIdx * kIntVLen;
2890             auto qXVec = qVec::loadu(X_ptr + vecStartIdx);
2891             auto dqXVec = qXVec.dequantize(x_fake_scale_vec, x_zp_vec,
2892                 x_fake_scale_zp_neg_premul_vec);
2893             for (const auto dqXVecIdx : c10::irange(dqXVec.size())) {
2894               int64_t vecVecStartIdx = vecStartIdx + dqXVecIdx * kFloatVLen;
2895               auto gammaVec = gamma_null
2896                 ? one_vec
2897                 : fVec::loadu(gamma_data + vecVecStartIdx);
2898               auto betaVec = beta_null
2899                 ? zero_vec
2900                 : fVec::loadu(beta_data + vecVecStartIdx);
2901               dqXVec[dqXVecIdx] =
2902                 (dqXVec[dqXVecIdx] - layer_mean_div_scale_xVec) *
2903                   scale_x_div_layer_stdVec * gammaVec + betaVec;
2904               qVec::quantize(dqXVec, y_scale, y_zp, y_inv_scale)
2905                 .store(Y_ptr + vecStartIdx);
2906             }
2907           }
2908           for (int64_t remIdx = N - kNonVecRemInLayer; remIdx < N; remIdx++) {
2909             const float gamma_v = gamma_null ? 1.0f : gamma_data[remIdx];
2910             const float beta_v = beta_null ? 0.0f : beta_data[remIdx];
2911             auto qXVal = X_ptr[remIdx];
2912             float dqXVal = at::native::dequantize_val(x_fake_scale, x_zp, qXVal);
2913             float dqY =
2914               ((dqXVal - layer_mean_div_scale_x) * scale_x_div_layer_std) * gamma_v + beta_v;
2915             Y_ptr[remIdx] = at::native::quantize_val<scalar_t>(y_scale, y_zp, dqY);
2916           }
2917         }
2918       }
2919     }); // parallel_for
2920 
2921   });
2922 }
2923 
2924 void qmean_inner_dim_kernel(
2925     const Tensor& self,
2926     OptionalIntArrayRef opt_dim,
2927     bool keepdim,
2928     std::optional<ScalarType> opt_dtype,
2929     Tensor& result) {
2930   // 'opt_dtype' should be none or equal to that of input
2931   ScalarType dtype = self.scalar_type();
2932   auto in_dims = self.sizes().vec();
2933   auto out_dims = in_dims;
2934   bool is_all_reduce = !opt_dim.has_value() || opt_dim.value().empty();
2935   size_t num_dims_to_squeeze = is_all_reduce ? self.dim() : opt_dim.value().size();
2936   int64_t M = 1; // Num of groups
2937   int64_t N = 1; // Num of elements to take average of in each group
2938   for (size_t i = 0; i < in_dims.size() - num_dims_to_squeeze; ++i) {
2939     M *= in_dims[i];
2940   }
2941   for (size_t i = 0; i < num_dims_to_squeeze; ++i) {
2942     auto idx = out_dims.size() - 1 - i;
2943     N *= out_dims[idx];
2944     out_dims[idx] = 1;
2945   }
2946   if (!keepdim) {
2947     out_dims.erase(out_dims.end() - num_dims_to_squeeze, out_dims.end());
2948   }
2949   result = at::_empty_affine_quantized(
2950       out_dims,
2951       at::device(kCPU).dtype(dtype).memory_format(self.suggest_memory_format()),
2952       self.q_scale(),
2953       self.q_zero_point(),
2954       std::nullopt);
2955 
2956   AT_DISPATCH_QINT_TYPES(self.scalar_type(), "quantized_mean_kernel_impl_cpu", [&]() {
2957     scalar_t* X_data = self.data_ptr<scalar_t>();
2958     scalar_t* Y_data = result.data_ptr<scalar_t>();
2959 
2960     at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {
2961       for (const auto i : c10::irange(start, end)) {
2962         scalar_t* X_ptr = X_data + i * N;
2963         scalar_t* Y_ptr = Y_data + i;
2964         scalar_t::underlying* X_ptr_underlying = reinterpret_cast<scalar_t::underlying*>(X_ptr);
2965         scalar_t::underlying* Y_ptr_underlying = reinterpret_cast<scalar_t::underlying*>(Y_ptr);
2966         auto x_sum = hsum(X_ptr_underlying, N);
2967         float y_float = static_cast<float>(x_sum) / N;
2968         *Y_ptr_underlying = std::nearbyint(y_float);
2969       }
2970     });
2971   });
2972 }
2973 
2974 void qstd_inner_dim_kernel(
2975     const Tensor& self,
2976     OptionalIntArrayRef dim,
2977     const std::optional<Scalar>& correction_opt,
2978     bool keepdim,
2979     Tensor& result) {
2980   ScalarType dtype = self.scalar_type();
2981   auto in_dims = self.sizes().vec();
2982   auto out_dims = in_dims;
2983   size_t num_dims_to_squeeze = dim.has_value() && !dim.value().empty() ?
2984                                dim.value().size() :
2985                                self.dim();
2986   int64_t M = 1; // Num of groups
2987   int64_t N = 1; // Num of elements to take std of in each group
2988   for (size_t i = 0; i < in_dims.size() - num_dims_to_squeeze; ++i) {
2989     M *= in_dims[i];
2990   }
2991   for (size_t i = 0; i < num_dims_to_squeeze; ++i) {
2992     auto idx = out_dims.size() - 1 - i;
2993     N *= out_dims[idx];
2994     out_dims[idx] = 1;
2995   }
2996   if (!keepdim) {
2997     out_dims.erase(out_dims.end() - num_dims_to_squeeze, out_dims.end());
2998   }
2999   const auto correction = correction_opt.value_or(1).toDouble();
3000   double den = std::max(N - correction, 0.0); // Denominator when computing mean and deviation
3001   auto x_scale = self.q_scale();
3002   auto x_zp = self.q_zero_point();
3003   result = at::_empty_affine_quantized(
3004       out_dims,
3005       at::device(kCPU).dtype(dtype).memory_format(self.suggest_memory_format()),
3006       x_scale,
3007       x_zp,
3008       std::nullopt);
3009 
3010   AT_DISPATCH_QINT_TYPES(self.scalar_type(), "quantized_std_kernel_impl_cpu", [&]() {
3011     scalar_t* X_data = self.data_ptr<scalar_t>();
3012     scalar_t* Y_data = result.data_ptr<scalar_t>();
3013 
3014     at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {
3015       for (const auto i : c10::irange(start, end)) {
3016         scalar_t* X_ptr = X_data + i * N;
3017         scalar_t* Y_ptr = Y_data + i;
3018         scalar_t::underlying* X_ptr_underlying = reinterpret_cast<scalar_t::underlying*>(X_ptr);
3019         scalar_t::underlying* Y_ptr_underlying = reinterpret_cast<scalar_t::underlying*>(Y_ptr);
3020         auto x_sum_shifted = hsum(X_ptr_underlying, N);
3021         auto x_sum_sq_shifted = hsum_sq(X_ptr_underlying, N);
3022         // Use double for intermediate variables to avoid accuracy issue
3023         // Mean with zero point
3024         double x_mean_shifted_div_scale_x = static_cast<double>(x_sum_shifted) / N;
3025         double x_mean_unbiased_shifted_div_scale_x = static_cast<double>(x_sum_shifted) / den;
3026         // variance / x_scale^2
3027         double x_var_div_scale_x_sq =
3028             std::max(static_cast<double>(x_sum_sq_shifted) / den -
3029                 2 * x_mean_shifted_div_scale_x * x_mean_unbiased_shifted_div_scale_x +
3030                 x_mean_shifted_div_scale_x * x_mean_shifted_div_scale_x * N / den, (double)0.0);
3031         double y_float = std::sqrt(x_var_div_scale_x_sq) * x_scale;
3032         *Y_ptr_underlying = at::native::quantize_val<scalar_t>(
3033                             x_scale, x_zp, y_float)
3034                             .val_;
3035       }
3036     });
3037   });
3038 }
3039 
3040 // For group norm of channels_last input
3041 void quantized_groupnorm_nhwc_kernel(
3042     const Tensor& X, // input tensor
3043     const Tensor& gamma, // weight (optional)
3044     const Tensor& beta, // bias (optional)
3045     bool affine_per_channel, // must be true for group/instance norm
3046     int num_channels, // only used if affine_per_channel is set
3047     int num_groups, // only used if affine_per_channel is set
3048     int64_t M, // number of groups = Bs * G
3049     int64_t N, // number of elements in each group = C * H * W / G
3050     double eps,
3051     Tensor* Y) {
3052   AT_DISPATCH_QINT_TYPES(X.scalar_type(), "quantized_norm_nhwc_kernel_impl_cpu", [&]() {
3053     using qVec = vec::Vectorized<scalar_t>;
3054     using fVec = vec::Vectorized<float>;
3055 
3056     int64_t G = num_groups;
3057     int64_t Bs = M / G;
3058     int64_t C = num_channels;
3059 
3060     TORCH_INTERNAL_ASSERT(X.numel() == M * N, "Unexpected num elements in X");
3061     TORCH_INTERNAL_ASSERT(
3062         !gamma.defined() ||
3063         (!affine_per_channel && gamma.numel() == N) ||
3064         (affine_per_channel && gamma.numel() == C),
3065         "Unexpected size of gamma");
3066     TORCH_INTERNAL_ASSERT(
3067         !beta.defined() ||
3068         (!affine_per_channel && beta.numel() == N) ||
3069         (affine_per_channel && beta.numel() == C),
3070         "Unexpected size of beta");
3071 
3072     scalar_t* X_data = X.data_ptr<scalar_t>();
3073     const float* gamma_data = gamma.defined() ? gamma.const_data_ptr<float>() : nullptr;
3074     const float* beta_data = beta.defined() ? beta.const_data_ptr<float>() : nullptr;
3075     scalar_t* Y_data = Y->data_ptr<scalar_t>();
3076     const bool gamma_null = gamma_data == nullptr;
3077     const bool beta_null = beta_data == nullptr;
3078     int64_t x_zp = X.q_zero_point();
3079     float x_scale = X.q_scale();
3080     fVec x_zp_vec((float)x_zp);
3081     fVec one_vec(1.0f);
3082     fVec zero_vec(0.0f);
3083     float x_fake_scale = 1.0f;
3084     fVec x_fake_scale_vec(x_fake_scale);
3085     fVec x_fake_scale_zp_neg_premul_vec = x_fake_scale_vec * x_zp_vec.neg();
3086     int64_t y_zp = Y->q_zero_point();
3087     float y_scale = Y->q_scale();
3088     float y_inv_scale = 1.0f / y_scale;
3089 
3090     constexpr int kFloatVLen = fVec::size();
3091     int64_t kIntVLen = kFloatVLen * qVec::float_num_vecs();
3092     int64_t channels_per_group = C / G;
3093     int64_t HxW = N / channels_per_group;
3094     int64_t kNumIntVecInHxW = channels_per_group / kIntVLen;
3095     int64_t kNonVecRemInHxW = channels_per_group % kIntVLen;
3096     int64_t kNumIntVecOnChannel = C / kIntVLen;
3097     int64_t kNonVecRemOnChannel = C % kIntVLen;
3098 
3099     // Buffer for x and x^2
3100     Tensor buffer = at::empty({M, 2 * channels_per_group}, X.options().dtype(at::kFloat));
3101     float* buffer_data = buffer.mutable_data_ptr<float>();
3102 
3103     // We can parallel in the following 2 impls:
3104     //
3105     // impl-1: parallel on N * G. Only need one omp session but memory access
3106     //   per thread is non-contiguous.
3107     //
3108     // impl-2: parallel on N * HxW. Memory access per thread is contiguous,
3109     //   but requires help of extra temp buffer of size {T, N, 2C}.
3110     //
3111     // Generally impl-2 has better performance when HxW is large enough
3112     // The threshold is found by tests.
3113     constexpr int64_t feature_map_threshold = 512;
3114     if (HxW < feature_map_threshold) {
3115       // Impl-1: Parallel for each group
3116       //
3117       // Parallel for each group, M = Bs * G
3118       at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
3119         int64_t n{0} /* batch index */, g{0} /* group index in each batch */;
3120         data_index_init(begin, n, N, g, G);
3121         for (const auto grpIdx : c10::irange(begin, end)) { // For each group
3122 
3123           // Step 1: calculate mean and variance.
3124           int64_t l_sum_shifted = 0;
3125           int64_t l_sum_sq_shifted = 0;
3126           for (const auto hw : c10::irange(HxW)) {
3127             scalar_t* X_ptr = X_data + n * N * G + g * channels_per_group + hw * C;
3128             scalar_t::underlying* X_ptr_underlying = reinterpret_cast<scalar_t::underlying*>(X_ptr);
3129             l_sum_shifted += hsum(X_ptr_underlying, channels_per_group);
3130             l_sum_sq_shifted += hsum_sq(X_ptr_underlying, channels_per_group);
3131           }
3132 
3133           // mean(dqX) / scale_x + x_zp
3134           float l_mean_shifted_div_scale_x = static_cast<float>(l_sum_shifted) / N;
3135           // mean(dqX) / scale_x
3136           float layer_mean_div_scale_x = l_mean_shifted_div_scale_x - x_zp;
3137           // var(dqX) / scale_x^2
3138           float layer_var_div_scale_x_sq =
3139             std::max(static_cast<float>(l_sum_sq_shifted) / N -
3140                 l_mean_shifted_div_scale_x * l_mean_shifted_div_scale_x, 0.0f);
3141           // scale_x / sqrt(var(dqX) + eps)
3142           float scale_x_div_layer_std = x_scale /
3143             std::sqrt(layer_var_div_scale_x_sq * x_scale * x_scale + eps);
3144 
3145           // Step 2: calculate scale and bias
3146           float* scale_ptr = buffer_data + grpIdx * 2 * channels_per_group;
3147           float* bias_ptr = scale_ptr + channels_per_group;
3148           for (const auto d : c10::irange(channels_per_group)) {
3149             const int64_t chIdx = g * channels_per_group + d;
3150             scale_ptr[d] = scale_x_div_layer_std * (gamma_null ? 1.0f : gamma_data[chIdx]);
3151             bias_ptr[d] = -scale_ptr[d] * layer_mean_div_scale_x + (beta_null ? 0.0f : beta_data[chIdx]);
3152           }
3153 
3154           // Step 3: applying scale and bias
3155           for (const auto hwIdx : c10::irange(HxW)) {
3156             const scalar_t* X_ptr = X_data + n * N * G + g * channels_per_group + hwIdx * C;
3157             scalar_t* Y_ptr = Y_data + n * N * G + g * channels_per_group + hwIdx * C;
3158             // vectorized
3159             for (const auto vecIdx : c10::irange(kNumIntVecInHxW)) {
3160               int64_t vecStartIdx = vecIdx * kIntVLen;
3161               auto qXVec = qVec::loadu(X_ptr + vecStartIdx);
3162               auto dqXVec = qXVec.dequantize(x_fake_scale_vec, x_zp_vec,
3163                     x_fake_scale_zp_neg_premul_vec);
3164               for (size_t fvecIdx = 0; fvecIdx < dqXVec.size(); ++fvecIdx) {
3165                 auto scaleVec = fVec::loadu(scale_ptr + vecStartIdx + fvecIdx * kFloatVLen);
3166                 auto biasVec = fVec::loadu(bias_ptr + vecStartIdx + fvecIdx * kFloatVLen);
3167                 dqXVec[fvecIdx] = dqXVec[fvecIdx] * scaleVec + biasVec;
3168               }
3169               qVec::quantize(dqXVec, y_scale, y_zp, y_inv_scale)
3170                   .store(Y_ptr + vecStartIdx);
3171             }
3172             // Remaining scalar
3173             for (int64_t remIdx = kNumIntVecInHxW * kIntVLen;
3174                  remIdx < kNonVecRemInHxW + kNumIntVecInHxW * kIntVLen;
3175                  ++remIdx) {
3176               auto qXVal = X_ptr[remIdx];
3177               float dqXVal = at::native::dequantize_val(x_fake_scale, x_zp, qXVal);
3178               float dqY = dqXVal * scale_ptr[remIdx] + bias_ptr[remIdx];
3179               Y_ptr[remIdx] = at::native::quantize_val<scalar_t>(y_scale, y_zp, dqY);
3180             }
3181           } // loop over HxW
3182 
3183           data_index_step(n, N, g, G);
3184         } // for each group
3185       }); // parallel_for
3186     } else { // HxW > feature_map_threshold
3187       // impl-2: parallel on Bs * HxW.
3188       //
3189       // Buffer for x and x^2
3190       // To avoid thread conflict, we use a temp buffer of {T, Bs, 2*C}
3191       int num_threads = at::get_num_threads();
3192       Tensor buffer = at::empty({num_threads, Bs, 2 * C}, X.options().dtype(at::kFloat)).zero_();
3193       float* buffer_data = buffer.mutable_data_ptr<float>();
3194       Tensor mean = at::empty(M, X.options().dtype(at::kFloat));
3195       float* mean_data = mean.mutable_data_ptr<float>();
3196       Tensor rstd = at::empty(M, X.options().dtype(at::kFloat));
3197       float* rstd_data = rstd.mutable_data_ptr<float>();
3198 
3199       // Step 1: Accumulate on C dimension
3200       at::parallel_for(0, Bs * HxW, 1, [&](int64_t begin, int64_t end) {
3201         int tid = at::get_thread_num();
3202         float* buffer_ptr = buffer_data + tid * Bs * 2 * C;
3203 
3204         int64_t n{0} /* batch index */, m{0} /* HxW index */;
3205         data_index_init(begin, n, Bs, m, HxW);
3206         for (const auto nhwIdx : c10::irange(begin, end)) {
3207           float* mean_ptr = buffer_ptr + n * 2 * C;
3208           float* rstd_ptr = mean_ptr + C;
3209           scalar_t* X_ptr = X_data + nhwIdx * C;
3210           scalar_t::underlying* X_ptr_underlying = reinterpret_cast<scalar_t::underlying*>(X_ptr);
3211           for (int chIdx = 0; chIdx < C; ++chIdx) {
3212             auto x = X_ptr_underlying[chIdx];
3213             mean_ptr[chIdx] += x;
3214             rstd_ptr[chIdx] += x * x;
3215           }
3216           data_index_step(n, Bs, m, HxW);
3217         }
3218       });
3219 
3220       // Step 2: Calculate mean and rstd
3221       for (const auto n : c10::irange(Bs)) {
3222         for (const auto g : c10::irange(G)) {
3223           float mean_val{0}, rstd_val{0};
3224           for (const auto t : c10::irange(num_threads)) {
3225             float* buffer_ptr = buffer_data + t * Bs * 2 * C + n * 2 * C;
3226             for (const auto d : c10::irange(channels_per_group)) {
3227               mean_val += buffer_ptr[g * channels_per_group + d];
3228               rstd_val += buffer_ptr[g * channels_per_group + d + C];
3229             } // for d
3230           } // for t
3231 
3232           // mean / scale_x + x_zp
3233           float l_mean_shifted_div_scale_x = mean_val / N;
3234           // mean / scale_x
3235           float layer_mean_div_scale_x = l_mean_shifted_div_scale_x - x_zp;
3236           // var / scale_x^2
3237           float layer_var_div_scale_x_sq =
3238               std::max(rstd_val / N -
3239               l_mean_shifted_div_scale_x * l_mean_shifted_div_scale_x, 0.0f);
3240           // scale_x / sqrt(var + eps)
3241           float scale_x_div_layer_std = x_scale /
3242               std::sqrt(layer_var_div_scale_x_sq * x_scale * x_scale + eps);
3243           mean_data[n * G + g] = layer_mean_div_scale_x;
3244           rstd_data[n * G + g] = scale_x_div_layer_std;
3245 
3246         } // for g
3247       } // for n
3248 
3249       // Step 3: Calculate scale and bias
3250       //
3251       // We could fuse step 3 and 4 into a single session but this way is better:
3252       //   a. D might be too small for vectorization;
3253       //   b. Avoid duplicate calculation of scale/bias, each HxW plain share the same scale/bias
3254       //
3255       for (const auto n : c10::irange(Bs)) {
3256         for (const auto g : c10::irange(G)) {
3257           float* scale_ptr = buffer_data + n * 2 * C;
3258           float* bias_ptr = scale_ptr + C;
3259           float mean_val = mean_data[n * G + g];
3260           float rstd_val = rstd_data[n * G + g];
3261           for (const auto d : c10::irange(channels_per_group)) {
3262             const int64_t chIdx = g * channels_per_group + d;
3263             scale_ptr[chIdx] = rstd_val * (gamma_null ? 1.0f : gamma_data[chIdx]);
3264             bias_ptr[chIdx] = -scale_ptr[chIdx] * mean_val + (beta_null ? 0.0f : beta_data[chIdx]);
3265           } // for d
3266         } // for g
3267       } // for n
3268 
3269       // step-4: apply scale and bias
3270       //
3271       // Parallel on all the outer dimensions of Bs and HxW
3272       // and vectorize on C.
3273       //
3274       at::parallel_for(0, Bs * HxW, 1, [&](int64_t begin, int64_t end) {
3275         int64_t n{0}, m{0};
3276         data_index_init(begin, n, Bs, m, HxW);
3277         for (const auto nhwIdx : c10::irange(begin, end)) {
3278           const scalar_t* X_ptr = X_data + nhwIdx * C;
3279           scalar_t* Y_ptr = Y_data + nhwIdx * C;
3280           float* scale_ptr = buffer_data + n * 2 * C;
3281           float* bias_ptr = scale_ptr + C;
3282           // Vectorized
3283           for (const auto vecIdx : c10::irange(kNumIntVecOnChannel)) {
3284             int64_t vecStartIdx = vecIdx * kIntVLen;
3285             auto qXVec = qVec::loadu(X_ptr + vecStartIdx);
3286             auto dqXVec = qXVec.dequantize(x_fake_scale_vec, x_zp_vec,
3287                   x_fake_scale_zp_neg_premul_vec);
3288             for (size_t fvecIdx = 0; fvecIdx < dqXVec.size(); ++fvecIdx) {
3289               auto scaleVec = fVec::loadu(scale_ptr + vecStartIdx + fvecIdx * kFloatVLen);
3290               auto biasVec = fVec::loadu(bias_ptr + vecStartIdx + fvecIdx * kFloatVLen);
3291               dqXVec[fvecIdx] = dqXVec[fvecIdx] * scaleVec + biasVec;
3292             }
3293             qVec::quantize(dqXVec, y_scale, y_zp, y_inv_scale)
3294                 .store(Y_ptr + vecStartIdx);
3295           }
3296           // Remaining scalar
3297           for (int64_t remIdx = kNumIntVecOnChannel * kIntVLen;
3298                remIdx < kNonVecRemOnChannel + kNumIntVecOnChannel * kIntVLen;
3299                ++remIdx) {
3300             auto qXVal = X_ptr[remIdx];
3301             float dqXVal = at::native::dequantize_val(x_fake_scale, x_zp, qXVal);
3302             float dqY = dqXVal * scale_ptr[remIdx] + bias_ptr[remIdx];
3303             Y_ptr[remIdx] = at::native::quantize_val<scalar_t>(y_scale, y_zp, dqY);
3304           }
3305 
3306           data_index_step(n, Bs, m, HxW);
3307         } // for idx on nhw
3308       }); // parallel_for on nhw
3309 
3310     } // if HxW > feature_map_threshold
3311 
3312   }); // AT_DISPATCH_QINT_TYPES
3313 }
3314 
3315 #ifdef USE_FBGEMM
3316 void quantize_tensor_per_tensor_affine_cpu(
3317     const Tensor& rtensor,
3318     Tensor& qtensor,
3319     double scale,
3320     int64_t zero_point) {
3321   AT_DISPATCH_QINT_TYPES(
3322       qtensor.scalar_type(), "quantize_tensor_per_tensor_affine_cpu", [&]() {
3323         check_tensor_memory_format(rtensor, qtensor);
3324         const float* rd = rtensor.const_data_ptr<float>();
3325         auto qd = reinterpret_cast<underlying_t*>(qtensor.data_ptr<scalar_t>());
3326         // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
3327         fbgemm::TensorQuantizationParams qparams;
3328         qparams.scale = scale;
3329         qparams.zero_point = zero_point;
3330         qparams.precision = CHAR_BIT * sizeof(underlying_t);
3331         int num_tasks = at::get_num_threads();
3332         at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
3333           for (const auto task_id : c10::irange(begin, end)) {
3334             fbgemm::Quantize<underlying_t, false /*LEGACY*/>(
3335                 // NOLINTNEXTLINE(bugprone-argument-comment)
3336                 rd, /*src=*/
3337                 // NOLINTNEXTLINE(bugprone-argument-comment)
3338                 qd, /*dst=*/
3339                 rtensor.numel(), /*len*/
3340                 // NOLINTNEXTLINE(bugprone-argument-comment)
3341                 qparams, /*qparams=*/
3342                 task_id, /*thread_id*/
3343                 num_tasks /*num_threads*/);
3344           }
3345         });
3346       });
3347 }
3348 
3349 void dequantize_tensor_per_tensor_affine_cpu(
3350     const Tensor& qtensor,
3351     Tensor& rtensor,
3352     double scale,
3353     int64_t zero_point) {
3354   AT_DISPATCH_QINT_TYPES(
3355       qtensor.scalar_type(), "dequantize_tensor_per_tensor_affine_cpu", [&]() {
3356         check_tensor_memory_format(qtensor, rtensor);
3357         const auto* qd =
3358             reinterpret_cast<const underlying_t*>(qtensor.data_ptr<scalar_t>());
3359         // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
3360         fbgemm::TensorQuantizationParams qparams;
3361         qparams.scale = scale;
3362         qparams.zero_point = zero_point;
3363         qparams.precision = CHAR_BIT * sizeof(underlying_t);
3364         float* rd = rtensor.data_ptr<float>();
3365         int num_tasks = at::get_num_threads();
3366         at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
3367           for (const auto task_id : c10::irange(begin, end)) {
3368             fbgemm::Dequantize<underlying_t>(
3369                 // NOLINTNEXTLINE(bugprone-argument-comment)
3370                 qd, /*src=*/
3371                 // NOLINTNEXTLINE(bugprone-argument-comment)
3372                 rd, /*dst=*/
3373                 // NOLINTNEXTLINE(bugprone-argument-comment)
3374                 qtensor.numel(), /*len=*/
3375                 // NOLINTNEXTLINE(bugprone-argument-comment)
3376                 qparams, /*qparams=*/
3377                 task_id, /*thread_id*/
3378                 num_tasks /*num_threads*/);
3379           }
3380         });
3381       });
3382 }
3383 #else // USE_FBGEMM
3384 
3385 #if defined(__ARM_NEON__) || defined(__aarch64__)
3386 
3387 const static int PARALLEL_THRESHOLD = 1 << 20;
3388 
3389 // Generic template defaults to naive quantize implementation
3390 template <typename T>
3391 void quantize_tensor_arm(
3392     const float* __restrict__ in,
3393     T* __restrict__ out,
3394     const int64_t N,
3395     const float scale,
3396     const int32_t zero_point) {
3397   for (const auto i : c10::irange(N)) {
3398     out[i] = at::native::quantize_val<T>(scale, zero_point, in[i]);
3399   }
3400 }
3401 
3402 namespace quantize_tensor_arm_intrinsics {
3403 template <typename Tx8>
3404 C10_ALWAYS_INLINE Tx8 vqmov(int16x8_t vraw);
3405 
3406 template <>
3407 C10_ALWAYS_INLINE uint8x8_t vqmov<uint8x8_t>(int16x8_t vraw) {
3408   return vqmovun_s16(vraw);
3409 }
3410 
3411 template <>
3412 C10_ALWAYS_INLINE int8x8_t vqmov<int8x8_t>(int16x8_t vraw) {
3413   return vqmovn_s16(vraw);
3414 }
3415 
3416 template <typename T, typename Tx8>
3417 C10_ALWAYS_INLINE void vst1(T* out, Tx8 vout);
3418 
3419 template <>
3420 C10_ALWAYS_INLINE void vst1<uint8_t, uint8x8_t>(uint8_t* out, uint8x8_t vout) {
3421   vst1_u8(out, vout);
3422 }
3423 
3424 template <>
3425 C10_ALWAYS_INLINE void vst1<int8_t, int8x8_t>(int8_t* out, int8x8_t vout) {
3426   vst1_s8(out, vout);
3427 }
3428 } // namespace quantize_tensor_arm_intrinsics
3429 
3430 // Specialized implementation from caffe2::Int8Quantize.
3431 // There may be slight accuracy difference between this and implementation of
3432 // quantize_val
3433 // TODO Update quantize_tensor_arm implementation to follow quantize_val,
3434 // i.e. f = Round(value/scale + zero_point)
3435 // TODO Make quantize_tensor_arm work for int32 datatype too.
3436 template <typename scalar_t, typename underlying_t, typename underlying_x8_t>
3437 void quantize_tensor_arm_q8(
3438     const float* __restrict__ in,
3439     scalar_t* __restrict__ out,
3440     const int64_t N,
3441     const float scale,
3442     const int32_t zero_point) {
3443   const float inv_scale = 1.0f / scale;
3444   uint32_t i = 0;
3445   underlying_t* out_underlying = reinterpret_cast<underlying_t*>(out);
3446   const float32x4_t vinv_scale = vdupq_n_f32(inv_scale);
3447 #if defined(__ARM_NEON__)
3448   // magic float and magic int to take care of rounding
3449   // int magic_round(float f): interpret_int32(f + 12582912.0f) - 0x4B400000
3450   // Some detail:
3451   // 12582912.0f is 2**23 + 2**22. The trick is based on the fact that when you
3452   // add a small number to a large number, the result rounds to the precision of
3453   // the least significant bit of the large number. For IEEE-754
3454   // single-precision number mantissa has 23 bits, and adding 2**23 would cause
3455   // rounding to the nearest even integer. The we cast to int and subtract the
3456   // same number (0x4B400000 is the integer representation of 12582912.0f) to
3457   // get only the mantissa. This works if -2**22 < x < 2**22, but preserves the
3458   // sign for negative numbers.
3459   const int32x4_t voffset = vdupq_n_s32(zero_point - 0x4B400000);
3460   const float32x4_t vmagic_float = vdupq_n_f32(12582912.0f);
3461   for (i = 0; i + 8 <= N; i += 8) {
3462     const float32x4_t vin0123 = vld1q_f32(in);
3463     in += 4;
3464     const float32x4_t vin4567 = vld1q_f32(in);
3465     in += 4;
3466     const int32x4_t vraw0123 = vaddq_s32(
3467         voffset,
3468         vreinterpretq_s32_f32(
3469             vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale))));
3470     const int32x4_t vraw4567 = vaddq_s32(
3471         voffset,
3472         vreinterpretq_s32_f32(
3473             vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale))));
3474     const int16x8_t vraw01234567 =
3475         vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567));
3476     const underlying_x8_t vout01234567 =
3477         quantize_tensor_arm_intrinsics::vqmov<underlying_x8_t>(vraw01234567);
3478     quantize_tensor_arm_intrinsics::vst1<underlying_t, underlying_x8_t>(
3479         out_underlying, vout01234567);
3480     out_underlying += 8;
3481   }
3482   for (; i < N; ++i) {
3483     (*out_underlying++) =
3484         at::native::quantize_val_arm<underlying_t>(scale, zero_point, (*in++));
3485   }
3486 #else
3487   const int16x8_t vzero_point = vdupq_n_s16((int16_t)(uint16_t)zero_point);
3488   for (i = 0; i + 8 <= N; i += 8) {
3489     const float32x4_t vin0123 = vld1q_f32(in);
3490     in += 4;
3491     const float32x4_t vin4567 = vld1q_f32(in);
3492     in += 4;
3493     const int32x4_t v0123_rounded = vcvtnq_s32_f32(vmulq_f32(vin0123, vinv_scale));
3494     const int32x4_t v4567_rounded = vcvtnq_s32_f32(vmulq_f32(vin4567, vinv_scale));
3495     const int16x8_t v01234567_packed = vqaddq_s16(
3496         vqmovn_high_s32(vqmovn_s32(v0123_rounded), v4567_rounded), vzero_point);
3497     const underlying_x8_t vout01234567 =
3498         quantize_tensor_arm_intrinsics::vqmov<underlying_x8_t>(
3499             v01234567_packed);
3500     quantize_tensor_arm_intrinsics::vst1<underlying_t, underlying_x8_t>(
3501         out_underlying, vout01234567);
3502     out_underlying += 8;
3503   }
3504   for (; i < N; ++i) {
3505     (*out_underlying++) =
3506         at::native::quantize_val_arm<underlying_t>(scale, zero_point, (*in++));
3507   }
3508 #endif
3509 }
3510 
3511 template <>
3512 void quantize_tensor_arm<c10::quint8>(
3513     const float* __restrict__ in,
3514     c10::quint8* __restrict__ out,
3515     const int64_t N,
3516     const float scale,
3517     const int32_t zero_point) {
3518   quantize_tensor_arm_q8<c10::quint8, uint8_t, uint8x8_t>(
3519       in, out, N, scale, zero_point);
3520 }
3521 
3522 template <>
3523 void quantize_tensor_arm<c10::qint8>(
3524     const float* __restrict__ in,
3525     c10::qint8* __restrict__ out,
3526     const int64_t N,
3527     const float scale,
3528     const int32_t zero_point) {
3529   quantize_tensor_arm_q8<c10::qint8, int8_t, int8x8_t>(
3530       in, out, N, scale, zero_point);
3531 }
3532 
3533 #if defined(__aarch64__)
3534 #define VMOVL_HIGH_U8(x) vmovl_high_u8(x)
3535 #define VMOVL_HIGH_S8(x) vmovl_high_s8(x)
3536 #define VMOVL_HIGH_U16(x) vmovl_high_u16(x)
3537 #define VMOVL_HIGH_S16(x) vmovl_high_s16(x)
3538 #else // vmovl_high intrinsic not supported
3539 #define VMOVL_HIGH_U8(x) vmovl_u8(vget_high_u8(x))
3540 #define VMOVL_HIGH_S8(x) vmovl_s8(vget_high_s8(x))
3541 #define VMOVL_HIGH_U16(x) vmovl_u16(vget_high_u16(x))
3542 #define VMOVL_HIGH_S16(x) vmovl_s16(vget_high_s16(x))
3543 #endif
3544 
3545 // Generic template defaults to naive dequantize implementation
3546 template <typename T>
3547 void dequantize_tensor_arm(
3548     const T* __restrict__ in,
3549     float* __restrict__ out,
3550     const int64_t N,
3551     const float scale,
3552     const int32_t zero_point) {
3553   for (int i = 0; i < N; ++i) {
3554     out[i] = dequantize_val<T>(scale, zero_point, in[i]);
3555   }
3556 }
3557 
3558 template <>
3559 void dequantize_tensor_arm<c10::qint8>(
3560     const c10::qint8* __restrict__ in,
3561     float* __restrict__ out,
3562     const int64_t N,
3563     const float scale,
3564     const int32_t zero_point) {
3565   const int8_t* in_underlying = reinterpret_cast<const int8_t*>(in);
3566 
3567   const float32x4_t scale_fp32x4 = vdupq_n_f32(scale);
3568   // Zero point is restricted to be in bounds of a signed 8 bit integer
3569   const int8x8_t zero_point_s8x8 = vget_low_s8(vdupq_n_s8(static_cast<int8_t>(zero_point)));
3570 
3571   int i;
3572   for (i = 0; i + 16 <= N; i += 16) {
3573     const int8x16_t vin_s8 = vld1q_s8(in_underlying);
3574 
3575     // Extract upper or lower values to int16x8 and subtract zero point
3576     // Each input element and the zero point are restricted to be in bounds of
3577     // a signed 8 bit integer, so the difference will fit in a signed 16 bit
3578     // integer
3579     const int16x8_t minus_zp_low_s16 = vsubl_s8(vget_low_s8(vin_s8), zero_point_s8x8); // 0 ... 7
3580     const int16x8_t minus_zp_high_s16 = vsubl_s8(vget_high_s8(vin_s8), zero_point_s8x8); // 8 ... 15
3581 
3582     const int32x4_t minus_zp_low_low = vmovl_s16(vget_low_s16(minus_zp_low_s16)); // 0 ... 3
3583     const int32x4_t minus_zp_low_high = VMOVL_HIGH_S16(minus_zp_low_s16); // 4 ... 7
3584     const int32x4_t minus_zp_high_low = vmovl_s16(vget_low_s16(minus_zp_high_s16)); // 8 ... 11
3585     const int32x4_t minus_zp_high_high = VMOVL_HIGH_S16(minus_zp_high_s16); // 12 ... 15
3586 
3587     // Store            * scale   int32->fp32
3588     vst1q_f32(out,      vmulq_f32(vcvtq_f32_s32(minus_zp_low_low), scale_fp32x4));
3589     vst1q_f32(out + 4,  vmulq_f32(vcvtq_f32_s32(minus_zp_low_high), scale_fp32x4));
3590     vst1q_f32(out + 8,  vmulq_f32(vcvtq_f32_s32(minus_zp_high_low), scale_fp32x4));
3591     vst1q_f32(out + 12, vmulq_f32(vcvtq_f32_s32(minus_zp_high_high), scale_fp32x4));
3592 
3593     out += 16;
3594     in += 16;
3595     in_underlying += 16;
3596   }
3597 
3598   for (; i < N; ++i) { // use default dequantize for remaining vals
3599     (*out++) = dequantize_val<c10::qint8>(scale, zero_point, (*in++));
3600   }
3601 }
3602 
3603 template <>
3604 void dequantize_tensor_arm<c10::quint8>(
3605     const c10::quint8* __restrict__ in,
3606     float* __restrict__ out,
3607     const int64_t N,
3608     const float scale,
3609     const int32_t zero_point) {
3610   const uint8_t* in_underlying = reinterpret_cast<const uint8_t*>(in);
3611 
3612   const float32x4_t scale_fp32x4 = vdupq_n_f32(scale);
3613   // Zero point is restricted to be in bounds of an unsigned 8 bit integer
3614   const uint8x8_t zero_point_u8x8 = vget_low_u8(vdupq_n_u8(static_cast<uint8_t>(zero_point)));
3615 
3616   int i;
3617   for (i = 0; i + 16 <= N; i += 16) {
3618     const uint8x16_t vin_u8 = vld1q_u8(in_underlying);
3619 
3620     // Extract upper or lower values to uint16x8 and subtract zero point
3621     // Each input element and the zero point are restricted to be in bounds of
3622     // an unsigned 8 bit integer, so the difference will fit in a signed 16 bit
3623     // integer
3624     const int16x8_t minus_zp_low_s16 = vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(vin_u8), zero_point_u8x8)); // 0 ... 7
3625     const int16x8_t minus_zp_high_s16 = vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(vin_u8), zero_point_u8x8)); // 8 ... 15
3626 
3627     const int32x4_t minus_zp_low_low = vmovl_s16(vget_low_s16(minus_zp_low_s16)); // 0 ... 3
3628     const int32x4_t minus_zp_low_high = VMOVL_HIGH_S16(minus_zp_low_s16); // 4 ... 7
3629     const int32x4_t minus_zp_high_low = vmovl_s16(vget_low_s16(minus_zp_high_s16)); // 8 ... 11
3630     const int32x4_t minus_zp_high_high = VMOVL_HIGH_S16(minus_zp_high_s16); // 12 ... 15
3631 
3632     // Store            * scale   int32->fp32
3633     vst1q_f32(out,      vmulq_f32(vcvtq_f32_s32(minus_zp_low_low), scale_fp32x4));
3634     vst1q_f32(out + 4,  vmulq_f32(vcvtq_f32_s32(minus_zp_low_high), scale_fp32x4));
3635     vst1q_f32(out + 8,  vmulq_f32(vcvtq_f32_s32(minus_zp_high_low), scale_fp32x4));
3636     vst1q_f32(out + 12, vmulq_f32(vcvtq_f32_s32(minus_zp_high_high), scale_fp32x4));
3637 
3638     out += 16;
3639     in += 16;
3640     in_underlying += 16;
3641   }
3642 
3643   for (; i < N; ++i) { // use default dequantize for remaining vals
3644     (*out++) = dequantize_val<c10::quint8>(scale, zero_point, (*in++));
3645   }
3646 }
3647 
3648 #endif // defined(__ARM_NEON__) || defined(__aarch64__)
3649 
3650 void quantize_tensor_per_tensor_affine_cpu(
3651     const Tensor& rtensor,
3652     Tensor& qtensor,
3653     double scale,
3654     int64_t zero_point) {
3655   check_tensor_memory_format(rtensor, qtensor);
3656   const float* rdata = rtensor.const_data_ptr<float>();
3657   int numel = rtensor.numel();
3658 #if defined(__ARM_NEON__) || defined(__aarch64__)
3659   AT_DISPATCH_QINT_TYPES(
3660       qtensor.scalar_type(), "quantize_tensor_per_tensor_affine_cpu", [&]() {
3661         scalar_t* qdata = qtensor.data_ptr<scalar_t>();
3662         auto quantize_range = [&](int64_t begin, int64_t end) {
3663           quantize_tensor_arm<scalar_t>(
3664             rdata + begin, qdata + begin, end - begin, scale, zero_point);
3665         };
3666         if (numel >= PARALLEL_THRESHOLD) {
3667           at::parallel_for(0, numel, 1, quantize_range);
3668         } else {
3669           quantize_range(0, numel);
3670         }
3671       });
3672 #else
3673   // Fallback path
3674   AT_DISPATCH_QINT_TYPES(
3675       qtensor.scalar_type(), "quantize_tensor_per_tensor_affine_cpu", [&]() {
3676         scalar_t* qdata = qtensor.data_ptr<scalar_t>();
3677         for (const auto i : c10::irange(numel)) {
3678           qdata[i] = quantize_val<scalar_t>(scale, zero_point, rdata[i]);
3679         }
3680       });
3681 #endif // defined(__ARM_NEON__) || defined(__aarch64__)
3682 }
3683 
3684 void dequantize_tensor_per_tensor_affine_cpu(
3685     const Tensor& qtensor,
3686     Tensor& rtensor,
3687     double scale,
3688     int64_t zero_point) {
3689   check_tensor_memory_format(qtensor, rtensor);
3690   float* rdata = rtensor.data_ptr<float>();
3691   int numel = qtensor.numel();
3692 #if defined(__ARM_NEON__) || defined(__aarch64__)
3693   AT_DISPATCH_QINT_TYPES(
3694       qtensor.scalar_type(), "dequantize_tensor_per_tensor_affine_cpu", [&]() {
3695         const scalar_t* qdata = qtensor.const_data_ptr<scalar_t>();
3696         auto dequantize_range = [&](int64_t begin, int64_t end) {
3697           dequantize_tensor_arm<scalar_t>(
3698             qdata + begin, rdata + begin, end - begin, scale, zero_point);
3699         };
3700         if (numel >= PARALLEL_THRESHOLD) {
3701           at::parallel_for(0, numel, 1, dequantize_range);
3702         } else {
3703           dequantize_range(0, numel);
3704         }
3705       });
3706 #else
3707   // Fallback path
3708   AT_DISPATCH_QINT_TYPES(
3709       qtensor.scalar_type(), "dequantize_tensor_per_tensor_affine_cpu", [&]() {
3710         const scalar_t* qdata = qtensor.const_data_ptr<scalar_t>();
3711         for (const auto i : c10::irange(numel)) {
3712           rdata[i] = dequantize_val<scalar_t>(scale, zero_point, qdata[i]);
3713         }
3714       });
3715 #endif // defined(__ARM_NEON__) || defined(__aarch64__)
3716 }
3717 #endif // USE_FBGEMM
3718 
3719 // TODO: add fbgemm for per channel
3720 // Generic template defaults to naive quantize implementation
3721 template <typename T>
3722 void quantize_tensor_per_channel_impl(
3723     const Tensor& rtensor,
3724     Tensor& qtensor,
3725     const Tensor& scales,
3726     const Tensor& zero_points,
3727     int64_t axis) {
3728   // TODO: channels last kernel can be made faster.
3729   // For contiguous tensors, e.g. NCHW, arbitrary axis can be used.
3730   // For channels_last/3d however axis == 0 or 1.
3731   // Since current implementation on channels_last format does not
3732   // cover per channel quant with arbitrary axis value, it is better
3733   // to check and fail.
3734   int64_t batches = size_to_dim_(axis, rtensor.sizes());
3735   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
3736   int64_t elements_per_channel = size_from_dim_(axis + 1, rtensor.sizes());
3737   int64_t channels = rtensor.size(axis);
3738   auto scales_data = scales.data_ptr<double>();
3739   auto zero_points_data = zero_points.data_ptr<int64_t>();
3740   const float* in = rtensor.const_data_ptr<float>();
3741   auto out = qtensor.data_ptr<T>();
3742   if (axis == 1 &&
3743       (rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
3744        rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
3745     // This code handles per channel quant when axis = 1 and
3746     // channels_last contig.
3747     // If axis = 0 and channels_last contig, implementation for channels
3748     // first (NCHW) works.
3749     for (const auto b : c10::irange(batches)) {
3750       for (const auto e : c10::irange(elements_per_channel)) {
3751         for (const auto c : c10::irange(channels)) {
3752           auto i = b * channels * elements_per_channel + e * channels + c;
3753           out[i] = at::native::quantize_val<T>(
3754               scales_data[c], zero_points_data[c], in[i]);
3755         }
3756       }
3757     }
3758   } else {
3759     for (const auto b : c10::irange(batches)) {
3760       for (const auto c : c10::irange(channels)) {
3761         for (const auto e : c10::irange(elements_per_channel)) {
3762           auto i = b * channels * elements_per_channel +
3763               c * elements_per_channel + e;
3764           out[i] = at::native::quantize_val<T>(
3765               scales_data[c], zero_points_data[c], in[i]);
3766         }
3767       }
3768     }
3769   }
3770 }
3771 
3772 #if defined(__ARM_NEON__) || defined(__aarch64__)
3773 // Specialized implementation from caffe2::Int8Quantize.
3774 // There may be slight accuracy difference between this and implementation of
3775 // quantize_val
3776 // TODO Update quantize_tensor_per_channel_impl implementation to follow
3777 // quantize_val, i.e. f = Round(value/scale + zero_point)
3778 // TODO Make quantize_tensor_per_channel_impl work for other datatypes too
3779 // (int8, int32).
3780 template <>
3781 void quantize_tensor_per_channel_impl<c10::quint8>(
3782     const Tensor& rtensor,
3783     Tensor& qtensor,
3784     const Tensor& scales,
3785     const Tensor& zero_points,
3786     int64_t axis) {
3787   int64_t batches = size_to_dim_(axis, rtensor.sizes());
3788   int64_t elements_per_channel = size_from_dim_(axis + 1, rtensor.sizes());
3789   int64_t channels = rtensor.size(axis);
3790   auto scales_data = scales.data_ptr<double>();
3791   auto zero_points_data = zero_points.data_ptr<int64_t>();
3792   const float* in = rtensor.const_data_ptr<float>();
3793   auto out = (uint8_t*)qtensor.data_ptr<c10::quint8>();
3794 #if defined(__ARM_NEON__)
3795   // magic float and magic int to take care of rounding
3796   // int magic_round(float f): interpret_int32(f + 12582912.0f) - 0x4B400000
3797   // Some detail:
3798   // 12582912.0f is 2**23 + 2**22. The trick is based on the fact that when you
3799   // add a small number to a large number, the result rounds to the precision of
3800   // the least significant bit of the large number. For IEEE-754
3801   // single-precision number mantissa has 23 bits, and adding 2**23 would cause
3802   // rounding to the nearest even integer. The we cast to int and subtract the
3803   // same number (0x4B400000 is the integer representation of 12582912.0f) to
3804   // get only the mantissa. This works if -2**22 < x < 2**22, but preserves the
3805   // sign for negative numbers.
3806   const float32x4_t vmagic_float = vdupq_n_f32(12582912.0f);
3807   // Copy reciprocal of scales (double) into float array
3808   // Copy zero_points with magic int (int64_t) into int32_t array
3809   std::vector<float> inv_scales(channels);
3810   std::vector<int32_t> zero_points_int32t(channels);
3811   for (const auto i : c10::irange(channels)) {
3812     inv_scales[i] = 1.0f / (float)scales_data[i];
3813     zero_points_int32t[i] = (int32_t)(uint32_t)zero_points_data[i] - 0x4B400000;
3814   }
3815   if (axis == 1 &&
3816       (rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
3817        rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
3818     // This code handles per channel quant when axis = 1 and
3819     // channels_last contig.
3820     // If axis = 0 and channels_last contig, implementation for channels
3821     // first (NCHW) works.
3822     for (C10_UNUSED const auto b : c10::irange(batches)) {
3823       for (C10_UNUSED const auto e : c10::irange(elements_per_channel)) {
3824         uint32_t c = 0;
3825         while (c + 8 < channels) {
3826           const int32x4_t voffset0123 = vld1q_s32(&zero_points_int32t[c]);
3827           const float32x4_t vinv_scale0123 = vld1q_f32(&inv_scales[c]);
3828           c += 4;
3829           const int32x4_t voffset4567 = vld1q_s32(&zero_points_int32t[c]);
3830           const float32x4_t vinv_scale4567 = vld1q_f32(&inv_scales[c]);
3831           c += 4;
3832           const float32x4_t vin0123 = vld1q_f32(in);
3833           in += 4;
3834           const float32x4_t vin4567 = vld1q_f32(in);
3835           in += 4;
3836           const int32x4_t vraw0123 = vaddq_s32(
3837               voffset0123,
3838               vreinterpretq_s32_f32(
3839                   vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale0123))));
3840           const int32x4_t vraw4567 = vaddq_s32(
3841               voffset4567,
3842               vreinterpretq_s32_f32(
3843                   vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale4567))));
3844           const int16x8_t vraw01234567 =
3845               vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567));
3846           const uint8x8_t vout01234567 = vqmovun_s16(vraw01234567);
3847           vst1_u8(out, vout01234567);
3848           out += 8;
3849         }
3850         for (; c < channels; ++c) {
3851           (*out++) = at::native::quantize_val_arm<uint8_t>(
3852               scales_data[c], zero_points_data[c], (*in++));
3853         }
3854       }
3855     }
3856   } else {
3857     for (C10_UNUSED const auto b : c10::irange(batches)) {
3858       for (const auto c : c10::irange(channels)) {
3859         uint32_t e = 0;
3860         const int32x4_t voffset = vdupq_n_s32(zero_points_int32t[c]);
3861         const float32x4_t vinv_scale = vdupq_n_f32(inv_scales[c]);
3862         for (; e + 8 < elements_per_channel; e += 8) {
3863           const float32x4_t vin0123 = vld1q_f32(in);
3864           in += 4;
3865           const float32x4_t vin4567 = vld1q_f32(in);
3866           in += 4;
3867           const int32x4_t vraw0123 = vaddq_s32(
3868               voffset,
3869               vreinterpretq_s32_f32(
3870                   vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale))));
3871           const int32x4_t vraw4567 = vaddq_s32(
3872               voffset,
3873               vreinterpretq_s32_f32(
3874                   vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale))));
3875           const int16x8_t vraw01234567 =
3876               vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567));
3877           const uint8x8_t vout01234567 = vqmovun_s16(vraw01234567);
3878           vst1_u8(out, vout01234567);
3879           out += 8;
3880         }
3881         for (; e < elements_per_channel; ++e) {
3882           (*out++) = at::native::quantize_val_arm<uint8_t>(
3883               scales_data[c], zero_points_data[c], (*in++));
3884         }
3885       }
3886     }
3887   }
3888 #else // defined(__ARM_NEON__)
3889   // Copy scales (double) into float array
3890   // Copy zero_points (int64_t) into int16_t array
3891   std::vector<float> inv_scales(channels);
3892   std::vector<int16_t> zero_points_int16t(channels);
3893   for (const auto i : c10::irange(channels)) {
3894     inv_scales[i] = 1.0f / (float)scales_data[i];
3895     zero_points_int16t[i] = (int16_t)(uint16_t)zero_points_data[i];
3896   }
3897   if (axis == 1 &&
3898       (rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
3899        rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
3900     // This code handles per channel quant when axis = 1 and
3901     // channels_last contig.
3902     // If axis = 0 and channels_last contig, implementation for channels
3903     // first (NCHW) works.
3904     for (const auto b C10_UNUSED : c10::irange(batches)) {
3905       for (const auto e C10_UNUSED : c10::irange(elements_per_channel)) {
3906         uint32_t c = 0;
3907         while (c + 8 < channels) {
3908           const int16x8_t vzero_point = vld1q_s16(&zero_points_int16t[c]);
3909           const float32x4_t vinv_scale0123 = vld1q_f32(&inv_scales[c]);
3910           c += 4;
3911           const float32x4_t vinv_scale4567 = vld1q_f32(&inv_scales[c]);
3912           c += 4;
3913           const float32x4_t vin0123 = vld1q_f32(in);
3914           in += 4;
3915           const float32x4_t vin4567 = vld1q_f32(in);
3916           in += 4;
3917           const int32x4_t v0123_rounded =
3918               vcvtnq_s32_f32(vmulq_f32(vin0123, vinv_scale0123));
3919           const int32x4_t v4567_rounded =
3920               vcvtnq_s32_f32(vmulq_f32(vin4567, vinv_scale4567));
3921           const int16x8_t v01234567_packed = vqaddq_s16(
3922               vqmovn_high_s32(vqmovn_s32(v0123_rounded), v4567_rounded),
3923               vzero_point);
3924           const uint8x8_t vout01234567 = vqmovun_s16(v01234567_packed);
3925           vst1_u8(out, vout01234567);
3926           out += 8;
3927         }
3928         for (; c < channels; ++c) {
3929           (*out++) = at::native::quantize_val_arm<uint8_t>(
3930               scales_data[c], zero_points_data[c], (*in++));
3931         }
3932       }
3933     }
3934   } else {
3935     for (const auto b C10_UNUSED : c10::irange(batches)) {
3936       for (const auto c C10_UNUSED : c10::irange(channels)) {
3937         uint32_t e = 0;
3938         const int16x8_t vzero_point = vdupq_n_s16(zero_points_int16t[c]);
3939         const float32x4_t vinv_scale = vdupq_n_f32(inv_scales[c]);
3940         for (; e + 8 < elements_per_channel; e += 8) {
3941           const float32x4_t vin0123 = vld1q_f32(in);
3942           in += 4;
3943           const float32x4_t vin4567 = vld1q_f32(in);
3944           in += 4;
3945           const int32x4_t v0123_rounded =
3946               vcvtnq_s32_f32(vmulq_f32(vin0123, vinv_scale));
3947           const int32x4_t v4567_rounded =
3948               vcvtnq_s32_f32(vmulq_f32(vin4567, vinv_scale));
3949           const int16x8_t v01234567_packed = vqaddq_s16(
3950               vqmovn_high_s32(vqmovn_s32(v0123_rounded), v4567_rounded),
3951               vzero_point);
3952           const uint8x8_t vout01234567 = vqmovun_s16(v01234567_packed);
3953           vst1_u8(out, vout01234567);
3954           out += 8;
3955         }
3956         for (; e < elements_per_channel; ++e) {
3957           (*out++) = at::native::quantize_val_arm<uint8_t>(
3958               scales_data[c], zero_points_data[c], (*in++));
3959         }
3960       }
3961     }
3962   }
3963 #endif // defined(__ARM_NEON__)
3964 }
3965 #endif // defined(__ARM_NEON__) || defined(__aarch64__)
3966 
3967 void quantize_tensor_per_channel_affine_cpu(
3968     const Tensor& rtensor,
3969     Tensor& qtensor,
3970     const Tensor& scales,
3971     const Tensor& zero_points,
3972     int64_t axis) {
3973   TORCH_CHECK(
3974       rtensor.is_contiguous() || (axis <= 1),
3975       "If tensor is channels_last contig then per channel quantization "
3976       "is supported only for axis = 0 or 1.");
3977   AT_DISPATCH_QINT_TYPES(
3978       qtensor.scalar_type(), "quantize_tensor_per_channel_affine_cpu", [&]() {
3979         check_tensor_memory_format(rtensor, qtensor);
3980         quantize_tensor_per_channel_impl<scalar_t>(
3981             rtensor, qtensor, scales, zero_points, axis);
3982       });
3983 }
3984 
3985 template<typename T, typename N, typename Q>
3986 void dequantize_per_channel_affine_kernel(
3987       const Tensor& qtensor,
3988       Tensor& rtensor,
3989       const Tensor& scales,
3990       const Tensor& zero_points,
3991       int64_t axis,
3992       int bit_width=8) {
3993 
3994   // For contiguous tensors, e.g. NCHW, arbitrary axis can be used.
3995   // For channels_last/3d however axis == 0 or 1.
3996   // Since current implementation on channels_last format does not
3997   // cover per channel quant with arbitrary axis value, it is better
3998   // to check and fail.
3999   TORCH_CHECK(rtensor.is_contiguous() || (axis <=1),
4000       "If tensor is channels_last contig then per channel quantization "
4001       "is supported only for axis = 0 or 1.");
4002   int64_t batches = size_to_dim_(axis, rtensor.sizes());
4003   int64_t elements_per_channel =
4004       // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
4005       size_from_dim_(axis + 1, rtensor.sizes());
4006   int64_t channel = rtensor.size(axis);
4007   auto scales_data = scales.data_ptr<T>();
4008   auto zero_points_data = zero_points.data_ptr<N>();
4009   check_tensor_memory_format(qtensor, rtensor);
4010   const auto* qd = qtensor.const_data_ptr<Q>();
4011   float* rd = rtensor.data_ptr<float>();
4012   const auto elem_per_byte = 8 / bit_width;
4013   if (axis == 1 && (rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
4014       rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
4015     for (const auto b : c10::irange(batches)) {
4016       for (const auto e : c10::irange(elements_per_channel)) {
4017         for (const auto c : c10::irange(channel)) {
4018           auto i = b * channel * elements_per_channel + e * channel + c;
4019           // We need to convert the qint8 value to float to ensure the
4020           // subtraction subexpression returns a float
4021           auto qvalue = qd[i / elem_per_byte].val_;
4022           if (bit_width < 8) {
4023             qvalue >>= (i % elem_per_byte) * bit_width;
4024             qvalue &= (1 << bit_width) - 1;
4025           }
4026           rd[i] = (static_cast<float>(qvalue) - zero_points_data[c]) * scales_data[c];
4027         }
4028       }
4029     }
4030   } else {
4031     for (const auto b : c10::irange(batches)) {
4032       for (const auto c : c10::irange(channel)) {
4033         for (const auto e : c10::irange(elements_per_channel)) {
4034           auto i = b * channel * elements_per_channel +
4035               c * elements_per_channel + e;
4036           // We need to convert the qint8 value to float to ensure the
4037           // subtraction subexpression returns a float
4038           auto qvalue = qd[i / elem_per_byte].val_;
4039           if (bit_width < 8) {
4040             qvalue >>= (i % elem_per_byte) * bit_width;
4041             qvalue &= (1 << bit_width) - 1;
4042           }
4043           rd[i] = (static_cast<float>(qvalue) - zero_points_data[c]) * scales_data[c];
4044         }
4045       }
4046     }
4047   }
4048 }
4049 
4050 void dequantize_tensor_per_channel_affine_cpu(
4051     const Tensor& qtensor,
4052     Tensor& rtensor,
4053     const Tensor& scales,
4054     const Tensor& zero_points,
4055     int64_t axis) {
4056   AT_DISPATCH_QINT_TYPES(
4057       qtensor.scalar_type(), "dequantize_tensor_per_channel_affine_cpu", [&]() {
4058         dequantize_per_channel_affine_kernel<double, int64_t, scalar_t>(qtensor, rtensor, scales, zero_points, axis);
4059       });
4060 }
4061 
4062 // quantize stubs for floating point scale and zero_point.
4063 void quantize_tensor_per_channel_float_qparams_cpu(
4064     const Tensor& rtensor,
4065     Tensor& qtensor,
4066     const Tensor& scales,
4067     const Tensor& zero_points,
4068     int64_t axis) {
4069   // For contiguous tensors, e.g. NCHW, arbitrary axis can be used.
4070   // For channels_last/3d however axis == 0 or 1.
4071   // Since current implementation on channels_last format does not
4072   // cover per channel quant with arbitrary axis value, it is better
4073   // to check and fail.
4074   TORCH_CHECK(rtensor.is_contiguous() || (axis <=1),
4075       "If tensor is channels_last contig then per channel quantization "
4076       "is supported only for axis = 0 or 1.");
4077   AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(
4078       qtensor.scalar_type(), "quantize_tensor_per_channel_float_qparams_cpu", [&]() {
4079         int64_t batches = size_to_dim_(axis, rtensor.sizes());
4080         int64_t elements_per_channel =
4081             size_from_dim_(axis + 1, rtensor.sizes());
4082         int64_t channel = rtensor.size(axis);
4083         auto scales_data = scales.data_ptr<float>();
4084         auto zero_points_data = zero_points.data_ptr<float>();
4085         check_tensor_memory_format(rtensor, qtensor);
4086         const float* rdata = rtensor.const_data_ptr<float>();
4087         auto qdata = reinterpret_cast<underlying_t*>(qtensor.data_ptr<scalar_t>());
4088         const auto elem_per_byte = CHAR_BIT / bit_width;
4089         int qvalue = 0;
4090         if (axis == 1 && (rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
4091             rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
4092           for (const auto b : c10::irange(batches)) {
4093             for (const auto e : c10::irange(elements_per_channel)) {
4094               for (const auto c : c10::irange(channel)) {
4095                 auto i = b * channel * elements_per_channel + e * channel + c;
4096                 qvalue = quantize_val_float_qparams(
4097                     scales_data[c], zero_points_data[c], rdata[i], quant_min, quant_max);
4098                 if (i % elem_per_byte == 0) {
4099                   qdata[i / elem_per_byte] = static_cast<underlying_t>(qvalue);
4100                 } else {
4101                   qdata[i / elem_per_byte] |= static_cast<underlying_t>(qvalue << ((i % elem_per_byte) * bit_width));
4102                 }
4103               }
4104             }
4105           }
4106         } else {
4107           for (const auto b : c10::irange(batches)) {
4108             for (const auto c : c10::irange(channel)) {
4109               for (const auto e : c10::irange(elements_per_channel)) {
4110                 auto i = b * channel * elements_per_channel +
4111                     c * elements_per_channel + e;
4112                 qvalue = quantize_val_float_qparams(
4113                     scales_data[c], zero_points_data[c], rdata[i], quant_min, quant_max);
4114                 if (i % elem_per_byte == 0) {
4115                   qdata[i / elem_per_byte] = static_cast<underlying_t>(qvalue);
4116                 } else {
4117                   qdata[i / elem_per_byte] |= static_cast<underlying_t>(qvalue << ((i % elem_per_byte) * bit_width));
4118                 }
4119               }
4120             }
4121           }
4122         }
4123       });
4124 }
4125 
4126 void dequantize_tensor_per_channel_float_qparams_cpu(
4127     const Tensor& qtensor,
4128     Tensor& rtensor,
4129     const Tensor& scales,
4130     const Tensor& zero_points,
4131     int64_t axis) {
4132   AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(
4133       qtensor.scalar_type(), "dequantize_tensor_per_channel_float_qparams_cpu", [&]() {
4134         dequantize_per_channel_affine_kernel<float, float, scalar_t>(qtensor, rtensor, scales, zero_points, axis, bit_width);
4135       });
4136 }
4137 
4138 void quantize_tensor_per_tensor_affine_sub_byte_cpu(
4139     const Tensor& rtensor,
4140     Tensor& qtensor,
4141     float scale,
4142     float zero_point) {
4143   // TODO Use fbgemm kernel to pack values
4144   AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(
4145     qtensor.scalar_type(), "quantize_tensor_per_tensor_affine_sub_byte_cpu", [&]() {
4146       check_tensor_memory_format(rtensor, qtensor);
4147       const float* const rdata = rtensor.const_data_ptr<float>();
4148       auto qdata = reinterpret_cast<underlying_t*>(qtensor.data_ptr<scalar_t>());
4149       auto numel = rtensor.numel();
4150       const auto elem_per_byte = CHAR_BIT / bit_width;
4151       for (const auto i : c10::irange(numel)) {
4152         float inv_scale = scale == 0 ? 1.0f : 1.0f / scale;
4153         int64_t qvalue = lrintf(std::nearbyint(rdata[i] * inv_scale) + zero_point);
4154         qvalue = std::max(quant_min, std::min(qvalue, quant_max));
4155 
4156         // We pack sub_byte values and align them to a byte.
4157         // Eg. for 4-bits Index 0 is packed in the lower 4-bits
4158         // and index 1 is packed in the upper 4-bits.
4159         if (i % elem_per_byte == 0) {
4160           qdata[i / elem_per_byte] = static_cast<underlying_t>(qvalue);
4161         } else {
4162           qdata[i / elem_per_byte] |= static_cast<underlying_t>(qvalue << ((i % elem_per_byte) * bit_width));
4163         }
4164       } // for numel
4165     });
4166 }
4167 
4168 void dequantize_tensor_per_tensor_affine_sub_byte_cpu(
4169     const Tensor& qtensor,
4170     Tensor& rtensor,
4171     float scale,
4172     float zero_point) {
4173   // TODO Use fbgemm kernel to pack values
4174   AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(
4175     qtensor.scalar_type(), "dequantize_tensor_per_tensor_affine_sub_byte_cpu", [&]() {
4176       check_tensor_memory_format(rtensor, qtensor);
4177       auto rdata = rtensor.data_ptr<float>();
4178       const underlying_t* qdata = reinterpret_cast<const underlying_t*>(qtensor.const_data_ptr<scalar_t>());
4179       auto numel = rtensor.numel();
4180       const auto elem_per_byte = CHAR_BIT / bit_width;
4181 
4182       for (const auto i : c10::irange(numel)) {
4183         underlying_t qvalue = qdata[i / elem_per_byte];
4184         qvalue >>= (i % elem_per_byte) * bit_width;
4185         qvalue &= (1 << bit_width) - 1;
4186         rdata[i] = (static_cast<float>(qvalue) - zero_point) * scale;
4187       }
4188   });
4189 }
4190 
4191 // This function expects quantized_val input to already be quantized
4192 template <typename scalar_t>
4193 void cpu_masked_fill_kernel_quantized_cpu(TensorIterator& iter, scalar_t quantized_val) {
4194   auto loop = [&](char** data, const int64_t* strides, int64_t n) {
4195     char* dst = data[0];
4196     char* mask = data[1];
4197     for (const auto i : c10::irange(n)) {
4198       bool mask_value = *reinterpret_cast<bool*>(mask + strides[1] * i);
4199 
4200       if (mask_value) {
4201         *(scalar_t*)(dst + strides[0] * i) = quantized_val;
4202       }
4203     }
4204   };
4205   iter.for_each(loop);
4206 }
4207 
4208 void masked_fill_kernel_quantized_cpu(TensorIterator& iter, const Scalar& value, double scale, int zero_point) {
4209   AT_DISPATCH_QINT_TYPES(iter.dtype(), "masked_fill", [&] {
4210     float float_val = value.to<float>();
4211     auto quantized_val = quantize_val<scalar_t>(scale, zero_point, float_val);
4212     auto mask_dtype = iter.input_dtype(0);
4213     TORCH_CHECK(mask_dtype == ScalarType::Bool, "masked_fill only supports boolean masks, "
4214       "but got mask with dtype ", mask_dtype);
4215     cpu_masked_fill_kernel_quantized_cpu<scalar_t>(iter, quantized_val);
4216   });
4217 }
4218 
4219 // currently, we do not support accumulate=True for quantized tensors. We throw an exception in _index_put_impl_quantized_cpu_
4220 void index_put_kernel_quantized_cpu(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate, double scale, int zero_point) {
4221   // NOTE: duplicate indices are only supported if accumulate is true.
4222   AT_DISPATCH_QINT_TYPES(iter.dtype(), "index_put", [&] {
4223     // See Note [Enabling Deterministic Operations]
4224     // Parallel cpu_index_kernel with accumulation is nondeterministic, so we
4225     // must enable serial execution if deterministic algorithms are enabled.
4226     const bool is_deterministic = at::globalContext().deterministicAlgorithms();
4227     at::native::cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [scale, zero_point](char* dst, char* src, int64_t offset) {
4228       *(scalar_t*)(dst + offset) = quantize_val<scalar_t>(scale, zero_point, *(float*)src);
4229     }, /*serial_execution=*/is_deterministic);
4230   });
4231 }
4232 } // anonymous namespace
4233 
4234 // Some quantization tests are flaky on Windows with AVX512. If --continue-through-error
4235 // is used, only one fails. But if the failing test is skipped, another one fails.
4236 // If the second test is also skipped, a third one fails.
4237 // So, until Quantization support for Windows is fixed for AVX512,
4238 // AVX2 kernels would be used instead. Ref: GH 56992.
4239 #if defined(_WIN32)
4240 REGISTER_DISPATCH(dequantize_tensor_per_channel_affine_stub,
4241                   &dequantize_tensor_per_channel_affine_cpu);
4242 REGISTER_DISPATCH(dequantize_tensor_per_channel_float_qparams_stub,
4243                   &dequantize_tensor_per_channel_float_qparams_cpu);
4244 REGISTER_DISPATCH(fake_quant_per_channel_cachemask_stub,
4245                   &fake_quant_per_channel_cachemask_cpu);
4246 REGISTER_DISPATCH(qavg_pool2d_nhwc_stub, &qavg_pool2d_nhwc_kernel);
4247 REGISTER_DISPATCH(qavg_pool3d_nhwc_stub, &qavg_pool3d_nhwc_kernel);
4248 #else
4249 // These kernels are dispatched to AVX512
4250 ALSO_REGISTER_AVX512_DISPATCH(dequantize_tensor_per_channel_affine_stub,
4251                   &dequantize_tensor_per_channel_affine_cpu);
4252 ALSO_REGISTER_AVX512_DISPATCH(dequantize_tensor_per_channel_float_qparams_stub,
4253                   &dequantize_tensor_per_channel_float_qparams_cpu);
4254 ALSO_REGISTER_AVX512_DISPATCH(fake_quant_per_channel_cachemask_stub,
4255                   &fake_quant_per_channel_cachemask_cpu);
4256 ALSO_REGISTER_AVX512_DISPATCH(qavg_pool2d_nhwc_stub, &qavg_pool2d_nhwc_kernel);
4257 ALSO_REGISTER_AVX512_DISPATCH(qavg_pool3d_nhwc_stub, &qavg_pool3d_nhwc_kernel);
4258 #endif // CPU_CAPABILITY_AVX512 && _WIN32
4259 
4260 // The kernels below are dispatched to AVX2 because they don't perform as well
4261 // with AVX512. We might revisit this decision in the near future.
4262 REGISTER_DISPATCH(dequantize_tensor_per_tensor_affine_stub,
4263                   &dequantize_tensor_per_tensor_affine_cpu);
4264 REGISTER_DISPATCH(fake_quant_grad_learnable_tensor_stub,
4265                   &fake_quantize_learnable_tensor_grad_kernel_cpu);
4266 REGISTER_DISPATCH(fake_quant_tensor_cachemask_stub,
4267                   &fake_quantize_tensor_cachemask_kernel);
4268 REGISTER_DISPATCH(fake_quant_tensor_cachemask_tensor_qparams_stub,
4269                   &fake_quantize_tensor_cachemask_tensor_qparams_kernel);
4270 REGISTER_DISPATCH(qadaptive_avg_pool2d_nhwc_stub,
4271                   &qadaptive_avg_pool2d_nhwc_kernel);
4272 REGISTER_DISPATCH(qadaptive_avg_pool3d_ndhwc_stub,
4273                   &qadaptive_avg_pool3d_ndhwc_kernel);
4274 REGISTER_DISPATCH(qadd_relu_stub, &qadd_kernel<true>);
4275 REGISTER_DISPATCH(qadd_scalar_relu_stub, &qadd_scalar_kernel<true>);
4276 REGISTER_DISPATCH(qadd_scalar_stub, &qadd_scalar_kernel<false>);
4277 REGISTER_DISPATCH(qadd_stub, &qadd_kernel<false>);
4278 
4279 REGISTER_DISPATCH(qbatch_norm_relu_stub, &q_batch_norm_kernel<true>);
4280 REGISTER_DISPATCH(qbatch_norm_stub, &q_batch_norm_kernel<false>);
4281 REGISTER_DISPATCH(qcat_nhwc_stub, &qcat_nhwc_kernel<false>);
4282 REGISTER_DISPATCH(qcat_relu_nhwc_stub, &qcat_nhwc_kernel<true>);
4283 REGISTER_DISPATCH(qclamp_stub, &qclamp_kernel);
4284 REGISTER_DISPATCH(qclamp_min_stub, &qclamp_min_kernel);
4285 REGISTER_DISPATCH(qclamp_max_stub, &qclamp_max_kernel);
4286 REGISTER_DISPATCH(qelu_stub, &qelu_kernel);
4287 REGISTER_DISPATCH(qhardsigmoid_stub, &qhardsigmoid_kernel);
4288 REGISTER_DISPATCH(qhardswish_stub, &qhardswish_kernel);
4289 REGISTER_DISPATCH(qmaxpool_2d_nhwc_stub, &qmaxpool_2d_nhwc_kernel);
4290 REGISTER_DISPATCH(qmaxpool_3d_nthwc_stub, &qmaxpool_3d_nthwc_kernel);
4291 REGISTER_DISPATCH(qmul_relu_stub, &qmul_kernel<true>);
4292 REGISTER_DISPATCH(qmul_stub, &qmul_kernel<false>);
4293 REGISTER_DISPATCH(qrelu_leaky_stub, &leaky_qrelu_out_kernel);
4294 REGISTER_DISPATCH(qrelu_stub, &qrelu_kernel);
4295 REGISTER_DISPATCH(qprelu_stub, &qprelu_out_kernel);
4296 REGISTER_DISPATCH(qgelu_stub, &qgelu_kernel);
4297 REGISTER_DISPATCH(qsigmoid_stub, &qsigmoid_kernel);
4298 REGISTER_DISPATCH(qtanh_stub, &qtanh_kernel);
4299 REGISTER_DISPATCH(qthreshold_stub, &qthreshold_kernel);
4300 REGISTER_DISPATCH(qtopk_stub, &qtopk_kernel);
4301 REGISTER_DISPATCH(fake_quant_grad_learnable_channel_stub,
4302                   &fake_quantize_learnable_channel_grad_kernel_cpu);
4303 REGISTER_DISPATCH(
4304     quantize_tensor_per_tensor_affine_stub,
4305     &quantize_tensor_per_tensor_affine_cpu);
4306 REGISTER_DISPATCH(
4307     quantize_tensor_per_channel_affine_stub,
4308     &quantize_tensor_per_channel_affine_cpu);
4309 REGISTER_DISPATCH(
4310     quantize_tensor_per_channel_float_qparams_stub,
4311     &quantize_tensor_per_channel_float_qparams_cpu);
4312 REGISTER_DISPATCH(quantized_normalize_stub, &quantized_normalize_kernel);
4313 REGISTER_DISPATCH(quantized_groupnorm_nhwc_stub, &quantized_groupnorm_nhwc_kernel);
4314 REGISTER_DISPATCH(qupsample_bilinear2d_nhwc_stub,
4315                   &qupsample_bilinear2d_nhwc_kernel);
4316 REGISTER_DISPATCH(
4317     quantize_tensor_per_tensor_affine_sub_byte_stub,
4318     &quantize_tensor_per_tensor_affine_sub_byte_cpu);
4319 REGISTER_DISPATCH(
4320     dequantize_tensor_per_tensor_affine_sub_byte_stub,
4321     &dequantize_tensor_per_tensor_affine_sub_byte_cpu);
4322 REGISTER_DISPATCH(
4323     masked_fill_kernel_quantized_stub,
4324     &masked_fill_kernel_quantized_cpu);
4325 REGISTER_DISPATCH(
4326     index_put_kernel_quantized_stub,
4327     &index_put_kernel_quantized_cpu);
4328 REGISTER_DISPATCH(qmean_inner_dim_stub, &qmean_inner_dim_kernel);
4329 REGISTER_DISPATCH(qstd_inner_dim_stub, &qstd_inner_dim_kernel);
4330 } // namespace at::native
4331 // NOLINTEND(*-c-arrays)
4332