1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <immintrin.h>
10
11 #include <qnnpack/q8dwconv.h>
12 #include <requantization/runtime-sse2.h>
13
pytorch_q8dwconv_ukernel_up8x9__sse2(size_t channels,size_t output_width,const uint8_t ** input,const void * weights,uint8_t * output,size_t input_stride,size_t output_increment,const union pytorch_qnnp_conv_quantization_params quantization_params[RESTRICT_STATIC1])14 void pytorch_q8dwconv_ukernel_up8x9__sse2(
15 size_t channels,
16 size_t output_width,
17 const uint8_t** input,
18 const void* weights,
19 uint8_t* output,
20 size_t input_stride,
21 size_t output_increment,
22 const union pytorch_qnnp_conv_quantization_params
23 quantization_params[RESTRICT_STATIC 1]) {
24 const __m128i va_zero_point = _mm_load_si128(
25 (const __m128i*)quantization_params->sse2.input_zero_point);
26 const __m128i vkernel_zero_point = _mm_set1_epi16(
27 quantization_params->sse2.kernel_zero_points[0]);
28 const __m128i vzero = _mm_setzero_si128();
29
30 do {
31 const uint8_t* i0 = input[0];
32 const uint8_t* i1 = input[1];
33 const uint8_t* i2 = input[2];
34 const uint8_t* i3 = input[3];
35 const uint8_t* i4 = input[4];
36 const uint8_t* i5 = input[5];
37 const uint8_t* i6 = input[6];
38 const uint8_t* i7 = input[7];
39 const uint8_t* i8 = input[8];
40
41 input = (const uint8_t**)((uintptr_t)input + input_stride);
42
43 size_t c = channels;
44 const void* w = weights;
45 for (; c >= 8; c -= 8) {
46 __m128i vacc_lo = _mm_loadu_si128((const __m128i*)w);
47 __m128i vacc_hi = _mm_loadu_si128((const __m128i*)((uintptr_t)w + 16));
48
49 const __m128i vi0 = _mm_loadl_epi64((const __m128i*)i0);
50 i0 += 8;
51 const __m128i vxi0 =
52 sub_zero_point(_mm_unpacklo_epi8(vi0, vzero), va_zero_point);
53 const __m128i vk0 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 32));
54 const __m128i vxk0 =
55 _mm_sub_epi16(_mm_unpacklo_epi8(vk0, vzero), vkernel_zero_point);
56 const __m128i vprod0_odd = _mm_mullo_epi16(vxi0, vxk0);
57 const __m128i vprod0_even = _mm_mulhi_epi16(vxi0, vxk0);
58 vacc_lo =
59 _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod0_odd, vprod0_even));
60 vacc_hi =
61 _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod0_odd, vprod0_even));
62
63 const __m128i vi1 = _mm_loadl_epi64((const __m128i*)i1);
64 i1 += 8;
65 const __m128i vxi1 =
66 sub_zero_point(_mm_unpacklo_epi8(vi1, vzero), va_zero_point);
67 const __m128i vk1 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 40));
68 const __m128i vxk1 =
69 _mm_sub_epi16(_mm_unpacklo_epi8(vk1, vzero), vkernel_zero_point);
70 const __m128i vprod1_odd = _mm_mullo_epi16(vxi1, vxk1);
71 const __m128i vprod1_even = _mm_mulhi_epi16(vxi1, vxk1);
72 vacc_lo =
73 _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod1_odd, vprod1_even));
74 vacc_hi =
75 _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod1_odd, vprod1_even));
76
77 const __m128i vi2 = _mm_loadl_epi64((const __m128i*)i2);
78 i2 += 8;
79 const __m128i vxi2 =
80 sub_zero_point(_mm_unpacklo_epi8(vi2, vzero), va_zero_point);
81 const __m128i vk2 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 48));
82 const __m128i vxk2 =
83 _mm_sub_epi16(_mm_unpacklo_epi8(vk2, vzero), vkernel_zero_point);
84 const __m128i vprod2_odd = _mm_mullo_epi16(vxi2, vxk2);
85 const __m128i vprod2_even = _mm_mulhi_epi16(vxi2, vxk2);
86 vacc_lo =
87 _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod2_odd, vprod2_even));
88 vacc_hi =
89 _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod2_odd, vprod2_even));
90
91 const __m128i vi3 = _mm_loadl_epi64((const __m128i*)i3);
92 i3 += 8;
93 const __m128i vxi3 =
94 sub_zero_point(_mm_unpacklo_epi8(vi3, vzero), va_zero_point);
95 const __m128i vk3 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 56));
96 const __m128i vxk3 =
97 _mm_sub_epi16(_mm_unpacklo_epi8(vk3, vzero), vkernel_zero_point);
98 const __m128i vprod3_odd = _mm_mullo_epi16(vxi3, vxk3);
99 const __m128i vprod3_even = _mm_mulhi_epi16(vxi3, vxk3);
100 vacc_lo =
101 _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod3_odd, vprod3_even));
102 vacc_hi =
103 _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod3_odd, vprod3_even));
104
105 const __m128i vi4 = _mm_loadl_epi64((const __m128i*)i4);
106 i4 += 8;
107 const __m128i vxi4 =
108 sub_zero_point(_mm_unpacklo_epi8(vi4, vzero), va_zero_point);
109 const __m128i vk4 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 64));
110 const __m128i vxk4 =
111 _mm_sub_epi16(_mm_unpacklo_epi8(vk4, vzero), vkernel_zero_point);
112 const __m128i vprod4_odd = _mm_mullo_epi16(vxi4, vxk4);
113 const __m128i vprod4_even = _mm_mulhi_epi16(vxi4, vxk4);
114 vacc_lo =
115 _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod4_odd, vprod4_even));
116 vacc_hi =
117 _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod4_odd, vprod4_even));
118
119 const __m128i vi5 = _mm_loadl_epi64((const __m128i*)i5);
120 i5 += 8;
121 const __m128i vxi5 =
122 sub_zero_point(_mm_unpacklo_epi8(vi5, vzero), va_zero_point);
123 const __m128i vk5 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 72));
124 const __m128i vxk5 =
125 _mm_sub_epi16(_mm_unpacklo_epi8(vk5, vzero), vkernel_zero_point);
126 const __m128i vprod5_odd = _mm_mullo_epi16(vxi5, vxk5);
127 const __m128i vprod5_even = _mm_mulhi_epi16(vxi5, vxk5);
128 vacc_lo =
129 _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod5_odd, vprod5_even));
130 vacc_hi =
131 _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod5_odd, vprod5_even));
132
133 const __m128i vi6 = _mm_loadl_epi64((const __m128i*)i6);
134 i6 += 8;
135 const __m128i vxi6 =
136 sub_zero_point(_mm_unpacklo_epi8(vi6, vzero), va_zero_point);
137 const __m128i vk6 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 80));
138 const __m128i vxk6 =
139 _mm_sub_epi16(_mm_unpacklo_epi8(vk6, vzero), vkernel_zero_point);
140 const __m128i vprod6_odd = _mm_mullo_epi16(vxi6, vxk6);
141 const __m128i vprod6_even = _mm_mulhi_epi16(vxi6, vxk6);
142 vacc_lo =
143 _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod6_odd, vprod6_even));
144 vacc_hi =
145 _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod6_odd, vprod6_even));
146
147 const __m128i vi7 = _mm_loadl_epi64((const __m128i*)i7);
148 i7 += 8;
149 const __m128i vxi7 =
150 sub_zero_point(_mm_unpacklo_epi8(vi7, vzero), va_zero_point);
151 const __m128i vk7 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 88));
152 const __m128i vxk7 =
153 _mm_sub_epi16(_mm_unpacklo_epi8(vk7, vzero), vkernel_zero_point);
154 const __m128i vprod7_odd = _mm_mullo_epi16(vxi7, vxk7);
155 const __m128i vprod7_even = _mm_mulhi_epi16(vxi7, vxk7);
156 vacc_lo =
157 _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod7_odd, vprod7_even));
158 vacc_hi =
159 _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod7_odd, vprod7_even));
160
161 const __m128i vi8 = _mm_loadl_epi64((const __m128i*)i8);
162 i8 += 8;
163 const __m128i vxi8 =
164 sub_zero_point(_mm_unpacklo_epi8(vi8, vzero), va_zero_point);
165 const __m128i vk8 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 96));
166 const __m128i vxk8 =
167 _mm_sub_epi16(_mm_unpacklo_epi8(vk8, vzero), vkernel_zero_point);
168 const __m128i vprod8_odd = _mm_mullo_epi16(vxi8, vxk8);
169 const __m128i vprod8_even = _mm_mulhi_epi16(vxi8, vxk8);
170 vacc_lo =
171 _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod8_odd, vprod8_even));
172 vacc_hi =
173 _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod8_odd, vprod8_even));
174
175 w = (void*)((uintptr_t)w + 104);
176
177 const __m128 vmultiplier =
178 _mm_set1_ps(quantization_params->sse2.requantization_scales[0]);
179
180 vacc_lo = _mm_cvtps_epi32(
181 _mm_mul_ps(
182 _mm_cvtepi32_ps(vacc_lo),
183 vmultiplier
184 )
185 );
186 vacc_hi = _mm_cvtps_epi32(
187 _mm_mul_ps(
188 _mm_cvtepi32_ps(vacc_hi),
189 vmultiplier
190 )
191 );
192
193 const __m128i voutput_zero_point = _mm_load_si128(
194 (const __m128i*)quantization_params->sse2.output_zero_point);
195 __m128i vout =
196 _mm_adds_epi16(_mm_packs_epi32(vacc_lo, vacc_hi), voutput_zero_point);
197 vout = _mm_packus_epi16(vout, vout);
198 vout = _mm_min_epu8(
199 vout,
200 _mm_load_si128((const __m128i*)quantization_params->sse2.output_max));
201 vout = _mm_max_epu8(
202 vout,
203 _mm_load_si128((const __m128i*)quantization_params->sse2.output_min));
204
205 _mm_storel_epi64((__m128i*)output, vout);
206 output += 8;
207 }
208 if (c != 0) {
209 const size_t i_predecrement = 8 - c;
210 const __m128i vi_shift = _mm_cvtsi32_si128(8 * i_predecrement);
211 i0 -= i_predecrement;
212 i1 -= i_predecrement;
213 i2 -= i_predecrement;
214 i3 -= i_predecrement;
215 i4 -= i_predecrement;
216 i5 -= i_predecrement;
217 i6 -= i_predecrement;
218 i7 -= i_predecrement;
219 i8 -= i_predecrement;
220
221 __m128i vacc_lo = _mm_loadu_si128((const __m128i*)w);
222 __m128i vacc_hi = _mm_loadu_si128((const __m128i*)((uintptr_t)w + 16));
223
224 const __m128i vi0 =
225 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i0), vi_shift);
226 const __m128i vxi0 =
227 sub_zero_point(_mm_unpacklo_epi8(vi0, vzero), va_zero_point);
228 const __m128i vk0 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 32));
229 const __m128i vxk0 =
230 _mm_sub_epi16(_mm_unpacklo_epi8(vk0, vzero), vkernel_zero_point);
231 const __m128i vprod0_odd = _mm_mullo_epi16(vxi0, vxk0);
232 const __m128i vprod0_even = _mm_mulhi_epi16(vxi0, vxk0);
233 vacc_lo =
234 _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod0_odd, vprod0_even));
235 vacc_hi =
236 _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod0_odd, vprod0_even));
237
238 const __m128i vi1 =
239 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i1), vi_shift);
240 const __m128i vxi1 =
241 sub_zero_point(_mm_unpacklo_epi8(vi1, vzero), va_zero_point);
242 const __m128i vk1 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 40));
243 const __m128i vxk1 =
244 _mm_sub_epi16(_mm_unpacklo_epi8(vk1, vzero), vkernel_zero_point);
245 const __m128i vprod1_odd = _mm_mullo_epi16(vxi1, vxk1);
246 const __m128i vprod1_even = _mm_mulhi_epi16(vxi1, vxk1);
247 vacc_lo =
248 _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod1_odd, vprod1_even));
249 vacc_hi =
250 _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod1_odd, vprod1_even));
251
252 const __m128i vi2 =
253 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i2), vi_shift);
254 const __m128i vxi2 =
255 sub_zero_point(_mm_unpacklo_epi8(vi2, vzero), va_zero_point);
256 const __m128i vk2 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 48));
257 const __m128i vxk2 =
258 _mm_sub_epi16(_mm_unpacklo_epi8(vk2, vzero), vkernel_zero_point);
259 const __m128i vprod2_odd = _mm_mullo_epi16(vxi2, vxk2);
260 const __m128i vprod2_even = _mm_mulhi_epi16(vxi2, vxk2);
261 vacc_lo =
262 _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod2_odd, vprod2_even));
263 vacc_hi =
264 _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod2_odd, vprod2_even));
265
266 const __m128i vi3 =
267 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i3), vi_shift);
268 const __m128i vxi3 =
269 sub_zero_point(_mm_unpacklo_epi8(vi3, vzero), va_zero_point);
270 const __m128i vk3 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 56));
271 const __m128i vxk3 =
272 _mm_sub_epi16(_mm_unpacklo_epi8(vk3, vzero), vkernel_zero_point);
273 const __m128i vprod3_odd = _mm_mullo_epi16(vxi3, vxk3);
274 const __m128i vprod3_even = _mm_mulhi_epi16(vxi3, vxk3);
275 vacc_lo =
276 _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod3_odd, vprod3_even));
277 vacc_hi =
278 _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod3_odd, vprod3_even));
279
280 const __m128i vi4 =
281 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i4), vi_shift);
282 const __m128i vxi4 =
283 sub_zero_point(_mm_unpacklo_epi8(vi4, vzero), va_zero_point);
284 const __m128i vk4 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 64));
285 const __m128i vxk4 =
286 _mm_sub_epi16(_mm_unpacklo_epi8(vk4, vzero), vkernel_zero_point);
287 const __m128i vprod4_odd = _mm_mullo_epi16(vxi4, vxk4);
288 const __m128i vprod4_even = _mm_mulhi_epi16(vxi4, vxk4);
289 vacc_lo =
290 _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod4_odd, vprod4_even));
291 vacc_hi =
292 _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod4_odd, vprod4_even));
293
294 const __m128i vi5 =
295 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i5), vi_shift);
296 const __m128i vxi5 =
297 sub_zero_point(_mm_unpacklo_epi8(vi5, vzero), va_zero_point);
298 const __m128i vk5 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 72));
299 const __m128i vxk5 =
300 _mm_sub_epi16(_mm_unpacklo_epi8(vk5, vzero), vkernel_zero_point);
301 const __m128i vprod5_odd = _mm_mullo_epi16(vxi5, vxk5);
302 const __m128i vprod5_even = _mm_mulhi_epi16(vxi5, vxk5);
303 vacc_lo =
304 _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod5_odd, vprod5_even));
305 vacc_hi =
306 _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod5_odd, vprod5_even));
307
308 const __m128i vi6 =
309 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i6), vi_shift);
310 const __m128i vxi6 =
311 sub_zero_point(_mm_unpacklo_epi8(vi6, vzero), va_zero_point);
312 const __m128i vk6 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 80));
313 const __m128i vxk6 =
314 _mm_sub_epi16(_mm_unpacklo_epi8(vk6, vzero), vkernel_zero_point);
315 const __m128i vprod6_odd = _mm_mullo_epi16(vxi6, vxk6);
316 const __m128i vprod6_even = _mm_mulhi_epi16(vxi6, vxk6);
317 vacc_lo =
318 _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod6_odd, vprod6_even));
319 vacc_hi =
320 _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod6_odd, vprod6_even));
321
322 const __m128i vi7 =
323 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i7), vi_shift);
324 const __m128i vxi7 =
325 sub_zero_point(_mm_unpacklo_epi8(vi7, vzero), va_zero_point);
326 const __m128i vk7 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 88));
327 const __m128i vxk7 =
328 _mm_sub_epi16(_mm_unpacklo_epi8(vk7, vzero), vkernel_zero_point);
329 const __m128i vprod7_odd = _mm_mullo_epi16(vxi7, vxk7);
330 const __m128i vprod7_even = _mm_mulhi_epi16(vxi7, vxk7);
331 vacc_lo =
332 _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod7_odd, vprod7_even));
333 vacc_hi =
334 _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod7_odd, vprod7_even));
335
336 const __m128i vi8 =
337 _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i8), vi_shift);
338 const __m128i vxi8 =
339 sub_zero_point(_mm_unpacklo_epi8(vi8, vzero), va_zero_point);
340 const __m128i vk8 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 96));
341 const __m128i vxk8 =
342 _mm_sub_epi16(_mm_unpacklo_epi8(vk8, vzero), vkernel_zero_point);
343 const __m128i vprod8_odd = _mm_mullo_epi16(vxi8, vxk8);
344 const __m128i vprod8_even = _mm_mulhi_epi16(vxi8, vxk8);
345 vacc_lo =
346 _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod8_odd, vprod8_even));
347 vacc_hi =
348 _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod8_odd, vprod8_even));
349
350 const __m128 vmultiplier =
351 _mm_set1_ps(quantization_params->sse2.requantization_scales[0]);
352
353 vacc_lo = _mm_cvtps_epi32(
354 _mm_mul_ps(
355 _mm_cvtepi32_ps(vacc_lo),
356 vmultiplier
357 )
358 );
359 vacc_hi = _mm_cvtps_epi32(
360 _mm_mul_ps(
361 _mm_cvtepi32_ps(vacc_hi),
362 vmultiplier
363 )
364 );
365
366 const __m128i voutput_zero_point = _mm_load_si128(
367 (const __m128i*)quantization_params->sse2.output_zero_point);
368 __m128i vout =
369 _mm_adds_epi16(_mm_packs_epi32(vacc_lo, vacc_hi), voutput_zero_point);
370 vout = _mm_packus_epi16(vout, vout);
371 vout = _mm_min_epu8(
372 vout,
373 _mm_load_si128((const __m128i*)quantization_params->sse2.output_max));
374 vout = _mm_max_epu8(
375 vout,
376 _mm_load_si128((const __m128i*)quantization_params->sse2.output_min));
377
378 if (c & 4) {
379 *((uint32_t*)output) = (uint32_t)_mm_cvtsi128_si32(vout);
380 output += 4;
381 vout = _mm_srli_epi64(vout, 32);
382 }
383 if (c & 2) {
384 *((uint16_t*)output) = (uint16_t)_mm_extract_epi16(vout, 0);
385 output += 2;
386 vout = _mm_srli_epi32(vout, 16);
387 }
388 if (c & 1) {
389 *((uint8_t*)output) = (uint8_t)_mm_cvtsi128_si32(vout);
390 output += 1;
391 }
392 }
393
394 output = (uint8_t*)((uintptr_t)output + output_increment);
395 } while (--output_width != 0);
396 }
397