1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <assert.h>
10
11 #include <emmintrin.h>
12
13 #include <qnnpack/q8gavgpool.h>
14
pytorch_q8gavgpool_ukernel_up8xm__sse2(size_t m,size_t n,const uint8_t * input,size_t input_stride,const uint8_t * zero,uint8_t * output,const union pytorch_qnnp_avgpool_quantization_params quantization_params[RESTRICT_STATIC1])15 void pytorch_q8gavgpool_ukernel_up8xm__sse2(
16 size_t m,
17 size_t n,
18 const uint8_t* input,
19 size_t input_stride,
20 const uint8_t* zero,
21 uint8_t* output,
22 const union pytorch_qnnp_avgpool_quantization_params
23 quantization_params[RESTRICT_STATIC 1]) {
24 assert(m >= 1);
25 assert(n < 8);
26
27 const __m128i vbias =
28 _mm_loadu_si128((const __m128i*)&quantization_params->sse2.bias);
29 __m128i vacc_lo = vbias;
30 __m128i vacc_hi = vbias;
31 __m128i vzero = _mm_setzero_si128();
32 while (m >= 8) {
33 const __m128i vinput = _mm_loadl_epi64((const __m128i*)input);
34 const __m128i vxinput = _mm_unpacklo_epi8(vinput, vzero);
35 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi8(vxinput, vzero));
36 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi8(vxinput, vzero));
37
38 input += input_stride;
39 m--;
40 }
41 while (m-- != 0) {
42 input += n;
43 __m128i vinput = _mm_setzero_si128();
44 if (n & 1) {
45 input -= 1;
46 vinput = _mm_cvtsi32_si128((int)(uint32_t)*input);
47 }
48 if (n & 2) {
49 vinput = _mm_slli_epi32(vinput, 16);
50 input -= 2;
51 vinput = _mm_insert_epi16(vinput, *((const uint16_t*)input), 0);
52 }
53 if (n & 4) {
54 input -= 4;
55 vinput = _mm_unpacklo_epi32(
56 _mm_cvtsi32_si128((int)*((const uint32_t*)input)), vinput);
57 }
58 input += input_stride;
59
60 const __m128i vxinput = _mm_unpacklo_epi8(vinput, vzero);
61 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi8(vxinput, vzero));
62 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi8(vxinput, vzero));
63 }
64
65 const __m128 vscale = _mm_loadu_ps(quantization_params->sse2.scale);
66
67 const __m128 vacc_lo_f = _mm_mul_ps(_mm_cvtepi32_ps(vacc_lo), vscale);
68 const __m128 vacc_hi_f = _mm_mul_ps(_mm_cvtepi32_ps(vacc_hi), vscale);
69
70 const __m128i vscaled_lo = _mm_cvtps_epi32(vacc_lo_f);
71 const __m128i vscaled_hi = _mm_cvtps_epi32(vacc_hi_f);
72
73 __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi);
74 vout = _mm_adds_epi16(
75 vout,
76 _mm_load_si128(
77 (const __m128i*)quantization_params->sse2.output_zero_point));
78 vout = _mm_packus_epi16(vout, vout);
79 vout = _mm_min_epu8(
80 vout,
81 _mm_load_si128((const __m128i*)quantization_params->sse2.output_max));
82 vout = _mm_max_epu8(
83 vout,
84 _mm_load_si128((const __m128i*)quantization_params->sse2.output_min));
85
86 if (n & 4) {
87 *((uint32_t*)output) = (uint32_t)_mm_cvtsi128_si32(vout);
88 output += 4;
89 vout = _mm_srli_epi64(vout, 32);
90 }
91 if (n & 2) {
92 *((uint16_t*)output) = (uint16_t)_mm_extract_epi16(vout, 0);
93 output += 2;
94 vout = _mm_srli_epi32(vout, 16);
95 }
96 if (n & 1) {
97 *((uint8_t*)output) = (uint8_t)_mm_cvtsi128_si32(vout);
98 }
99 }
100