xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/up8xm-sse2.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <assert.h>
10 
11 #include <emmintrin.h>
12 
13 #include <qnnpack/q8gavgpool.h>
14 
pytorch_q8gavgpool_ukernel_up8xm__sse2(size_t m,size_t n,const uint8_t * input,size_t input_stride,const uint8_t * zero,uint8_t * output,const union pytorch_qnnp_avgpool_quantization_params quantization_params[RESTRICT_STATIC1])15 void pytorch_q8gavgpool_ukernel_up8xm__sse2(
16     size_t m,
17     size_t n,
18     const uint8_t* input,
19     size_t input_stride,
20     const uint8_t* zero,
21     uint8_t* output,
22     const union pytorch_qnnp_avgpool_quantization_params
23         quantization_params[RESTRICT_STATIC 1]) {
24   assert(m >= 1);
25   assert(n < 8);
26 
27   const __m128i vbias =
28       _mm_loadu_si128((const __m128i*)&quantization_params->sse2.bias);
29   __m128i vacc_lo = vbias;
30   __m128i vacc_hi = vbias;
31   __m128i vzero = _mm_setzero_si128();
32   while (m >= 8) {
33     const __m128i vinput = _mm_loadl_epi64((const __m128i*)input);
34     const __m128i vxinput = _mm_unpacklo_epi8(vinput, vzero);
35     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi8(vxinput, vzero));
36     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi8(vxinput, vzero));
37 
38     input += input_stride;
39     m--;
40   }
41   while (m-- != 0) {
42     input += n;
43     __m128i vinput = _mm_setzero_si128();
44     if (n & 1) {
45       input -= 1;
46       vinput = _mm_cvtsi32_si128((int)(uint32_t)*input);
47     }
48     if (n & 2) {
49       vinput = _mm_slli_epi32(vinput, 16);
50       input -= 2;
51       vinput = _mm_insert_epi16(vinput, *((const uint16_t*)input), 0);
52     }
53     if (n & 4) {
54       input -= 4;
55       vinput = _mm_unpacklo_epi32(
56           _mm_cvtsi32_si128((int)*((const uint32_t*)input)), vinput);
57     }
58     input += input_stride;
59 
60     const __m128i vxinput = _mm_unpacklo_epi8(vinput, vzero);
61     vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi8(vxinput, vzero));
62     vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi8(vxinput, vzero));
63   }
64 
65   const __m128 vscale = _mm_loadu_ps(quantization_params->sse2.scale);
66 
67   const __m128 vacc_lo_f = _mm_mul_ps(_mm_cvtepi32_ps(vacc_lo), vscale);
68   const __m128 vacc_hi_f = _mm_mul_ps(_mm_cvtepi32_ps(vacc_hi), vscale);
69 
70   const __m128i vscaled_lo = _mm_cvtps_epi32(vacc_lo_f);
71   const __m128i vscaled_hi = _mm_cvtps_epi32(vacc_hi_f);
72 
73   __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi);
74   vout = _mm_adds_epi16(
75       vout,
76       _mm_load_si128(
77           (const __m128i*)quantization_params->sse2.output_zero_point));
78   vout = _mm_packus_epi16(vout, vout);
79   vout = _mm_min_epu8(
80       vout,
81       _mm_load_si128((const __m128i*)quantization_params->sse2.output_max));
82   vout = _mm_max_epu8(
83       vout,
84       _mm_load_si128((const __m128i*)quantization_params->sse2.output_min));
85 
86   if (n & 4) {
87     *((uint32_t*)output) = (uint32_t)_mm_cvtsi128_si32(vout);
88     output += 4;
89     vout = _mm_srli_epi64(vout, 32);
90   }
91   if (n & 2) {
92     *((uint16_t*)output) = (uint16_t)_mm_extract_epi16(vout, 0);
93     output += 2;
94     vout = _mm_srli_epi32(vout, 16);
95   }
96   if (n & 1) {
97     *((uint8_t*)output) = (uint8_t)_mm_cvtsi128_si32(vout);
98   }
99 }
100