1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <assert.h>
10
11 #include <emmintrin.h>
12
13 #include <qnnpack/q8gavgpool.h>
14
pytorch_q8gavgpool_ukernel_up8x7__sse2(size_t m,size_t n,const uint8_t * input,size_t input_stride,const uint8_t * zero,uint8_t * output,const union pytorch_qnnp_avgpool_quantization_params quantization_params[RESTRICT_STATIC1])15 void pytorch_q8gavgpool_ukernel_up8x7__sse2(
16 size_t m,
17 size_t n,
18 const uint8_t* input,
19 size_t input_stride,
20 const uint8_t* zero,
21 uint8_t* output,
22 const union pytorch_qnnp_avgpool_quantization_params
23 quantization_params[RESTRICT_STATIC 1]) {
24 assert(m >= 1);
25 assert(m <= 7);
26 assert(n >= 8);
27
28 const uint8_t* i0 = input;
29 const uint8_t* i1 = i0 + input_stride;
30 if (m < 2) {
31 i1 = zero;
32 }
33 const uint8_t* i2 = i1 + input_stride;
34 if (m <= 2) {
35 i2 = zero;
36 }
37 const uint8_t* i3 = i2 + input_stride;
38 if (m < 4) {
39 i3 = zero;
40 }
41 const uint8_t* i4 = i3 + input_stride;
42 if (m <= 4) {
43 i4 = zero;
44 }
45 const uint8_t* i5 = i4 + input_stride;
46 if (m < 6) {
47 i5 = zero;
48 }
49 const uint8_t* i6 = i5 + input_stride;
50 if (m <= 6) {
51 i6 = zero;
52 }
53 const __m128i vbias =
54 _mm_load_si128((const __m128i*)&quantization_params->sse2.bias);
55 const __m128i vzero = _mm_setzero_si128();
56
57 const __m128 vscale = _mm_loadu_ps(quantization_params->sse2.scale);
58
59 do {
60 const __m128i vi0 = _mm_loadl_epi64((const __m128i*)i0);
61 i0 += 8;
62 const __m128i vi1 = _mm_loadl_epi64((const __m128i*)i1);
63 i1 += 8;
64 const __m128i vi2 = _mm_loadl_epi64((const __m128i*)i2);
65 i2 += 8;
66 const __m128i vi3 = _mm_loadl_epi64((const __m128i*)i3);
67 i3 += 8;
68 const __m128i vi4 = _mm_loadl_epi64((const __m128i*)i4);
69 i4 += 8;
70 const __m128i vi5 = _mm_loadl_epi64((const __m128i*)i5);
71 i5 += 8;
72 const __m128i vi6 = _mm_loadl_epi64((const __m128i*)i6);
73 i6 += 8;
74
75 const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
76 const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
77 const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
78 const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
79 const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
80 const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
81 const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
82
83 __m128i vacc_lo = _mm_add_epi32(vbias, _mm_unpacklo_epi16(vxi0, vzero));
84 __m128i vacc_hi = _mm_add_epi32(vbias, _mm_unpackhi_epi16(vxi0, vzero));
85 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi1, vzero));
86 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi1, vzero));
87 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi2, vzero));
88 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi2, vzero));
89 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi3, vzero));
90 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi3, vzero));
91 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi4, vzero));
92 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi4, vzero));
93 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi5, vzero));
94 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi5, vzero));
95 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi6, vzero));
96 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi6, vzero));
97
98 const __m128 vacc_lo_f = _mm_mul_ps(_mm_cvtepi32_ps(vacc_lo), vscale);
99 const __m128 vacc_hi_f = _mm_mul_ps(_mm_cvtepi32_ps(vacc_hi), vscale);
100
101 const __m128i vscaled_lo = _mm_cvtps_epi32(vacc_lo_f);
102 const __m128i vscaled_hi = _mm_cvtps_epi32(vacc_hi_f);
103
104 __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi);
105 vout = _mm_adds_epi16(
106 vout,
107 _mm_load_si128(
108 (const __m128i*)quantization_params->sse2.output_zero_point));
109 vout = _mm_packus_epi16(vout, vout);
110 vout = _mm_min_epu8(
111 vout,
112 _mm_load_si128((const __m128i*)quantization_params->sse2.output_max));
113 vout = _mm_max_epu8(
114 vout,
115 _mm_load_si128((const __m128i*)quantization_params->sse2.output_min));
116
117 _mm_storel_epi64((__m128i*)output, vout);
118 output += 8;
119
120 n -= 8;
121 } while (n >= 8);
122 if (n != 0) {
123 const size_t address_decrement = 8 - n;
124 i0 = (const uint8_t*)((uintptr_t)i0 - address_decrement);
125 i1 = (const uint8_t*)((uintptr_t)i1 - address_decrement);
126 i2 = (const uint8_t*)((uintptr_t)i2 - address_decrement);
127 i3 = (const uint8_t*)((uintptr_t)i3 - address_decrement);
128 i4 = (const uint8_t*)((uintptr_t)i4 - address_decrement);
129 i5 = (const uint8_t*)((uintptr_t)i5 - address_decrement);
130 i6 = (const uint8_t*)((uintptr_t)i6 - address_decrement);
131 const __m128i vi_shift = _mm_cvtsi32_si128(8 * address_decrement);
132
133 const __m128i vi0 =
134 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i0), vi_shift);
135 const __m128i vi1 =
136 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i1), vi_shift);
137 const __m128i vi2 =
138 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i2), vi_shift);
139 const __m128i vi3 =
140 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i3), vi_shift);
141 const __m128i vi4 =
142 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i4), vi_shift);
143 const __m128i vi5 =
144 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i5), vi_shift);
145 const __m128i vi6 =
146 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i6), vi_shift);
147
148 const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
149 const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
150 const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
151 const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
152 const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
153 const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
154 const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
155
156 __m128i vacc_lo = _mm_add_epi32(vbias, _mm_unpacklo_epi16(vxi0, vzero));
157 __m128i vacc_hi = _mm_add_epi32(vbias, _mm_unpackhi_epi16(vxi0, vzero));
158 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi1, vzero));
159 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi1, vzero));
160 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi2, vzero));
161 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi2, vzero));
162 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi3, vzero));
163 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi3, vzero));
164 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi4, vzero));
165 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi4, vzero));
166 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi5, vzero));
167 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi5, vzero));
168 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi6, vzero));
169 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi6, vzero));
170
171 const __m128 vacc_lo_f = _mm_mul_ps(_mm_cvtepi32_ps(vacc_lo), vscale);
172 const __m128 vacc_hi_f = _mm_mul_ps(_mm_cvtepi32_ps(vacc_hi), vscale);
173
174 const __m128i vscaled_lo = _mm_cvtps_epi32(vacc_lo_f);
175 const __m128i vscaled_hi = _mm_cvtps_epi32(vacc_hi_f);
176
177 __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi);
178 vout = _mm_adds_epi16(
179 vout,
180 _mm_load_si128(
181 (const __m128i*)quantization_params->sse2.output_zero_point));
182 vout = _mm_packus_epi16(vout, vout);
183 vout = _mm_min_epu8(
184 vout,
185 _mm_load_si128((const __m128i*)quantization_params->sse2.output_max));
186 vout = _mm_max_epu8(
187 vout,
188 _mm_load_si128((const __m128i*)quantization_params->sse2.output_min));
189
190 if (n & 4) {
191 *((uint32_t*)output) = (uint32_t)_mm_cvtsi128_si32(vout);
192 output += 4;
193 vout = _mm_srli_epi64(vout, 32);
194 }
195 if (n & 2) {
196 *((uint16_t*)output) = (uint16_t)_mm_extract_epi16(vout, 0);
197 output += 2;
198 vout = _mm_srli_epi32(vout, 16);
199 }
200 if (n & 1) {
201 *((uint8_t*)output) = (uint8_t)_mm_cvtsi128_si32(vout);
202 }
203 }
204 }
205