1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <immintrin.h>
10
11 #include <qnnpack/q8conv.h>
12 #include <requantization/runtime-sse2.h>
13
pytorch_q8conv_ukernel_4x4c2__sse2(size_t mr,size_t nr,size_t kc,size_t ks,const uint8_t ** restrict a,const void * restrict w,uint8_t * restrict c,size_t c_stride,size_t output_channel_index,const union pytorch_qnnp_conv_quantization_params quantization_params[RESTRICT_STATIC1])14 void pytorch_q8conv_ukernel_4x4c2__sse2(
15 size_t mr,
16 size_t nr,
17 size_t kc,
18 size_t ks,
19 const uint8_t** restrict a,
20 const void* restrict w,
21 uint8_t* restrict c,
22 size_t c_stride,
23 size_t output_channel_index,
24 const union pytorch_qnnp_conv_quantization_params
25 quantization_params[RESTRICT_STATIC 1]) {
26 __m128i vacc0x0123 = _mm_loadu_si128((const __m128i*)w);
27 __m128i vacc1x0123 = vacc0x0123;
28 __m128i vacc2x0123 = vacc0x0123;
29 __m128i vacc3x0123 = vacc0x0123;
30 w = (const void*)((uintptr_t)w + 16);
31
32 const __m128i va_zero_point = _mm_load_si128(
33 (const __m128i*)quantization_params->sse2.input_zero_point);
34 const int16_t vb_zero_point_0 =
35 quantization_params->sse2.kernel_zero_points[output_channel_index];
36 const int16_t vb_zero_point_1 =
37 quantization_params->sse2.kernel_zero_points[output_channel_index + 1];
38 const int16_t vb_zero_point_2 =
39 quantization_params->sse2.kernel_zero_points[output_channel_index + 2];
40 const int16_t vb_zero_point_3 =
41 quantization_params->sse2.kernel_zero_points[output_channel_index + 3];
42
43 const __m128i vb_zero_point = _mm_set_epi16(vb_zero_point_3,
44 vb_zero_point_3,
45 vb_zero_point_2,
46 vb_zero_point_2,
47 vb_zero_point_1,
48 vb_zero_point_1,
49 vb_zero_point_0,
50 vb_zero_point_0
51 );
52 const __m128i vzero = _mm_setzero_si128();
53 do {
54 const uint8_t* restrict a0 = *a++;
55 const uint8_t* restrict a1 = *a++;
56 const uint8_t* restrict a2 = *a++;
57 const uint8_t* restrict a3 = *a++;
58
59 size_t k = kc;
60 for (; k >= 8; k -= 8) {
61 const __m128i va0 = _mm_loadl_epi64((const __m128i*)a0);
62 const __m128i vxa0 =
63 sub_zero_point(_mm_unpacklo_epi8(va0, vzero), va_zero_point);
64 a0 += 8;
65 const __m128i va1 = _mm_loadl_epi64((const __m128i*)a1);
66 const __m128i vxa1 =
67 sub_zero_point(_mm_unpacklo_epi8(va1, vzero), va_zero_point);
68 a1 += 8;
69 const __m128i va2 = _mm_loadl_epi64((const __m128i*)a2);
70 const __m128i vxa2 =
71 sub_zero_point(_mm_unpacklo_epi8(va2, vzero), va_zero_point);
72 a2 += 8;
73 const __m128i va3 = _mm_loadl_epi64((const __m128i*)a3);
74 const __m128i vxa3 =
75 sub_zero_point(_mm_unpacklo_epi8(va3, vzero), va_zero_point);
76 a3 += 8;
77
78 const __m128i vb0 = _mm_loadl_epi64((const __m128i*)w);
79 const __m128i vxb0 =
80 _mm_sub_epi16(_mm_unpacklo_epi8(vb0, vzero), vb_zero_point);
81 vacc0x0123 = _mm_add_epi32(
82 vacc0x0123,
83 _mm_madd_epi16(
84 _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
85 vacc1x0123 = _mm_add_epi32(
86 vacc1x0123,
87 _mm_madd_epi16(
88 _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
89 vacc2x0123 = _mm_add_epi32(
90 vacc2x0123,
91 _mm_madd_epi16(
92 _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
93 vacc3x0123 = _mm_add_epi32(
94 vacc3x0123,
95 _mm_madd_epi16(
96 _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
97
98 const __m128i vb1 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 8));
99 const __m128i vxb1 =
100 _mm_sub_epi16(_mm_unpacklo_epi8(vb1, vzero), vb_zero_point);
101 vacc0x0123 = _mm_add_epi32(
102 vacc0x0123,
103 _mm_madd_epi16(
104 _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
105 vacc1x0123 = _mm_add_epi32(
106 vacc1x0123,
107 _mm_madd_epi16(
108 _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
109 vacc2x0123 = _mm_add_epi32(
110 vacc2x0123,
111 _mm_madd_epi16(
112 _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
113 vacc3x0123 = _mm_add_epi32(
114 vacc3x0123,
115 _mm_madd_epi16(
116 _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
117
118 const __m128i vb2 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 16));
119 const __m128i vxb2 =
120 _mm_sub_epi16(_mm_unpacklo_epi8(vb2, vzero), vb_zero_point);
121 vacc0x0123 = _mm_add_epi32(
122 vacc0x0123,
123 _mm_madd_epi16(
124 _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
125 vacc1x0123 = _mm_add_epi32(
126 vacc1x0123,
127 _mm_madd_epi16(
128 _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
129 vacc2x0123 = _mm_add_epi32(
130 vacc2x0123,
131 _mm_madd_epi16(
132 _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
133 vacc3x0123 = _mm_add_epi32(
134 vacc3x0123,
135 _mm_madd_epi16(
136 _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
137
138 const __m128i vb3 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 24));
139 const __m128i vxb3 =
140 _mm_sub_epi16(_mm_unpacklo_epi8(vb3, vzero), vb_zero_point);
141 vacc0x0123 = _mm_add_epi32(
142 vacc0x0123,
143 _mm_madd_epi16(
144 _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
145 vacc1x0123 = _mm_add_epi32(
146 vacc1x0123,
147 _mm_madd_epi16(
148 _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
149 vacc2x0123 = _mm_add_epi32(
150 vacc2x0123,
151 _mm_madd_epi16(
152 _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
153 vacc3x0123 = _mm_add_epi32(
154 vacc3x0123,
155 _mm_madd_epi16(
156 _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
157
158 w = (void*)((uintptr_t)w + 32);
159 }
160 if (k != 0) {
161 const size_t a_predecrement = 8 - k;
162 const __m128i va_shift = _mm_cvtsi32_si128(8 * a_predecrement);
163
164 const __m128i va0 = _mm_srl_epi64(
165 _mm_loadl_epi64((const __m128i*)(a0 - a_predecrement)), va_shift);
166 const __m128i vxa0 =
167 sub_zero_point(_mm_unpacklo_epi8(va0, vzero), va_zero_point);
168 const __m128i va1 = _mm_srl_epi64(
169 _mm_loadl_epi64((const __m128i*)(a1 - a_predecrement)), va_shift);
170 const __m128i vxa1 =
171 sub_zero_point(_mm_unpacklo_epi8(va1, vzero), va_zero_point);
172 const __m128i va2 = _mm_srl_epi64(
173 _mm_loadl_epi64((const __m128i*)(a2 - a_predecrement)), va_shift);
174 const __m128i vxa2 =
175 sub_zero_point(_mm_unpacklo_epi8(va2, vzero), va_zero_point);
176 const __m128i va3 = _mm_srl_epi64(
177 _mm_loadl_epi64((const __m128i*)(a3 - a_predecrement)), va_shift);
178 const __m128i vxa3 =
179 sub_zero_point(_mm_unpacklo_epi8(va3, vzero), va_zero_point);
180
181 const __m128i vb0 = _mm_loadl_epi64((const __m128i*)w);
182 const __m128i vxb0 =
183 _mm_sub_epi16(_mm_unpacklo_epi8(vb0, vzero), vb_zero_point);
184 w = (void*)((uintptr_t)w + 8);
185
186 vacc0x0123 = _mm_add_epi32(
187 vacc0x0123,
188 _mm_madd_epi16(
189 _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
190 vacc1x0123 = _mm_add_epi32(
191 vacc1x0123,
192 _mm_madd_epi16(
193 _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
194 vacc2x0123 = _mm_add_epi32(
195 vacc2x0123,
196 _mm_madd_epi16(
197 _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
198 vacc3x0123 = _mm_add_epi32(
199 vacc3x0123,
200 _mm_madd_epi16(
201 _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
202
203 if (k > 2) {
204 const __m128i vb1 = _mm_loadl_epi64((const __m128i*)w);
205 const __m128i vxb1 =
206 _mm_sub_epi16(_mm_unpacklo_epi8(vb1, vzero), vb_zero_point);
207 w = (void*)((uintptr_t)w + 8);
208
209 vacc0x0123 = _mm_add_epi32(
210 vacc0x0123,
211 _mm_madd_epi16(
212 _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
213 vacc1x0123 = _mm_add_epi32(
214 vacc1x0123,
215 _mm_madd_epi16(
216 _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
217 vacc2x0123 = _mm_add_epi32(
218 vacc2x0123,
219 _mm_madd_epi16(
220 _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
221 vacc3x0123 = _mm_add_epi32(
222 vacc3x0123,
223 _mm_madd_epi16(
224 _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
225
226 if (k > 4) {
227 const __m128i vb2 = _mm_loadl_epi64((const __m128i*)w);
228 const __m128i vxb2 =
229 _mm_sub_epi16(_mm_unpacklo_epi8(vb2, vzero), vb_zero_point);
230 w = (void*)((uintptr_t)w + 8);
231
232 vacc0x0123 = _mm_add_epi32(
233 vacc0x0123,
234 _mm_madd_epi16(
235 _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
236 vacc1x0123 = _mm_add_epi32(
237 vacc1x0123,
238 _mm_madd_epi16(
239 _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
240 vacc2x0123 = _mm_add_epi32(
241 vacc2x0123,
242 _mm_madd_epi16(
243 _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
244 vacc3x0123 = _mm_add_epi32(
245 vacc3x0123,
246 _mm_madd_epi16(
247 _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
248
249 if (k > 6) {
250 const __m128i vb3 = _mm_loadl_epi64((const __m128i*)w);
251 const __m128i vxb3 =
252 _mm_sub_epi16(_mm_unpacklo_epi8(vb3, vzero), vb_zero_point);
253 w = (void*)((uintptr_t)w + 8);
254
255 vacc0x0123 = _mm_add_epi32(
256 vacc0x0123,
257 _mm_madd_epi16(
258 _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
259 vacc1x0123 = _mm_add_epi32(
260 vacc1x0123,
261 _mm_madd_epi16(
262 _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
263 vacc2x0123 = _mm_add_epi32(
264 vacc2x0123,
265 _mm_madd_epi16(
266 _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
267 vacc3x0123 = _mm_add_epi32(
268 vacc3x0123,
269 _mm_madd_epi16(
270 _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3));
271 }
272 }
273 }
274 }
275 } while (--ks != 0);
276
277 const __m128 vmultiplier =
278 _mm_loadu_ps(&quantization_params->sse2.requantization_scales
279 [output_channel_index]);
280
281 vacc0x0123 = _mm_cvtps_epi32(
282 _mm_mul_ps(
283 _mm_cvtepi32_ps(vacc0x0123),
284 vmultiplier
285 )
286 );
287 vacc1x0123 = _mm_cvtps_epi32(
288 _mm_mul_ps(
289 _mm_cvtepi32_ps(vacc1x0123),
290 vmultiplier
291 )
292 );
293 vacc2x0123 = _mm_cvtps_epi32(
294 _mm_mul_ps(
295 _mm_cvtepi32_ps(vacc2x0123),
296 vmultiplier
297 )
298 );
299 vacc3x0123 = _mm_cvtps_epi32(
300 _mm_mul_ps(
301 _mm_cvtepi32_ps(vacc3x0123),
302 vmultiplier
303 )
304 );
305
306 const __m128i voutput_zero_point = _mm_load_si128(
307 (const __m128i*)quantization_params->sse2.output_zero_point);
308 const __m128i vacc01x0123 = _mm_adds_epi16(
309 _mm_packs_epi32(vacc0x0123, vacc1x0123), voutput_zero_point);
310 const __m128i vacc23x0123 = _mm_adds_epi16(
311 _mm_packs_epi32(vacc2x0123, vacc3x0123), voutput_zero_point);
312 __m128i vout = _mm_packus_epi16(vacc01x0123, vacc23x0123);
313 vout = _mm_min_epu8(
314 vout,
315 _mm_load_si128((const __m128i*)quantization_params->sse2.output_max));
316 vout = _mm_max_epu8(
317 vout,
318 _mm_load_si128((const __m128i*)quantization_params->sse2.output_min));
319
320 uint8_t* c0 = c;
321 uint8_t* c1 = (uint8_t*)((uintptr_t)c0 + c_stride);
322 if (mr < 2) {
323 c1 = c0;
324 }
325 uint8_t* c2 = (uint8_t*)((uintptr_t)c1 + c_stride);
326 if (mr <= 2) {
327 c2 = c1;
328 }
329 uint8_t* c3 = (uint8_t*)((uintptr_t)c2 + c_stride);
330 if (mr != 4) {
331 c3 = c2;
332 }
333 if (nr == 4) {
334 *((uint32_t*)c0) = (uint32_t)_mm_cvtsi128_si32(vout);
335 *((uint32_t*)c1) = (uint32_t)_mm_cvtsi128_si32(_mm_srli_epi64(vout, 32));
336 *((uint32_t*)c2) =
337 (uint32_t)_mm_cvtsi128_si32(_mm_unpackhi_epi32(vout, vout));
338 *((uint32_t*)c3) = (uint32_t)_mm_cvtsi128_si32(_mm_srli_si128(vout, 12));
339 } else {
340 typedef PYTORCH_QNNP_UNALIGNED uint16_t unaligned_uint16_t;
341 if (nr >= 2) {
342 *((unaligned_uint16_t*)c0) = (uint16_t)_mm_extract_epi16(vout, 0);
343 c0 += 2;
344 *((unaligned_uint16_t*)c1) = (uint16_t)_mm_extract_epi16(vout, 2);
345 c1 += 2;
346 *((unaligned_uint16_t*)c2) = (uint16_t)_mm_extract_epi16(vout, 4);
347 c2 += 2;
348 *((unaligned_uint16_t*)c3) = (uint16_t)_mm_extract_epi16(vout, 6);
349 c3 += 2;
350 vout = _mm_srli_epi32(vout, 16);
351 nr -= 2;
352 }
353 if (nr != 0) {
354 *((uint8_t*)c0) = (uint8_t)_mm_cvtsi128_si32(vout);
355 *((uint8_t*)c1) = (uint8_t)_mm_extract_epi16(vout, 2);
356 *((uint8_t*)c2) = (uint8_t)_mm_extract_epi16(vout, 4);
357 *((uint8_t*)c3) = (uint8_t)_mm_extract_epi16(vout, 6);
358 }
359 }
360 }
361