1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <assert.h>
10
11 #include <emmintrin.h>
12
13 #include <qnnpack/q8gavgpool.h>
14
pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2(size_t m,size_t n,const uint8_t * input,size_t input_stride,const uint8_t * zero,int32_t * buffer,uint8_t * output,const union pytorch_qnnp_avgpool_quantization_params quantization_params[RESTRICT_STATIC1])15 void pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2(
16 size_t m,
17 size_t n,
18 const uint8_t* input,
19 size_t input_stride,
20 const uint8_t* zero,
21 int32_t* buffer,
22 uint8_t* output,
23 const union pytorch_qnnp_avgpool_quantization_params
24 quantization_params[RESTRICT_STATIC 1]) {
25 assert(m > 7);
26 assert(n >= 8);
27
28 const uint8_t* i0 = input;
29 const uint8_t* i1 = i0 + input_stride;
30 const uint8_t* i2 = i1 + input_stride;
31 const uint8_t* i3 = i2 + input_stride;
32 const uint8_t* i4 = i3 + input_stride;
33 const uint8_t* i5 = i4 + input_stride;
34 const uint8_t* i6 = i5 + input_stride;
35 const size_t packed_n = (n + 7) & -8;
36 const size_t input_increment = 7 * input_stride - packed_n;
37 const __m128i vbias =
38 _mm_load_si128((const __m128i*)&quantization_params->sse2.bias);
39 const __m128i vzero = _mm_setzero_si128();
40
41 /* note: goes up to 7 elements over bound */
42 int32_t* acc = buffer;
43 for (size_t k = 0; k < n; k += 8) {
44 const __m128i vi0 = _mm_loadl_epi64((const __m128i*)i0);
45 i0 += 8;
46 const __m128i vi1 = _mm_loadl_epi64((const __m128i*)i1);
47 i1 += 8;
48 const __m128i vi2 = _mm_loadl_epi64((const __m128i*)i2);
49 i2 += 8;
50 const __m128i vi3 = _mm_loadl_epi64((const __m128i*)i3);
51 i3 += 8;
52 const __m128i vi4 = _mm_loadl_epi64((const __m128i*)i4);
53 i4 += 8;
54 const __m128i vi5 = _mm_loadl_epi64((const __m128i*)i5);
55 i5 += 8;
56 const __m128i vi6 = _mm_loadl_epi64((const __m128i*)i6);
57 i6 += 8;
58
59 const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
60 const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
61 const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
62 const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
63 const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
64 const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
65 const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
66
67 __m128i vacc_lo = _mm_add_epi32(vbias, _mm_unpacklo_epi16(vxi0, vzero));
68 __m128i vacc_hi = _mm_add_epi32(vbias, _mm_unpackhi_epi16(vxi0, vzero));
69 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi1, vzero));
70 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi1, vzero));
71 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi2, vzero));
72 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi2, vzero));
73 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi3, vzero));
74 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi3, vzero));
75 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi4, vzero));
76 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi4, vzero));
77 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi5, vzero));
78 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi5, vzero));
79 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi6, vzero));
80 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi6, vzero));
81
82 _mm_store_si128((__m128i*)acc, vacc_lo);
83 _mm_store_si128((__m128i*)acc + 1, vacc_hi);
84 acc += 8;
85 }
86 for (m -= 7; m > 7; m -= 7) {
87 acc = buffer;
88 i0 = (const uint8_t*)((uintptr_t)i0 + input_increment);
89 i1 = (const uint8_t*)((uintptr_t)i1 + input_increment);
90 i2 = (const uint8_t*)((uintptr_t)i2 + input_increment);
91 i3 = (const uint8_t*)((uintptr_t)i3 + input_increment);
92 i4 = (const uint8_t*)((uintptr_t)i4 + input_increment);
93 i5 = (const uint8_t*)((uintptr_t)i5 + input_increment);
94 i6 = (const uint8_t*)((uintptr_t)i6 + input_increment);
95
96 /* note: goes up to 7 elements over bound */
97 for (size_t k = 0; k < n; k += 8) {
98 const __m128i vi0 = _mm_loadl_epi64((const __m128i*)i0);
99 i0 += 8;
100 const __m128i vi1 = _mm_loadl_epi64((const __m128i*)i1);
101 i1 += 8;
102 const __m128i vi2 = _mm_loadl_epi64((const __m128i*)i2);
103 i2 += 8;
104 const __m128i vi3 = _mm_loadl_epi64((const __m128i*)i3);
105 i3 += 8;
106 const __m128i vi4 = _mm_loadl_epi64((const __m128i*)i4);
107 i4 += 8;
108 const __m128i vi5 = _mm_loadl_epi64((const __m128i*)i5);
109 i5 += 8;
110 const __m128i vi6 = _mm_loadl_epi64((const __m128i*)i6);
111 i6 += 8;
112 __m128i vacc_lo = _mm_load_si128((const __m128i*)acc);
113 __m128i vacc_hi = _mm_load_si128((const __m128i*)acc + 1);
114
115 const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
116 const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
117 const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
118 const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
119 const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
120 const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
121 const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
122
123 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi0, vzero));
124 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi0, vzero));
125 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi1, vzero));
126 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi1, vzero));
127 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi2, vzero));
128 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi2, vzero));
129 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi3, vzero));
130 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi3, vzero));
131 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi4, vzero));
132 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi4, vzero));
133 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi5, vzero));
134 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi5, vzero));
135 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi6, vzero));
136 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi6, vzero));
137
138 _mm_store_si128((__m128i*)acc, vacc_lo);
139 _mm_store_si128((__m128i*)acc + 1, vacc_hi);
140 acc += 8;
141 }
142 }
143
144 const __m128 vscale = _mm_loadu_ps(quantization_params->sse2.scale);
145
146 i0 = (const uint8_t*)((uintptr_t)i0 + input_increment);
147 i1 = (const uint8_t*)((uintptr_t)i1 + input_increment);
148 if (m < 2) {
149 i1 = zero;
150 }
151 i2 = (const uint8_t*)((uintptr_t)i2 + input_increment);
152 if (m <= 2) {
153 i2 = zero;
154 }
155 i3 = (const uint8_t*)((uintptr_t)i3 + input_increment);
156 if (m < 4) {
157 i3 = zero;
158 }
159 i4 = (const uint8_t*)((uintptr_t)i4 + input_increment);
160 if (m <= 4) {
161 i4 = zero;
162 }
163 i5 = (const uint8_t*)((uintptr_t)i5 + input_increment);
164 if (m < 6) {
165 i5 = zero;
166 }
167 i6 = (const uint8_t*)((uintptr_t)i6 + input_increment);
168 if (m <= 6) {
169 i6 = zero;
170 }
171
172 acc = buffer;
173 do {
174 const __m128i vi0 = _mm_loadl_epi64((const __m128i*)i0);
175 i0 += 8;
176 const __m128i vi1 = _mm_loadl_epi64((const __m128i*)i1);
177 i1 += 8;
178 const __m128i vi2 = _mm_loadl_epi64((const __m128i*)i2);
179 i2 += 8;
180 const __m128i vi3 = _mm_loadl_epi64((const __m128i*)i3);
181 i3 += 8;
182 const __m128i vi4 = _mm_loadl_epi64((const __m128i*)i4);
183 i4 += 8;
184 const __m128i vi5 = _mm_loadl_epi64((const __m128i*)i5);
185 i5 += 8;
186 const __m128i vi6 = _mm_loadl_epi64((const __m128i*)i6);
187 i6 += 8;
188 __m128i vacc_lo = _mm_load_si128((const __m128i*)acc);
189 __m128i vacc_hi = _mm_load_si128((const __m128i*)acc + 1);
190 acc += 8;
191
192 const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
193 const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
194 const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
195 const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
196 const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
197 const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
198 const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
199
200 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi0, vzero));
201 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi0, vzero));
202 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi1, vzero));
203 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi1, vzero));
204 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi2, vzero));
205 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi2, vzero));
206 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi3, vzero));
207 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi3, vzero));
208 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi4, vzero));
209 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi4, vzero));
210 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi5, vzero));
211 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi5, vzero));
212 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi6, vzero));
213 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi6, vzero));
214
215 const __m128 vacc_lo_f = _mm_mul_ps(_mm_cvtepi32_ps(vacc_lo), vscale);
216 const __m128 vacc_hi_f = _mm_mul_ps(_mm_cvtepi32_ps(vacc_hi), vscale);
217
218 const __m128i vscaled_lo = _mm_cvtps_epi32(vacc_lo_f);
219 const __m128i vscaled_hi = _mm_cvtps_epi32(vacc_hi_f);
220
221 __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi);
222 vout = _mm_adds_epi16(
223 vout,
224 _mm_load_si128(
225 (const __m128i*)quantization_params->sse2.output_zero_point));
226 vout = _mm_packus_epi16(vout, vout);
227 vout = _mm_min_epu8(
228 vout,
229 _mm_load_si128((const __m128i*)quantization_params->sse2.output_max));
230 vout = _mm_max_epu8(
231 vout,
232 _mm_load_si128((const __m128i*)quantization_params->sse2.output_min));
233
234 _mm_storel_epi64((__m128i*)output, vout);
235 output += 8;
236
237 n -= 8;
238 } while (n >= 8);
239 if (n != 0) {
240 const size_t address_decrement = 8 - n;
241 i0 = (const uint8_t*)((uintptr_t)i0 - address_decrement);
242 i1 = (const uint8_t*)((uintptr_t)i1 - address_decrement);
243 i2 = (const uint8_t*)((uintptr_t)i2 - address_decrement);
244 i3 = (const uint8_t*)((uintptr_t)i3 - address_decrement);
245 i4 = (const uint8_t*)((uintptr_t)i4 - address_decrement);
246 i5 = (const uint8_t*)((uintptr_t)i5 - address_decrement);
247 i6 = (const uint8_t*)((uintptr_t)i6 - address_decrement);
248 const __m128i vi_shift = _mm_cvtsi32_si128(8 * address_decrement);
249
250 const __m128i vi0 =
251 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i0), vi_shift);
252 const __m128i vi1 =
253 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i1), vi_shift);
254 const __m128i vi2 =
255 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i2), vi_shift);
256 const __m128i vi3 =
257 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i3), vi_shift);
258 const __m128i vi4 =
259 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i4), vi_shift);
260 const __m128i vi5 =
261 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i5), vi_shift);
262 const __m128i vi6 =
263 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i6), vi_shift);
264 __m128i vacc_lo = _mm_load_si128((const __m128i*)acc);
265 __m128i vacc_hi = _mm_load_si128((const __m128i*)acc + 1);
266
267 const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero);
268 const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero);
269 const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero);
270 const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero);
271 const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero);
272 const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero);
273 const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero);
274
275 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi0, vzero));
276 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi0, vzero));
277 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi1, vzero));
278 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi1, vzero));
279 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi2, vzero));
280 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi2, vzero));
281 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi3, vzero));
282 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi3, vzero));
283 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi4, vzero));
284 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi4, vzero));
285 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi5, vzero));
286 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi5, vzero));
287 vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi6, vzero));
288 vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi6, vzero));
289
290 const __m128 vacc_lo_f = _mm_mul_ps(_mm_cvtepi32_ps(vacc_lo), vscale);
291 const __m128 vacc_hi_f = _mm_mul_ps(_mm_cvtepi32_ps(vacc_hi), vscale);
292
293 const __m128i vscaled_lo = _mm_cvtps_epi32(vacc_lo_f);
294 const __m128i vscaled_hi = _mm_cvtps_epi32(vacc_hi_f);
295
296 __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi);
297 vout = _mm_adds_epi16(
298 vout,
299 _mm_load_si128(
300 (const __m128i*)quantization_params->sse2.output_zero_point));
301 vout = _mm_packus_epi16(vout, vout);
302 vout = _mm_min_epu8(
303 vout,
304 _mm_load_si128((const __m128i*)quantization_params->sse2.output_max));
305 vout = _mm_max_epu8(
306 vout,
307 _mm_load_si128((const __m128i*)quantization_params->sse2.output_min));
308
309 if (n & 4) {
310 *((uint32_t*)output) = (uint32_t)_mm_cvtsi128_si32(vout);
311 output += 4;
312 vout = _mm_srli_epi64(vout, 32);
313 }
314 if (n & 2) {
315 *((uint16_t*)output) = (uint16_t)_mm_extract_epi16(vout, 0);
316 output += 2;
317 vout = _mm_srli_epi32(vout, 16);
318 }
319 if (n & 1) {
320 *((uint8_t*)output) = (uint8_t)_mm_cvtsi128_si32(vout);
321 }
322 }
323 }
324