1 // Copyright 2021 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7
8 #include <immintrin.h>
9
10 #include <xnnpack/common.h>
11 #include <xnnpack/dwconv.h>
12 #include <xnnpack/gemm.h>
13 #include <xnnpack/igemm.h>
14 #include <xnnpack/intrinsics-polyfill.h>
15 #include <xnnpack/lut.h>
16 #include <xnnpack/math.h>
17 #include <xnnpack/pavgpool.h>
18 #include <xnnpack/raddstoreexpminusmax.h>
19 #include <xnnpack/unaligned.h>
20 #include <xnnpack/vadd.h>
21 #include <xnnpack/vcvt.h>
22 #include <xnnpack/vlrelu.h>
23 #include <xnnpack/vunary.h>
24
25
xnn_f16_gemm_minmax_ukernel_1x16__avx2_broadcast(size_t mr,size_t nc,size_t kc,const void * restrict a,size_t a_stride,const void * restrict w,void * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])26 void xnn_f16_gemm_minmax_ukernel_1x16__avx2_broadcast(
27 size_t mr,
28 size_t nc,
29 size_t kc,
30 const void*restrict a,
31 size_t a_stride,
32 const void*restrict w,
33 void*restrict c,
34 size_t cm_stride,
35 size_t cn_stride,
36 const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
37 {
38 assert(mr != 0);
39 assert(mr <= 1);
40 assert(nc != 0);
41 assert(kc != 0);
42 assert(kc % sizeof(uint16_t) == 0);
43 assert(a != NULL);
44 assert(w != NULL);
45 assert(c != NULL);
46
47 const uint16_t* a0 = a;
48 uint16_t* c0 = c;
49
50 do {
51 __m256 vacc0x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
52 __m256 vacc0x89ABCDEF = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) ((const uint16_t*) w + 8)));
53 w = (const uint16_t*) w + 16;
54
55 size_t k = kc;
56 do {
57 const __m256 va0 = _mm256_cvtph_ps(_mm_set1_epi16((short) *a0));
58 a0 += 1;
59
60 const __m256 vb01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
61 const __m256 vb89ABCDEF = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) ((const uint16_t*) w + 8)));
62 w = (const uint16_t*) w + 16;
63
64 vacc0x01234567 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(va0, vb01234567, vacc0x01234567), _MM_FROUND_NO_EXC));
65 vacc0x89ABCDEF = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(va0, vb89ABCDEF, vacc0x89ABCDEF), _MM_FROUND_NO_EXC));
66
67 k -= sizeof(uint16_t);
68 } while (k != 0);
69
70 const __m256 vmin = _mm256_load_ps(params->avx.min);
71 vacc0x01234567 = _mm256_max_ps(vacc0x01234567, vmin);
72 vacc0x89ABCDEF = _mm256_max_ps(vacc0x89ABCDEF, vmin);
73
74 const __m256 vmax = _mm256_load_ps(params->avx.max);
75 vacc0x01234567 = _mm256_min_ps(vacc0x01234567, vmax);
76 vacc0x89ABCDEF = _mm256_min_ps(vacc0x89ABCDEF, vmax);
77
78 if XNN_LIKELY(nc >= 16) {
79 _mm_storeu_si128((__m128i*) c0, _mm256_cvtps_ph(vacc0x01234567, _MM_FROUND_NO_EXC));
80 _mm_storeu_si128((__m128i*) (c0 + 8), _mm256_cvtps_ph(vacc0x89ABCDEF, _MM_FROUND_NO_EXC));
81 c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride);
82
83 a0 = (const uint16_t*) ((uintptr_t) a0 - kc);
84
85 nc -= 16;
86 } else {
87 __m128i vh0x01234567 = _mm256_cvtps_ph(vacc0x01234567, _MM_FROUND_NO_EXC);
88 if (nc & 8) {
89 _mm_storeu_si128((__m128i*) c0, vh0x01234567);
90
91 vh0x01234567 = _mm256_cvtps_ph(vacc0x89ABCDEF, _MM_FROUND_NO_EXC);
92
93 c0 += 8;
94 }
95 if (nc & 4) {
96 _mm_storel_epi64((__m128i*) c0, vh0x01234567);
97
98 vh0x01234567 = _mm_unpackhi_epi64(vh0x01234567, vh0x01234567);
99
100 c0 += 4;
101 }
102 if (nc & 2) {
103 _mm_storeu_si32(c0, vh0x01234567);
104
105 vh0x01234567 = _mm_srli_epi64(vh0x01234567, 32);
106
107 c0 += 2;
108 }
109 if (nc & 1) {
110 *c0 = (uint16_t) _mm_extract_epi16(vh0x01234567, 0);
111 }
112
113 nc = 0;
114 }
115 } while (nc != 0);
116 }
117
xnn_f16_gemm_minmax_ukernel_4x16__avx2_broadcast(size_t mr,size_t nc,size_t kc,const void * restrict a,size_t a_stride,const void * restrict w,void * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])118 void xnn_f16_gemm_minmax_ukernel_4x16__avx2_broadcast(
119 size_t mr,
120 size_t nc,
121 size_t kc,
122 const void*restrict a,
123 size_t a_stride,
124 const void*restrict w,
125 void*restrict c,
126 size_t cm_stride,
127 size_t cn_stride,
128 const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
129 {
130 assert(mr != 0);
131 assert(mr <= 4);
132 assert(nc != 0);
133 assert(kc != 0);
134 assert(kc % sizeof(uint16_t) == 0);
135 assert(a != NULL);
136 assert(w != NULL);
137 assert(c != NULL);
138
139 const uint16_t* a0 = a;
140 uint16_t* c0 = c;
141 const uint16_t* a1 = (const uint16_t*) ((uintptr_t) a0 + a_stride);
142 uint16_t* c1 = (uint16_t*) ((uintptr_t) c0 + cm_stride);
143 if XNN_UNPREDICTABLE(mr < 2) {
144 a1 = a0;
145 c1 = c0;
146 }
147 const uint16_t* a2 = (const uint16_t*) ((uintptr_t) a1 + a_stride);
148 uint16_t* c2 = (uint16_t*) ((uintptr_t) c1 + cm_stride);
149 if XNN_UNPREDICTABLE(mr <= 2) {
150 a2 = a1;
151 c2 = c1;
152 }
153 const uint16_t* a3 = (const uint16_t*) ((uintptr_t) a2 + a_stride);
154 uint16_t* c3 = (uint16_t*) ((uintptr_t) c2 + cm_stride);
155 if XNN_UNPREDICTABLE(mr != 4) {
156 a3 = a2;
157 c3 = c2;
158 }
159
160 do {
161 __m256 vacc0x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
162 __m256 vacc0x89ABCDEF = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) ((const uint16_t*) w + 8)));
163 __m256 vacc1x01234567 = vacc0x01234567;
164 __m256 vacc1x89ABCDEF = vacc0x89ABCDEF;
165 __m256 vacc2x01234567 = vacc0x01234567;
166 __m256 vacc2x89ABCDEF = vacc0x89ABCDEF;
167 __m256 vacc3x01234567 = vacc0x01234567;
168 __m256 vacc3x89ABCDEF = vacc0x89ABCDEF;
169 w = (const uint16_t*) w + 16;
170
171 size_t k = kc;
172 do {
173 const __m256 va0 = _mm256_cvtph_ps(_mm_set1_epi16((short) *a0));
174 a0 += 1;
175 const __m256 va1 = _mm256_cvtph_ps(_mm_set1_epi16((short) *a1));
176 a1 += 1;
177 const __m256 va2 = _mm256_cvtph_ps(_mm_set1_epi16((short) *a2));
178 a2 += 1;
179 const __m256 va3 = _mm256_cvtph_ps(_mm_set1_epi16((short) *a3));
180 a3 += 1;
181
182 const __m256 vb01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
183 const __m256 vb89ABCDEF = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) ((const uint16_t*) w + 8)));
184 w = (const uint16_t*) w + 16;
185
186 vacc0x01234567 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(va0, vb01234567, vacc0x01234567), _MM_FROUND_NO_EXC));
187 vacc1x01234567 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(va1, vb01234567, vacc1x01234567), _MM_FROUND_NO_EXC));
188 vacc2x01234567 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(va2, vb01234567, vacc2x01234567), _MM_FROUND_NO_EXC));
189 vacc3x01234567 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(va3, vb01234567, vacc3x01234567), _MM_FROUND_NO_EXC));
190 vacc0x89ABCDEF = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(va0, vb89ABCDEF, vacc0x89ABCDEF), _MM_FROUND_NO_EXC));
191 vacc1x89ABCDEF = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(va1, vb89ABCDEF, vacc1x89ABCDEF), _MM_FROUND_NO_EXC));
192 vacc2x89ABCDEF = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(va2, vb89ABCDEF, vacc2x89ABCDEF), _MM_FROUND_NO_EXC));
193 vacc3x89ABCDEF = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(va3, vb89ABCDEF, vacc3x89ABCDEF), _MM_FROUND_NO_EXC));
194
195 k -= sizeof(uint16_t);
196 } while (k != 0);
197
198 const __m256 vmin = _mm256_load_ps(params->avx.min);
199 vacc0x01234567 = _mm256_max_ps(vacc0x01234567, vmin);
200 vacc1x01234567 = _mm256_max_ps(vacc1x01234567, vmin);
201 vacc2x01234567 = _mm256_max_ps(vacc2x01234567, vmin);
202 vacc3x01234567 = _mm256_max_ps(vacc3x01234567, vmin);
203 vacc0x89ABCDEF = _mm256_max_ps(vacc0x89ABCDEF, vmin);
204 vacc1x89ABCDEF = _mm256_max_ps(vacc1x89ABCDEF, vmin);
205 vacc2x89ABCDEF = _mm256_max_ps(vacc2x89ABCDEF, vmin);
206 vacc3x89ABCDEF = _mm256_max_ps(vacc3x89ABCDEF, vmin);
207
208 const __m256 vmax = _mm256_load_ps(params->avx.max);
209 vacc0x01234567 = _mm256_min_ps(vacc0x01234567, vmax);
210 vacc1x01234567 = _mm256_min_ps(vacc1x01234567, vmax);
211 vacc2x01234567 = _mm256_min_ps(vacc2x01234567, vmax);
212 vacc3x01234567 = _mm256_min_ps(vacc3x01234567, vmax);
213 vacc0x89ABCDEF = _mm256_min_ps(vacc0x89ABCDEF, vmax);
214 vacc1x89ABCDEF = _mm256_min_ps(vacc1x89ABCDEF, vmax);
215 vacc2x89ABCDEF = _mm256_min_ps(vacc2x89ABCDEF, vmax);
216 vacc3x89ABCDEF = _mm256_min_ps(vacc3x89ABCDEF, vmax);
217
218 if XNN_LIKELY(nc >= 16) {
219 _mm_storeu_si128((__m128i*) c0, _mm256_cvtps_ph(vacc0x01234567, _MM_FROUND_NO_EXC));
220 _mm_storeu_si128((__m128i*) (c0 + 8), _mm256_cvtps_ph(vacc0x89ABCDEF, _MM_FROUND_NO_EXC));
221 c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride);
222 _mm_storeu_si128((__m128i*) c1, _mm256_cvtps_ph(vacc1x01234567, _MM_FROUND_NO_EXC));
223 _mm_storeu_si128((__m128i*) (c1 + 8), _mm256_cvtps_ph(vacc1x89ABCDEF, _MM_FROUND_NO_EXC));
224 c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride);
225 _mm_storeu_si128((__m128i*) c2, _mm256_cvtps_ph(vacc2x01234567, _MM_FROUND_NO_EXC));
226 _mm_storeu_si128((__m128i*) (c2 + 8), _mm256_cvtps_ph(vacc2x89ABCDEF, _MM_FROUND_NO_EXC));
227 c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride);
228 _mm_storeu_si128((__m128i*) c3, _mm256_cvtps_ph(vacc3x01234567, _MM_FROUND_NO_EXC));
229 _mm_storeu_si128((__m128i*) (c3 + 8), _mm256_cvtps_ph(vacc3x89ABCDEF, _MM_FROUND_NO_EXC));
230 c3 = (uint16_t*) ((uintptr_t) c3 + cn_stride);
231
232 a0 = (const uint16_t*) ((uintptr_t) a0 - kc);
233 a1 = (const uint16_t*) ((uintptr_t) a1 - kc);
234 a2 = (const uint16_t*) ((uintptr_t) a2 - kc);
235 a3 = (const uint16_t*) ((uintptr_t) a3 - kc);
236
237 nc -= 16;
238 } else {
239 __m128i vh0x01234567 = _mm256_cvtps_ph(vacc0x01234567, _MM_FROUND_NO_EXC);
240 __m128i vh1x01234567 = _mm256_cvtps_ph(vacc1x01234567, _MM_FROUND_NO_EXC);
241 __m128i vh2x01234567 = _mm256_cvtps_ph(vacc2x01234567, _MM_FROUND_NO_EXC);
242 __m128i vh3x01234567 = _mm256_cvtps_ph(vacc3x01234567, _MM_FROUND_NO_EXC);
243 if (nc & 8) {
244 _mm_storeu_si128((__m128i*) c0, vh0x01234567);
245 _mm_storeu_si128((__m128i*) c1, vh1x01234567);
246 _mm_storeu_si128((__m128i*) c2, vh2x01234567);
247 _mm_storeu_si128((__m128i*) c3, vh3x01234567);
248
249 vh0x01234567 = _mm256_cvtps_ph(vacc0x89ABCDEF, _MM_FROUND_NO_EXC);
250 vh1x01234567 = _mm256_cvtps_ph(vacc1x89ABCDEF, _MM_FROUND_NO_EXC);
251 vh2x01234567 = _mm256_cvtps_ph(vacc2x89ABCDEF, _MM_FROUND_NO_EXC);
252 vh3x01234567 = _mm256_cvtps_ph(vacc3x89ABCDEF, _MM_FROUND_NO_EXC);
253
254 c0 += 8;
255 c1 += 8;
256 c2 += 8;
257 c3 += 8;
258 }
259 if (nc & 4) {
260 _mm_storel_epi64((__m128i*) c0, vh0x01234567);
261 _mm_storel_epi64((__m128i*) c1, vh1x01234567);
262 _mm_storel_epi64((__m128i*) c2, vh2x01234567);
263 _mm_storel_epi64((__m128i*) c3, vh3x01234567);
264
265 vh0x01234567 = _mm_unpackhi_epi64(vh0x01234567, vh0x01234567);
266 vh1x01234567 = _mm_unpackhi_epi64(vh1x01234567, vh1x01234567);
267 vh2x01234567 = _mm_unpackhi_epi64(vh2x01234567, vh2x01234567);
268 vh3x01234567 = _mm_unpackhi_epi64(vh3x01234567, vh3x01234567);
269
270 c0 += 4;
271 c1 += 4;
272 c2 += 4;
273 c3 += 4;
274 }
275 if (nc & 2) {
276 _mm_storeu_si32(c0, vh0x01234567);
277 _mm_storeu_si32(c1, vh1x01234567);
278 _mm_storeu_si32(c2, vh2x01234567);
279 _mm_storeu_si32(c3, vh3x01234567);
280
281 vh0x01234567 = _mm_srli_epi64(vh0x01234567, 32);
282 vh1x01234567 = _mm_srli_epi64(vh1x01234567, 32);
283 vh2x01234567 = _mm_srli_epi64(vh2x01234567, 32);
284 vh3x01234567 = _mm_srli_epi64(vh3x01234567, 32);
285
286 c0 += 2;
287 c1 += 2;
288 c2 += 2;
289 c3 += 2;
290 }
291 if (nc & 1) {
292 *c0 = (uint16_t) _mm_extract_epi16(vh0x01234567, 0);
293 *c1 = (uint16_t) _mm_extract_epi16(vh1x01234567, 0);
294 *c2 = (uint16_t) _mm_extract_epi16(vh2x01234567, 0);
295 *c3 = (uint16_t) _mm_extract_epi16(vh3x01234567, 0);
296 }
297
298 nc = 0;
299 }
300 } while (nc != 0);
301 }
302
xnn_f16_igemm_minmax_ukernel_1x16__avx2_broadcast(size_t mr,size_t nc,size_t kc,size_t ks,const void ** restrict a,const void * restrict w,void * restrict c,size_t cm_stride,size_t cn_stride,size_t a_offset,const void * zero,const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])303 void xnn_f16_igemm_minmax_ukernel_1x16__avx2_broadcast(
304 size_t mr,
305 size_t nc,
306 size_t kc,
307 size_t ks,
308 const void**restrict a,
309 const void*restrict w,
310 void*restrict c,
311 size_t cm_stride,
312 size_t cn_stride,
313 size_t a_offset,
314 const void* zero,
315 const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
316 {
317 assert(mr != 0);
318 assert(mr <= 1);
319 assert(nc != 0);
320 assert(kc != 0);
321 assert(kc % sizeof(uint16_t) == 0);
322 assert(ks != 0);
323 assert(ks % (1 * sizeof(void*)) == 0);
324 assert(a_offset % sizeof(uint16_t) == 0);
325 assert(a != NULL);
326 assert(w != NULL);
327 assert(c != NULL);
328
329 uint16_t* c0 = c;
330
331 do {
332 __m256 vacc0x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
333 __m256 vacc0x89ABCDEF = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) ((const uint16_t*) w + 8)));
334 w = (const uint16_t*) w + 16;
335
336 size_t p = ks;
337 do {
338 const uint16_t* restrict a0 = (const uint16_t*) a[0];
339 assert(a0 != NULL);
340 if XNN_UNPREDICTABLE(a0 != zero) {
341 a0 = (const uint16_t*) ((uintptr_t) a0 + a_offset);
342 }
343 a += 1;
344
345 size_t k = kc;
346 do {
347 const __m256 vb01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
348 const __m256 vb89ABCDEF = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) ((const uint16_t*) w + 8)));
349 w = (const uint16_t*) w + 16;
350
351 const __m256 va0 = _mm256_cvtph_ps(_mm_set1_epi16((short) *a0));
352 a0 += 1;
353
354 vacc0x01234567 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(va0, vb01234567, vacc0x01234567), _MM_FROUND_NO_EXC));
355 vacc0x89ABCDEF = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(va0, vb89ABCDEF, vacc0x89ABCDEF), _MM_FROUND_NO_EXC));
356
357 k -= sizeof(uint16_t);
358 } while (k != 0);
359 p -= 1 * sizeof(void*);
360 } while (p != 0);
361
362 const __m256 vmin = _mm256_load_ps(params->avx.min);
363 vacc0x01234567 = _mm256_max_ps(vacc0x01234567, vmin);
364 vacc0x89ABCDEF = _mm256_max_ps(vacc0x89ABCDEF, vmin);
365
366 const __m256 vmax = _mm256_load_ps(params->avx.max);
367 vacc0x01234567 = _mm256_min_ps(vacc0x01234567, vmax);
368 vacc0x89ABCDEF = _mm256_min_ps(vacc0x89ABCDEF, vmax);
369
370 if XNN_LIKELY(nc >= 16) {
371 _mm_storeu_si128((__m128i*) c0, _mm256_cvtps_ph(vacc0x01234567, _MM_FROUND_NO_EXC));
372 _mm_storeu_si128((__m128i*) (c0 + 8), _mm256_cvtps_ph(vacc0x89ABCDEF, _MM_FROUND_NO_EXC));
373 c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride);
374
375 a = (const void**restrict) ((uintptr_t) a - ks);
376 nc -= 16;
377 } else {
378 __m128i vh0x01234567 = _mm256_cvtps_ph(vacc0x01234567, _MM_FROUND_NO_EXC);
379 if (nc & 8) {
380 _mm_storeu_si128((__m128i*) c0, vh0x01234567);
381
382 vh0x01234567 = _mm256_cvtps_ph(vacc0x89ABCDEF, _MM_FROUND_NO_EXC);
383
384 c0 += 8;
385 }
386 if (nc & 4) {
387 _mm_storel_epi64((__m128i*) c0, vh0x01234567);
388
389 vh0x01234567 = _mm_unpackhi_epi64(vh0x01234567, vh0x01234567);
390
391 c0 += 4;
392 }
393 if (nc & 2) {
394 _mm_storeu_si32(c0, vh0x01234567);
395
396 vh0x01234567 = _mm_srli_epi64(vh0x01234567, 32);
397
398 c0 += 2;
399 }
400 if (nc & 1) {
401 *c0 = _mm_extract_epi16(vh0x01234567, 0);
402 }
403
404 nc = 0;
405 }
406 } while (nc != 0);
407 }
408
xnn_f16_igemm_minmax_ukernel_4x16__avx2_broadcast(size_t mr,size_t nc,size_t kc,size_t ks,const void ** restrict a,const void * restrict w,void * restrict c,size_t cm_stride,size_t cn_stride,size_t a_offset,const void * zero,const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])409 void xnn_f16_igemm_minmax_ukernel_4x16__avx2_broadcast(
410 size_t mr,
411 size_t nc,
412 size_t kc,
413 size_t ks,
414 const void**restrict a,
415 const void*restrict w,
416 void*restrict c,
417 size_t cm_stride,
418 size_t cn_stride,
419 size_t a_offset,
420 const void* zero,
421 const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
422 {
423 assert(mr != 0);
424 assert(mr <= 4);
425 assert(nc != 0);
426 assert(kc != 0);
427 assert(kc % sizeof(uint16_t) == 0);
428 assert(ks != 0);
429 assert(ks % (4 * sizeof(void*)) == 0);
430 assert(a_offset % sizeof(uint16_t) == 0);
431 assert(a != NULL);
432 assert(w != NULL);
433 assert(c != NULL);
434
435 uint16_t* c0 = c;
436 uint16_t* c1 = (uint16_t*) ((uintptr_t) c0 + cm_stride);
437 if XNN_UNPREDICTABLE(mr < 2) {
438 c1 = c0;
439 }
440 uint16_t* c2 = (uint16_t*) ((uintptr_t) c1 + cm_stride);
441 if XNN_UNPREDICTABLE(mr <= 2) {
442 c2 = c1;
443 }
444 uint16_t* c3 = (uint16_t*) ((uintptr_t) c2 + cm_stride);
445 if XNN_UNPREDICTABLE(mr != 4) {
446 c3 = c2;
447 }
448
449 do {
450 __m256 vacc0x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
451 __m256 vacc0x89ABCDEF = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) ((const uint16_t*) w + 8)));
452 __m256 vacc1x01234567 = vacc0x01234567;
453 __m256 vacc1x89ABCDEF = vacc0x89ABCDEF;
454 __m256 vacc2x01234567 = vacc0x01234567;
455 __m256 vacc2x89ABCDEF = vacc0x89ABCDEF;
456 __m256 vacc3x01234567 = vacc0x01234567;
457 __m256 vacc3x89ABCDEF = vacc0x89ABCDEF;
458 w = (const uint16_t*) w + 16;
459
460 size_t p = ks;
461 do {
462 const uint16_t* restrict a0 = (const uint16_t*) a[0];
463 assert(a0 != NULL);
464 if XNN_UNPREDICTABLE(a0 != zero) {
465 a0 = (const uint16_t*) ((uintptr_t) a0 + a_offset);
466 }
467 const uint16_t* restrict a1 = (const uint16_t*) a[1];
468 assert(a1 != NULL);
469 if XNN_UNPREDICTABLE(a1 != zero) {
470 a1 = (const uint16_t*) ((uintptr_t) a1 + a_offset);
471 }
472 const uint16_t* restrict a2 = (const uint16_t*) a[2];
473 assert(a2 != NULL);
474 if XNN_UNPREDICTABLE(a2 != zero) {
475 a2 = (const uint16_t*) ((uintptr_t) a2 + a_offset);
476 }
477 const uint16_t* restrict a3 = (const uint16_t*) a[3];
478 assert(a3 != NULL);
479 if XNN_UNPREDICTABLE(a3 != zero) {
480 a3 = (const uint16_t*) ((uintptr_t) a3 + a_offset);
481 }
482 a += 4;
483
484 size_t k = kc;
485 do {
486 const __m256 vb01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
487 const __m256 vb89ABCDEF = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) ((const uint16_t*) w + 8)));
488 w = (const uint16_t*) w + 16;
489
490 const __m256 va0 = _mm256_cvtph_ps(_mm_set1_epi16((short) *a0));
491 a0 += 1;
492 const __m256 va1 = _mm256_cvtph_ps(_mm_set1_epi16((short) *a1));
493 a1 += 1;
494 const __m256 va2 = _mm256_cvtph_ps(_mm_set1_epi16((short) *a2));
495 a2 += 1;
496 const __m256 va3 = _mm256_cvtph_ps(_mm_set1_epi16((short) *a3));
497 a3 += 1;
498
499 vacc0x01234567 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(va0, vb01234567, vacc0x01234567), _MM_FROUND_NO_EXC));
500 vacc0x89ABCDEF = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(va0, vb89ABCDEF, vacc0x89ABCDEF), _MM_FROUND_NO_EXC));
501 vacc1x01234567 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(va1, vb01234567, vacc1x01234567), _MM_FROUND_NO_EXC));
502 vacc1x89ABCDEF = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(va1, vb89ABCDEF, vacc1x89ABCDEF), _MM_FROUND_NO_EXC));
503 vacc2x01234567 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(va2, vb01234567, vacc2x01234567), _MM_FROUND_NO_EXC));
504 vacc2x89ABCDEF = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(va2, vb89ABCDEF, vacc2x89ABCDEF), _MM_FROUND_NO_EXC));
505 vacc3x01234567 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(va3, vb01234567, vacc3x01234567), _MM_FROUND_NO_EXC));
506 vacc3x89ABCDEF = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(va3, vb89ABCDEF, vacc3x89ABCDEF), _MM_FROUND_NO_EXC));
507
508 k -= sizeof(uint16_t);
509 } while (k != 0);
510 p -= 4 * sizeof(void*);
511 } while (p != 0);
512
513 const __m256 vmin = _mm256_load_ps(params->avx.min);
514 vacc0x01234567 = _mm256_max_ps(vacc0x01234567, vmin);
515 vacc1x01234567 = _mm256_max_ps(vacc1x01234567, vmin);
516 vacc2x01234567 = _mm256_max_ps(vacc2x01234567, vmin);
517 vacc3x01234567 = _mm256_max_ps(vacc3x01234567, vmin);
518 vacc0x89ABCDEF = _mm256_max_ps(vacc0x89ABCDEF, vmin);
519 vacc1x89ABCDEF = _mm256_max_ps(vacc1x89ABCDEF, vmin);
520 vacc2x89ABCDEF = _mm256_max_ps(vacc2x89ABCDEF, vmin);
521 vacc3x89ABCDEF = _mm256_max_ps(vacc3x89ABCDEF, vmin);
522
523 const __m256 vmax = _mm256_load_ps(params->avx.max);
524 vacc0x01234567 = _mm256_min_ps(vacc0x01234567, vmax);
525 vacc1x01234567 = _mm256_min_ps(vacc1x01234567, vmax);
526 vacc2x01234567 = _mm256_min_ps(vacc2x01234567, vmax);
527 vacc3x01234567 = _mm256_min_ps(vacc3x01234567, vmax);
528 vacc0x89ABCDEF = _mm256_min_ps(vacc0x89ABCDEF, vmax);
529 vacc1x89ABCDEF = _mm256_min_ps(vacc1x89ABCDEF, vmax);
530 vacc2x89ABCDEF = _mm256_min_ps(vacc2x89ABCDEF, vmax);
531 vacc3x89ABCDEF = _mm256_min_ps(vacc3x89ABCDEF, vmax);
532
533 if XNN_LIKELY(nc >= 16) {
534 _mm_storeu_si128((__m128i*) c3, _mm256_cvtps_ph(vacc3x01234567, _MM_FROUND_NO_EXC));
535 _mm_storeu_si128((__m128i*) (c3 + 8), _mm256_cvtps_ph(vacc3x89ABCDEF, _MM_FROUND_NO_EXC));
536 c3 = (uint16_t*) ((uintptr_t) c3 + cn_stride);
537 _mm_storeu_si128((__m128i*) c2, _mm256_cvtps_ph(vacc2x01234567, _MM_FROUND_NO_EXC));
538 _mm_storeu_si128((__m128i*) (c2 + 8), _mm256_cvtps_ph(vacc2x89ABCDEF, _MM_FROUND_NO_EXC));
539 c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride);
540 _mm_storeu_si128((__m128i*) c1, _mm256_cvtps_ph(vacc1x01234567, _MM_FROUND_NO_EXC));
541 _mm_storeu_si128((__m128i*) (c1 + 8), _mm256_cvtps_ph(vacc1x89ABCDEF, _MM_FROUND_NO_EXC));
542 c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride);
543 _mm_storeu_si128((__m128i*) c0, _mm256_cvtps_ph(vacc0x01234567, _MM_FROUND_NO_EXC));
544 _mm_storeu_si128((__m128i*) (c0 + 8), _mm256_cvtps_ph(vacc0x89ABCDEF, _MM_FROUND_NO_EXC));
545 c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride);
546
547 a = (const void**restrict) ((uintptr_t) a - ks);
548 nc -= 16;
549 } else {
550 __m128i vh3x01234567 = _mm256_cvtps_ph(vacc3x01234567, _MM_FROUND_NO_EXC);
551 __m128i vh2x01234567 = _mm256_cvtps_ph(vacc2x01234567, _MM_FROUND_NO_EXC);
552 __m128i vh1x01234567 = _mm256_cvtps_ph(vacc1x01234567, _MM_FROUND_NO_EXC);
553 __m128i vh0x01234567 = _mm256_cvtps_ph(vacc0x01234567, _MM_FROUND_NO_EXC);
554 if (nc & 8) {
555 _mm_storeu_si128((__m128i*) c3, vh3x01234567);
556 _mm_storeu_si128((__m128i*) c2, vh2x01234567);
557 _mm_storeu_si128((__m128i*) c1, vh1x01234567);
558 _mm_storeu_si128((__m128i*) c0, vh0x01234567);
559
560 vh3x01234567 = _mm256_cvtps_ph(vacc3x89ABCDEF, _MM_FROUND_NO_EXC);
561 vh2x01234567 = _mm256_cvtps_ph(vacc2x89ABCDEF, _MM_FROUND_NO_EXC);
562 vh1x01234567 = _mm256_cvtps_ph(vacc1x89ABCDEF, _MM_FROUND_NO_EXC);
563 vh0x01234567 = _mm256_cvtps_ph(vacc0x89ABCDEF, _MM_FROUND_NO_EXC);
564
565 c3 += 8;
566 c2 += 8;
567 c1 += 8;
568 c0 += 8;
569 }
570 if (nc & 4) {
571 _mm_storel_epi64((__m128i*) c3, vh3x01234567);
572 _mm_storel_epi64((__m128i*) c2, vh2x01234567);
573 _mm_storel_epi64((__m128i*) c1, vh1x01234567);
574 _mm_storel_epi64((__m128i*) c0, vh0x01234567);
575
576 vh3x01234567 = _mm_unpackhi_epi64(vh3x01234567, vh3x01234567);
577 vh2x01234567 = _mm_unpackhi_epi64(vh2x01234567, vh2x01234567);
578 vh1x01234567 = _mm_unpackhi_epi64(vh1x01234567, vh1x01234567);
579 vh0x01234567 = _mm_unpackhi_epi64(vh0x01234567, vh0x01234567);
580
581 c3 += 4;
582 c2 += 4;
583 c1 += 4;
584 c0 += 4;
585 }
586 if (nc & 2) {
587 _mm_storeu_si32(c3, vh3x01234567);
588 _mm_storeu_si32(c2, vh2x01234567);
589 _mm_storeu_si32(c1, vh1x01234567);
590 _mm_storeu_si32(c0, vh0x01234567);
591
592 vh3x01234567 = _mm_srli_epi64(vh3x01234567, 32);
593 vh2x01234567 = _mm_srli_epi64(vh2x01234567, 32);
594 vh1x01234567 = _mm_srli_epi64(vh1x01234567, 32);
595 vh0x01234567 = _mm_srli_epi64(vh0x01234567, 32);
596
597 c3 += 2;
598 c2 += 2;
599 c1 += 2;
600 c0 += 2;
601 }
602 if (nc & 1) {
603 *c3 = _mm_extract_epi16(vh3x01234567, 0);
604 *c2 = _mm_extract_epi16(vh2x01234567, 0);
605 *c1 = _mm_extract_epi16(vh1x01234567, 0);
606 *c0 = _mm_extract_epi16(vh0x01234567, 0);
607 }
608
609 nc = 0;
610 }
611 } while (nc != 0);
612 }
613
xnn_f16_pavgpool_minmax_ukernel_9p8x__avx2_c8(size_t output_pixels,size_t kernel_elements,size_t channels,const void ** input,size_t input_offset,const void * zero,const void * multiplier,void * buffer,void * output,size_t input_increment,size_t output_increment,const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])614 void xnn_f16_pavgpool_minmax_ukernel_9p8x__avx2_c8(
615 size_t output_pixels,
616 size_t kernel_elements,
617 size_t channels,
618 const void** input,
619 size_t input_offset,
620 const void* zero,
621 const void* multiplier,
622 void* buffer,
623 void* output,
624 size_t input_increment,
625 size_t output_increment,
626 const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
627 {
628 assert(output_pixels != 0);
629 assert(kernel_elements > 9);
630 assert(channels != 0);
631
632 const __m256 voutput_min = _mm256_load_ps(params->avx.min);
633 const __m256 voutput_max = _mm256_load_ps(params->avx.max);
634
635 uint16_t* o = (uint16_t*) output;
636 do {
637 {
638 const uint16_t* i0 = (const uint16_t*) *input++;
639 assert(i0 != NULL);
640 if XNN_UNPREDICTABLE(i0 != zero) {
641 i0 = (const uint16_t*) ((uintptr_t) i0 + input_offset);
642 }
643 const uint16_t* i1 = (const uint16_t*) *input++;
644 assert(i1 != NULL);
645 if XNN_UNPREDICTABLE(i1 != zero) {
646 i1 = (const uint16_t*) ((uintptr_t) i1 + input_offset);
647 }
648 const uint16_t* i2 = (const uint16_t*) *input++;
649 assert(i2 != NULL);
650 if XNN_UNPREDICTABLE(i2 != zero) {
651 i2 = (const uint16_t*) ((uintptr_t) i2 + input_offset);
652 }
653 const uint16_t* i3 = (const uint16_t*) *input++;
654 assert(i3 != NULL);
655 if XNN_UNPREDICTABLE(i3 != zero) {
656 i3 = (const uint16_t*) ((uintptr_t) i3 + input_offset);
657 }
658 const uint16_t* i4 = (const uint16_t*) *input++;
659 assert(i4 != NULL);
660 if XNN_UNPREDICTABLE(i4 != zero) {
661 i4 = (const uint16_t*) ((uintptr_t) i4 + input_offset);
662 }
663 const uint16_t* i5 = (const uint16_t*) *input++;
664 assert(i5 != NULL);
665 if XNN_UNPREDICTABLE(i5 != zero) {
666 i5 = (const uint16_t*) ((uintptr_t) i5 + input_offset);
667 }
668 const uint16_t* i6 = (const uint16_t*) *input++;
669 assert(i6 != NULL);
670 if XNN_UNPREDICTABLE(i6 != zero) {
671 i6 = (const uint16_t*) ((uintptr_t) i6 + input_offset);
672 }
673 const uint16_t* i7 = (const uint16_t*) *input++;
674 assert(i7 != NULL);
675 if XNN_UNPREDICTABLE(i7 != zero) {
676 i7 = (const uint16_t*) ((uintptr_t) i7 + input_offset);
677 }
678 const uint16_t* i8 = (const uint16_t*) *input++;
679 assert(i8 != NULL);
680 if XNN_UNPREDICTABLE(i8 != zero) {
681 i8 = (const uint16_t*) ((uintptr_t) i8 + input_offset);
682 }
683
684 uint16_t* b = (uint16_t*) buffer;
685 for (size_t c = 0; c < channels; c += 8) {
686 const __m256 vi0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
687 i0 += 8;
688 const __m256 vi1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
689 i1 += 8;
690 const __m256 vi2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
691 i2 += 8;
692 const __m256 vi3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
693 i3 += 8;
694 const __m256 vi4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4));
695 i4 += 8;
696 const __m256 vi5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5));
697 i5 += 8;
698 const __m256 vi6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6));
699 i6 += 8;
700 const __m256 vi7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i7));
701 i7 += 8;
702 const __m256 vi8 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i8));
703 i8 += 8;
704
705 const __m256 vsum01 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi0, vi1), _MM_FROUND_NO_EXC));
706 const __m256 vsum23 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi2, vi3), _MM_FROUND_NO_EXC));
707 const __m256 vsum45 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi4, vi5), _MM_FROUND_NO_EXC));
708 const __m256 vsum67 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi6, vi7), _MM_FROUND_NO_EXC));
709 const __m256 vsum018 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum01, vi8), _MM_FROUND_NO_EXC));
710 const __m256 vsum2345 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum23, vsum45), _MM_FROUND_NO_EXC));
711 const __m256 vsum01678 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum018, vsum67), _MM_FROUND_NO_EXC));
712 const __m256 vsum = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum2345, vsum01678), _MM_FROUND_NO_EXC));
713
714 _mm_storeu_si128((__m128i*) b, _mm256_cvtps_ph(vsum, _MM_FROUND_NO_EXC));
715 b += 8;
716 }
717 }
718
719 size_t k = kernel_elements;
720 for (k -= 9; k > 8; k -= 8) {
721 const uint16_t* i0 = (const uint16_t*) *input++;
722 assert(i0 != NULL);
723 if XNN_UNPREDICTABLE(i0 != zero) {
724 i0 = (const uint16_t*) ((uintptr_t) i0 + input_offset);
725 }
726 const uint16_t* i1 = (const uint16_t*) *input++;
727 assert(i1 != NULL);
728 if XNN_UNPREDICTABLE(i1 != zero) {
729 i1 = (const uint16_t*) ((uintptr_t) i1 + input_offset);
730 }
731 const uint16_t* i2 = (const uint16_t*) *input++;
732 assert(i2 != NULL);
733 if XNN_UNPREDICTABLE(i2 != zero) {
734 i2 = (const uint16_t*) ((uintptr_t) i2 + input_offset);
735 }
736 const uint16_t* i3 = (const uint16_t*) *input++;
737 assert(i3 != NULL);
738 if XNN_UNPREDICTABLE(i3 != zero) {
739 i3 = (const uint16_t*) ((uintptr_t) i3 + input_offset);
740 }
741 const uint16_t* i4 = (const uint16_t*) *input++;
742 assert(i4 != NULL);
743 if XNN_UNPREDICTABLE(i4 != zero) {
744 i4 = (const uint16_t*) ((uintptr_t) i4 + input_offset);
745 }
746 const uint16_t* i5 = (const uint16_t*) *input++;
747 assert(i5 != NULL);
748 if XNN_UNPREDICTABLE(i5 != zero) {
749 i5 = (const uint16_t*) ((uintptr_t) i5 + input_offset);
750 }
751 const uint16_t* i6 = (const uint16_t*) *input++;
752 assert(i6 != NULL);
753 if XNN_UNPREDICTABLE(i6 != zero) {
754 i6 = (const uint16_t*) ((uintptr_t) i6 + input_offset);
755 }
756 const uint16_t* i7 = (const uint16_t*) *input++;
757 assert(i7 != NULL);
758 if XNN_UNPREDICTABLE(i7 != zero) {
759 i7 = (const uint16_t*) ((uintptr_t) i7 + input_offset);
760 }
761
762 uint16_t* b = (uint16_t*) buffer;
763 for (size_t c = 0; c < channels; c += 8) {
764 const __m256 vi0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
765 i0 += 8;
766 const __m256 vi1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
767 i1 += 8;
768 const __m256 vi2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
769 i2 += 8;
770 const __m256 vi3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
771 i3 += 8;
772 const __m256 vi4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4));
773 i4 += 8;
774 const __m256 vi5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5));
775 i5 += 8;
776 const __m256 vi6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6));
777 i6 += 8;
778 const __m256 vi7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i7));
779 i7 += 8;
780 const __m256 vacc = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) b));
781
782 const __m256 vsum01 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi0, vi1), _MM_FROUND_NO_EXC));
783 const __m256 vsum23 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi2, vi3), _MM_FROUND_NO_EXC));
784 const __m256 vsum45 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi4, vi5), _MM_FROUND_NO_EXC));
785 const __m256 vsum67 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi6, vi7), _MM_FROUND_NO_EXC));
786 const __m256 vsum01a = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum01, vacc), _MM_FROUND_NO_EXC));
787 const __m256 vsum2345 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum23, vsum45), _MM_FROUND_NO_EXC));
788 const __m256 vsum0167a = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum01a, vsum67), _MM_FROUND_NO_EXC));
789 const __m256 vsum = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum2345, vsum0167a), _MM_FROUND_NO_EXC));
790
791 _mm_storeu_si128((__m128i*) b, _mm256_cvtps_ph(vsum, _MM_FROUND_NO_EXC));
792 b += 8;
793 }
794 }
795
796 {
797 const uint16_t* i0 = (const uint16_t*) input[0];
798 assert(i0 != NULL);
799 const uint16_t* i1 = (const uint16_t*) input[1];
800 const uint16_t* i2 = (const uint16_t*) input[2];
801 const uint16_t* i3 = (const uint16_t*) input[3];
802 const uint16_t* i4 = (const uint16_t*) input[4];
803 const uint16_t* i5 = (const uint16_t*) input[5];
804 const uint16_t* i6 = (const uint16_t*) input[6];
805 const uint16_t* i7 = (const uint16_t*) input[7];
806 input = (const void**) ((uintptr_t) input + input_increment);
807 if (k < 2) {
808 i1 = (const uint16_t*) zero;
809 }
810 assert(i1 != NULL);
811 if (k <= 2) {
812 i2 = (const uint16_t*) zero;
813 }
814 assert(i2 != NULL);
815 if (k < 4) {
816 i3 = (const uint16_t*) zero;
817 }
818 assert(i3 != NULL);
819 if (k <= 4) {
820 i4 = (const uint16_t*) zero;
821 }
822 assert(i4 != NULL);
823 if (k < 6) {
824 i5 = (const uint16_t*) zero;
825 }
826 assert(i5 != NULL);
827 if (k <= 6) {
828 i6 = (const uint16_t*) zero;
829 }
830 assert(i6 != NULL);
831 if (k < 8) {
832 i7 = (const uint16_t*) zero;
833 }
834 assert(i7 != NULL);
835 if XNN_UNPREDICTABLE(i0 != zero) {
836 i0 = (const uint16_t*) ((uintptr_t) i0 + input_offset);
837 }
838 if XNN_UNPREDICTABLE(i1 != zero) {
839 i1 = (const uint16_t*) ((uintptr_t) i1 + input_offset);
840 }
841 if XNN_UNPREDICTABLE(i2 != zero) {
842 i2 = (const uint16_t*) ((uintptr_t) i2 + input_offset);
843 }
844 if XNN_UNPREDICTABLE(i3 != zero) {
845 i3 = (const uint16_t*) ((uintptr_t) i3 + input_offset);
846 }
847 if XNN_UNPREDICTABLE(i4 != zero) {
848 i4 = (const uint16_t*) ((uintptr_t) i4 + input_offset);
849 }
850 if XNN_UNPREDICTABLE(i5 != zero) {
851 i5 = (const uint16_t*) ((uintptr_t) i5 + input_offset);
852 }
853 if XNN_UNPREDICTABLE(i6 != zero) {
854 i6 = (const uint16_t*) ((uintptr_t) i6 + input_offset);
855 }
856 if XNN_UNPREDICTABLE(i7 != zero) {
857 i7 = (const uint16_t*) ((uintptr_t) i7 + input_offset);
858 }
859
860 const __m256 vmultiplier = _mm256_cvtph_ps(_mm_set1_epi16((short) *((const uint16_t*) multiplier)));
861 multiplier = (const uint16_t*) multiplier + 1;
862
863 size_t c = channels;
864 const uint16_t* b = (const uint16_t*) buffer;
865 while (c >= 8) {
866 const __m256 vi0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
867 i0 += 8;
868 const __m256 vi1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
869 i1 += 8;
870 const __m256 vi2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
871 i2 += 8;
872 const __m256 vi3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
873 i3 += 8;
874 const __m256 vi4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4));
875 i4 += 8;
876 const __m256 vi5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5));
877 i5 += 8;
878 const __m256 vi6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6));
879 i6 += 8;
880 const __m256 vi7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i7));
881 i7 += 8;
882 const __m256 vacc = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) b));
883 b += 8;
884
885 const __m256 vsum01 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi0, vi1), _MM_FROUND_NO_EXC));
886 const __m256 vsum23 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi2, vi3), _MM_FROUND_NO_EXC));
887 const __m256 vsum45 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi4, vi5), _MM_FROUND_NO_EXC));
888 const __m256 vsum67 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi6, vi7), _MM_FROUND_NO_EXC));
889 const __m256 vsum01a = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum01, vacc), _MM_FROUND_NO_EXC));
890 const __m256 vsum2345 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum23, vsum45), _MM_FROUND_NO_EXC));
891 const __m256 vsum0167a = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum01a, vsum67), _MM_FROUND_NO_EXC));
892 const __m256 vsum = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum2345, vsum0167a), _MM_FROUND_NO_EXC));
893
894 __m256 vout = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(vsum, vmultiplier), _MM_FROUND_NO_EXC));
895 vout = _mm256_max_ps(vout, voutput_min);
896 vout = _mm256_min_ps(vout, voutput_max);
897
898 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vout, _MM_FROUND_NO_EXC));
899 o += 8;
900
901 c -= 8;
902 }
903 if (c != 0) {
904 const __m256 vi0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
905 const __m256 vi1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
906 const __m256 vi2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
907 const __m256 vi3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
908 const __m256 vi4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4));
909 const __m256 vi5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5));
910 const __m256 vi6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6));
911 const __m256 vi7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i7));
912 const __m256 vacc = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) b));
913
914 const __m256 vsum01 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi0, vi1), _MM_FROUND_NO_EXC));
915 const __m256 vsum23 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi2, vi3), _MM_FROUND_NO_EXC));
916 const __m256 vsum45 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi4, vi5), _MM_FROUND_NO_EXC));
917 const __m256 vsum67 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi6, vi7), _MM_FROUND_NO_EXC));
918 const __m256 vsum01a = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum01, vacc), _MM_FROUND_NO_EXC));
919 const __m256 vsum2345 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum23, vsum45), _MM_FROUND_NO_EXC));
920 const __m256 vsum0167a = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum01a, vsum67), _MM_FROUND_NO_EXC));
921 const __m256 vsum = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum2345, vsum0167a), _MM_FROUND_NO_EXC));
922
923 __m256 vout = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(vsum, vmultiplier), _MM_FROUND_NO_EXC));
924 vout = _mm256_max_ps(vout, voutput_min);
925 vout = _mm256_min_ps(vout, voutput_max);
926
927 __m128i vh = _mm256_cvtps_ph(vout, _MM_FROUND_NO_EXC);
928 if (c & 4) {
929 _mm_storel_epi64((__m128i*) o, vh);
930 vh = _mm_unpackhi_epi64(vh, vh);
931 o += 4;
932 }
933 if (c & 2) {
934 _mm_storeu_si32(o, vh);
935 vh = _mm_srli_epi64(vh, 32);
936 o += 2;
937 }
938 if (c & 1) {
939 *o = (uint16_t) _mm_extract_epi16(vh, 0);
940 o += 1;
941 }
942 }
943 }
944 o = (uint16_t*) ((uintptr_t) o + output_increment);
945 } while (--output_pixels != 0);
946 }
947
xnn_f16_pavgpool_minmax_ukernel_9x__avx2_c8(size_t output_pixels,size_t kernel_elements,size_t channels,const void ** input,size_t input_offset,const void * zero,const void * multiplier,void * output,size_t input_increment,size_t output_increment,const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])948 void xnn_f16_pavgpool_minmax_ukernel_9x__avx2_c8(
949 size_t output_pixels,
950 size_t kernel_elements,
951 size_t channels,
952 const void** input,
953 size_t input_offset,
954 const void* zero,
955 const void* multiplier,
956 void* output,
957 size_t input_increment,
958 size_t output_increment,
959 const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
960 {
961 assert(output_pixels != 0);
962 assert(kernel_elements != 0);
963 assert(kernel_elements <= 9);
964 assert(channels != 0);
965
966 const __m256 voutput_min = _mm256_load_ps(params->avx.min);
967 const __m256 voutput_max = _mm256_load_ps(params->avx.max);
968
969 uint16_t* o = (uint16_t*) output;
970 do {
971 const uint16_t* i0 = (const uint16_t*) input[0];
972 assert(i0 != NULL);
973 const uint16_t* i1 = (const uint16_t*) input[1];
974 const uint16_t* i2 = (const uint16_t*) input[2];
975 const uint16_t* i3 = (const uint16_t*) input[3];
976 const uint16_t* i4 = (const uint16_t*) input[4];
977 const uint16_t* i5 = (const uint16_t*) input[5];
978 const uint16_t* i6 = (const uint16_t*) input[6];
979 const uint16_t* i7 = (const uint16_t*) input[7];
980 const uint16_t* i8 = (const uint16_t*) input[8];
981 input = (const void**) ((uintptr_t) input + input_increment);
982 if (kernel_elements < 2) {
983 i1 = (const uint16_t*) zero;
984 }
985 assert(i1 != NULL);
986 if (kernel_elements <= 2) {
987 i2 = (const uint16_t*) zero;
988 }
989 assert(i2 != NULL);
990 if (kernel_elements < 4) {
991 i3 = (const uint16_t*) zero;
992 }
993 assert(i3 != NULL);
994 if (kernel_elements <= 4) {
995 i4 = (const uint16_t*) zero;
996 }
997 assert(i4 != NULL);
998 if (kernel_elements < 6) {
999 i5 = (const uint16_t*) zero;
1000 }
1001 assert(i5 != NULL);
1002 if (kernel_elements <= 6) {
1003 i6 = (const uint16_t*) zero;
1004 }
1005 assert(i6 != NULL);
1006 if (kernel_elements < 8) {
1007 i7 = (const uint16_t*) zero;
1008 }
1009 assert(i7 != NULL);
1010 if (kernel_elements <= 8) {
1011 i8 = (const uint16_t*) zero;
1012 }
1013 assert(i8 != NULL);
1014 if XNN_UNPREDICTABLE(i0 != zero) {
1015 i0 = (const uint16_t*) ((uintptr_t) i0 + input_offset);
1016 }
1017 if XNN_UNPREDICTABLE(i1 != zero) {
1018 i1 = (const uint16_t*) ((uintptr_t) i1 + input_offset);
1019 }
1020 if XNN_UNPREDICTABLE(i2 != zero) {
1021 i2 = (const uint16_t*) ((uintptr_t) i2 + input_offset);
1022 }
1023 if XNN_UNPREDICTABLE(i3 != zero) {
1024 i3 = (const uint16_t*) ((uintptr_t) i3 + input_offset);
1025 }
1026 if XNN_UNPREDICTABLE(i4 != zero) {
1027 i4 = (const uint16_t*) ((uintptr_t) i4 + input_offset);
1028 }
1029 if XNN_UNPREDICTABLE(i5 != zero) {
1030 i5 = (const uint16_t*) ((uintptr_t) i5 + input_offset);
1031 }
1032 if XNN_UNPREDICTABLE(i6 != zero) {
1033 i6 = (const uint16_t*) ((uintptr_t) i6 + input_offset);
1034 }
1035 if XNN_UNPREDICTABLE(i7 != zero) {
1036 i7 = (const uint16_t*) ((uintptr_t) i7 + input_offset);
1037 }
1038 if XNN_UNPREDICTABLE(i8 != zero) {
1039 i8 = (const uint16_t*) ((uintptr_t) i8 + input_offset);
1040 }
1041
1042 const __m256 vmultiplier = _mm256_cvtph_ps(_mm_set1_epi16((short) *((const uint16_t*) multiplier)));
1043 multiplier = (const uint16_t*) multiplier + 1;
1044
1045 size_t c = channels;
1046 while (c >= 8) {
1047 const __m256 vi0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
1048 i0 += 8;
1049 const __m256 vi1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
1050 i1 += 8;
1051 const __m256 vi2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
1052 i2 += 8;
1053 const __m256 vi3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
1054 i3 += 8;
1055 const __m256 vi4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4));
1056 i4 += 8;
1057 const __m256 vi5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5));
1058 i5 += 8;
1059 const __m256 vi6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6));
1060 i6 += 8;
1061 const __m256 vi7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i7));
1062 i7 += 8;
1063 const __m256 vi8 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i8));
1064 i8 += 8;
1065
1066 const __m256 vsum01 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi0, vi1), _MM_FROUND_NO_EXC));
1067 const __m256 vsum23 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi2, vi3), _MM_FROUND_NO_EXC));
1068 const __m256 vsum45 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi4, vi5), _MM_FROUND_NO_EXC));
1069 const __m256 vsum67 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi6, vi7), _MM_FROUND_NO_EXC));
1070 const __m256 vsum018 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum01, vi8), _MM_FROUND_NO_EXC));
1071 const __m256 vsum2345 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum23, vsum45), _MM_FROUND_NO_EXC));
1072 const __m256 vsum01678 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum018, vsum67), _MM_FROUND_NO_EXC));
1073 const __m256 vsum = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum2345, vsum01678), _MM_FROUND_NO_EXC));
1074
1075 __m256 vout = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(vsum, vmultiplier), _MM_FROUND_NO_EXC));
1076 vout = _mm256_max_ps(vout, voutput_min);
1077 vout = _mm256_min_ps(vout, voutput_max);
1078
1079 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vout, _MM_FROUND_NO_EXC));
1080 o += 8;
1081
1082 c -= 8;
1083 }
1084 if (c != 0) {
1085 const __m256 vi0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
1086 const __m256 vi1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
1087 const __m256 vi2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
1088 const __m256 vi3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
1089 const __m256 vi4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4));
1090 const __m256 vi5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5));
1091 const __m256 vi6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6));
1092 const __m256 vi7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i7));
1093 const __m256 vi8 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i8));
1094
1095 const __m256 vsum01 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi0, vi1), _MM_FROUND_NO_EXC));
1096 const __m256 vsum23 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi2, vi3), _MM_FROUND_NO_EXC));
1097 const __m256 vsum45 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi4, vi5), _MM_FROUND_NO_EXC));
1098 const __m256 vsum67 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi6, vi7), _MM_FROUND_NO_EXC));
1099 const __m256 vsum018 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum01, vi8), _MM_FROUND_NO_EXC));
1100 const __m256 vsum2345 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum23, vsum45), _MM_FROUND_NO_EXC));
1101 const __m256 vsum01678 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum018, vsum67), _MM_FROUND_NO_EXC));
1102 const __m256 vsum = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum2345, vsum01678), _MM_FROUND_NO_EXC));
1103
1104 __m256 vout = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(vsum, vmultiplier), _MM_FROUND_NO_EXC));
1105 vout = _mm256_max_ps(vout, voutput_min);
1106 vout = _mm256_min_ps(vout, voutput_max);
1107
1108 __m128i vh = _mm256_cvtps_ph(vout, _MM_FROUND_NO_EXC);
1109 if (c & 4) {
1110 _mm_storel_epi64((__m128i*) o, vh);
1111 vh = _mm_unpackhi_epi64(vh, vh);
1112 o += 4;
1113 }
1114 if (c & 2) {
1115 _mm_storeu_si32(o, vh);
1116 vh = _mm_srli_epi64(vh, 32);
1117 o += 2;
1118 }
1119 if (c & 1) {
1120 *o = (uint16_t) _mm_extract_epi16(vh, 0);
1121 o += 1;
1122 }
1123 }
1124 o = (uint16_t*) ((uintptr_t) o + output_increment);
1125 } while (--output_pixels != 0);
1126 }
1127
xnn_f16_raddstoreexpminusmax_ukernel__avx2_rr1_p2_x40(size_t batch,const void * input,const void * max,void * output,void * sum,const union xnn_f16_expminus_params params[restrict XNN_MIN_ELEMENTS (1)])1128 void xnn_f16_raddstoreexpminusmax_ukernel__avx2_rr1_p2_x40(
1129 size_t batch,
1130 const void* input,
1131 const void* max,
1132 void* output,
1133 void* sum,
1134 const union xnn_f16_expminus_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
1135 {
1136 assert(batch % sizeof(uint16_t) == 0);
1137
1138 const __m256 vi_max = _mm256_cvtph_ps(_mm_set1_epi16((short) *((const uint16_t*) max)));
1139 const __m256 vlog2e = _mm256_load_ps(params->avx2_rr1_p2.log2e);
1140 const __m256 vmagic_bias = _mm256_load_ps(params->avx2_rr1_p2.magic_bias);
1141 const __m256 vminus_ln2 = _mm256_load_ps(params->avx2_rr1_p2.minus_ln2);
1142 const __m256 vc2 = _mm256_load_ps(params->avx2_rr1_p2.c2);
1143 const __m256 vc1 = _mm256_load_ps(params->avx2_rr1_p2.c1);
1144 const __m256 vdenorm_cutoff = _mm256_load_ps(params->avx2_rr1_p2.denorm_cutoff);
1145
1146 const uint16_t* i = (const uint16_t*) input;
1147 uint16_t* o = (uint16_t*) output;
1148 __m256 vacc0 = _mm256_setzero_ps();
1149 for (; batch >= 40 * sizeof(uint16_t); batch -= 40 * sizeof(uint16_t)) {
1150 const __m256 vi0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i));
1151 const __m256 vi1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i + 8)));
1152 const __m256 vi2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i + 16)));
1153 const __m256 vi3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i + 24)));
1154 const __m256 vi4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i + 32)));
1155 i += 40;
1156
1157 const __m256 vx0 = _mm256_sub_ps(vi0, vi_max);
1158 const __m256 vx1 = _mm256_sub_ps(vi1, vi_max);
1159 const __m256 vx2 = _mm256_sub_ps(vi2, vi_max);
1160 const __m256 vx3 = _mm256_sub_ps(vi3, vi_max);
1161 const __m256 vx4 = _mm256_sub_ps(vi4, vi_max);
1162
1163 __m256 vn0 = _mm256_fmadd_ps(vx0, vlog2e, vmagic_bias);
1164 __m256 vn1 = _mm256_fmadd_ps(vx1, vlog2e, vmagic_bias);
1165 __m256 vn2 = _mm256_fmadd_ps(vx2, vlog2e, vmagic_bias);
1166 __m256 vn3 = _mm256_fmadd_ps(vx3, vlog2e, vmagic_bias);
1167 __m256 vn4 = _mm256_fmadd_ps(vx4, vlog2e, vmagic_bias);
1168
1169 const __m256 vs0 = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn0), 23));
1170 const __m256 vs1 = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn1), 23));
1171 const __m256 vs2 = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn2), 23));
1172 const __m256 vs3 = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn3), 23));
1173 const __m256 vs4 = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn4), 23));
1174
1175 vn0 = _mm256_sub_ps(vn0, vmagic_bias);
1176 vn1 = _mm256_sub_ps(vn1, vmagic_bias);
1177 vn2 = _mm256_sub_ps(vn2, vmagic_bias);
1178 vn3 = _mm256_sub_ps(vn3, vmagic_bias);
1179 vn4 = _mm256_sub_ps(vn4, vmagic_bias);
1180
1181 __m256 vt0 = _mm256_fmadd_ps(vn0, vminus_ln2, vx0);
1182 __m256 vt1 = _mm256_fmadd_ps(vn1, vminus_ln2, vx1);
1183 __m256 vt2 = _mm256_fmadd_ps(vn2, vminus_ln2, vx2);
1184 __m256 vt3 = _mm256_fmadd_ps(vn3, vminus_ln2, vx3);
1185 __m256 vt4 = _mm256_fmadd_ps(vn4, vminus_ln2, vx4);
1186
1187 const __m256 vp0 = _mm256_fmadd_ps(vc2, vt0, vc1);
1188 const __m256 vp1 = _mm256_fmadd_ps(vc2, vt1, vc1);
1189 const __m256 vp2 = _mm256_fmadd_ps(vc2, vt2, vc1);
1190 const __m256 vp3 = _mm256_fmadd_ps(vc2, vt3, vc1);
1191 const __m256 vp4 = _mm256_fmadd_ps(vc2, vt4, vc1);
1192
1193 vt0 = _mm256_mul_ps(vt0, vs0);
1194 vt1 = _mm256_mul_ps(vt1, vs1);
1195 vt2 = _mm256_mul_ps(vt2, vs2);
1196 vt3 = _mm256_mul_ps(vt3, vs3);
1197 vt4 = _mm256_mul_ps(vt4, vs4);
1198
1199 __m256 vf0 = _mm256_fmadd_ps(vt0, vp0, vs0);
1200 __m256 vf1 = _mm256_fmadd_ps(vt1, vp1, vs1);
1201 __m256 vf2 = _mm256_fmadd_ps(vt2, vp2, vs2);
1202 __m256 vf3 = _mm256_fmadd_ps(vt3, vp3, vs3);
1203 __m256 vf4 = _mm256_fmadd_ps(vt4, vp4, vs4);
1204
1205 vf0 = _mm256_andnot_ps(_mm256_cmp_ps(vx0, vdenorm_cutoff, _CMP_LT_OS), vf0);
1206 vf1 = _mm256_andnot_ps(_mm256_cmp_ps(vx1, vdenorm_cutoff, _CMP_LT_OS), vf1);
1207 vf2 = _mm256_andnot_ps(_mm256_cmp_ps(vx2, vdenorm_cutoff, _CMP_LT_OS), vf2);
1208 vf3 = _mm256_andnot_ps(_mm256_cmp_ps(vx3, vdenorm_cutoff, _CMP_LT_OS), vf3);
1209 vf4 = _mm256_andnot_ps(_mm256_cmp_ps(vx4, vdenorm_cutoff, _CMP_LT_OS), vf4);
1210
1211 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vf0, _MM_FROUND_NO_EXC));
1212 _mm_storeu_si128((__m128i*) (o + 8), _mm256_cvtps_ph(vf1, _MM_FROUND_NO_EXC));
1213 _mm_storeu_si128((__m128i*) (o + 16), _mm256_cvtps_ph(vf2, _MM_FROUND_NO_EXC));
1214 _mm_storeu_si128((__m128i*) (o + 24), _mm256_cvtps_ph(vf3, _MM_FROUND_NO_EXC));
1215 _mm_storeu_si128((__m128i*) (o + 32), _mm256_cvtps_ph(vf4, _MM_FROUND_NO_EXC));
1216 o += 40;
1217
1218 vacc0 = _mm256_add_ps(vacc0, vf0);
1219 vacc0 = _mm256_add_ps(vacc0, vf1);
1220 vacc0 = _mm256_add_ps(vacc0, vf2);
1221 vacc0 = _mm256_add_ps(vacc0, vf3);
1222 vacc0 = _mm256_add_ps(vacc0, vf4);
1223 }
1224
1225 __m256 vacc = vacc0;
1226 for (; batch >= 8 * sizeof(uint16_t); batch -= 8 * sizeof(uint16_t)) {
1227 const __m256 vi = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i));
1228 i += 8;
1229
1230 const __m256 vx = _mm256_sub_ps(vi, vi_max);
1231
1232 __m256 vn = _mm256_fmadd_ps(vx, vlog2e, vmagic_bias);
1233
1234 const __m256 vs = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn), 23));
1235
1236 vn = _mm256_sub_ps(vn, vmagic_bias);
1237
1238 __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vx);
1239
1240 const __m256 vp = _mm256_fmadd_ps(vc2, vt, vc1);
1241 vt = _mm256_mul_ps(vt, vs);
1242 __m256 vf = _mm256_fmadd_ps(vt, vp, vs);
1243 vf = _mm256_andnot_ps(_mm256_cmp_ps(vx, vdenorm_cutoff, _CMP_LT_OS), vf);
1244
1245 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vf, _MM_FROUND_NO_EXC));
1246 o += 8;
1247
1248 vacc = _mm256_add_ps(vacc, vf);
1249 }
1250 __m128 vacc_lo = _mm_add_ps(_mm256_castps256_ps128(vacc), _mm256_extractf128_ps(vacc, 1));
1251 if (batch != 0) {
1252 assert(batch >= 1 * sizeof(uint16_t));
1253 assert(batch <= 7 * sizeof(uint16_t));
1254
1255 const __m256 vi = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i));
1256
1257 const __m256 vx = _mm256_sub_ps(vi, vi_max);
1258
1259 __m256 vn = _mm256_fmadd_ps(vx, vlog2e, vmagic_bias);
1260
1261 const __m256 vs = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn), 23));
1262
1263 vn = _mm256_sub_ps(vn, vmagic_bias);
1264
1265 __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vx);
1266
1267 const __m256 vp = _mm256_fmadd_ps(vc2, vt, vc1);
1268 vt = _mm256_mul_ps(vt, vs);
1269 __m256 vf = _mm256_fmadd_ps(vt, vp, vs);
1270 vf = _mm256_andnot_ps(_mm256_cmp_ps(vx, vdenorm_cutoff, _CMP_LT_OS), vf);
1271
1272 __m128i vh = _mm256_cvtps_ph(vf, _MM_FROUND_NO_EXC);
1273 __m128 vf_lo = _mm256_castps256_ps128(vf);
1274 if (batch & (4 * sizeof(uint16_t))) {
1275 _mm_storel_epi64((__m128i*) o, vh);
1276 vh = _mm_unpackhi_epi64(vh, vh);
1277 vacc_lo = _mm_add_ps(vacc_lo, vf_lo);
1278 vf_lo = _mm256_extractf128_ps(vf, 1);
1279 o += 4;
1280 }
1281 if (batch & (2 * sizeof(uint16_t))) {
1282 _mm_storeu_si32(o, vh);
1283 vh = _mm_srli_epi64(vh, 32);
1284 vacc_lo = _mm_blend_ps(_mm_add_ps(vacc_lo, vf_lo), vacc_lo, 0xC);
1285 vf_lo = _mm_movehl_ps(vf_lo, vf_lo);
1286 o += 2;
1287 }
1288 if (batch & (1 * sizeof(uint16_t))) {
1289 *o = (uint16_t) _mm_extract_epi16(vh, 0);
1290 vacc_lo = _mm_add_ss(vacc_lo, vf_lo);
1291 }
1292 }
1293 vacc_lo = _mm_add_ps(vacc_lo, _mm_movehl_ps(vacc_lo, vacc_lo));
1294 vacc_lo = _mm_add_ss(vacc_lo, _mm_movehdup_ps(vacc_lo));
1295 *((uint16_t*) sum) = (uint16_t) _mm_extract_epi16(_mm_cvtps_ph(vacc_lo, _MM_FROUND_NO_EXC), 0);
1296 _mm256_zeroupper();
1297 }
1298
xnn_f16_velu_ukernel__avx2_rr1_p3_x16(size_t n,const void * input,void * output,const union xnn_f16_elu_params params[restrict XNN_MIN_ELEMENTS (1)])1299 void xnn_f16_velu_ukernel__avx2_rr1_p3_x16(
1300 size_t n,
1301 const void* input,
1302 void* output,
1303 const union xnn_f16_elu_params params[restrict XNN_MIN_ELEMENTS(1)])
1304 {
1305 assert(n % sizeof(uint16_t) == 0);
1306
1307 const __m256 vprescale = _mm256_load_ps(params->avx2_rr1_p3.prescale);
1308 const __m256 vsat_cutoff = _mm256_load_ps(params->avx2_rr1_p3.sat_cutoff);
1309 const __m256 vmagic_bias = _mm256_load_ps(params->avx2_rr1_p3.magic_bias);
1310 const __m256 vlog2e = _mm256_load_ps(params->avx2_rr1_p3.log2e);
1311 const __m256 vminus_ln2 = _mm256_load_ps(params->avx2_rr1_p3.minus_ln2);
1312 const __m256 vc3 = _mm256_load_ps(params->avx2_rr1_p3.c3);
1313 const __m256 vc2 = _mm256_load_ps(params->avx2_rr1_p3.c2);
1314 const __m256 vc1 = _mm256_load_ps(params->avx2_rr1_p3.c1);
1315 const __m256 valpha = _mm256_load_ps(params->avx2_rr1_p3.alpha);
1316 const __m256 vbeta = _mm256_load_ps(params->avx2_rr1_p3.beta);
1317
1318 const uint16_t* i = (const uint16_t*) input;
1319 uint16_t* o = (uint16_t*) output;
1320 for (; n >= 16 * sizeof(uint16_t); n -= 16 * sizeof(uint16_t)) {
1321 __m256 vx0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i));
1322 __m256 vx1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i + 8)));
1323 i += 16;
1324
1325 const __m256 vz0 = _mm256_max_ps(vsat_cutoff, _mm256_mul_ps(vx0, vprescale));
1326 const __m256 vz1 = _mm256_max_ps(vsat_cutoff, _mm256_mul_ps(vx1, vprescale));
1327
1328 __m256 vn0 = _mm256_fmadd_ps(vz0, vlog2e, vmagic_bias);
1329 __m256 vn1 = _mm256_fmadd_ps(vz1, vlog2e, vmagic_bias);
1330
1331 __m256 vs0 = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn0), 23));
1332 vn0 = _mm256_sub_ps(vn0, vmagic_bias);
1333 __m256 vs1 = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn1), 23));
1334 vn1 = _mm256_sub_ps(vn1, vmagic_bias);
1335
1336 __m256 vt0 = _mm256_fmadd_ps(vn0, vminus_ln2, vz0);
1337 __m256 vt1 = _mm256_fmadd_ps(vn1, vminus_ln2, vz1);
1338
1339 __m256 vp0 = _mm256_fmadd_ps(vc3, vt0, vc2);
1340 __m256 vp1 = _mm256_fmadd_ps(vc3, vt1, vc2);
1341
1342 vp0 = _mm256_fmadd_ps(vp0, vt0, vc1);
1343 vt0 = _mm256_mul_ps(vt0, valpha);
1344 vp1 = _mm256_fmadd_ps(vp1, vt1, vc1);
1345 vt1 = _mm256_mul_ps(vt1, valpha);
1346
1347 vt0 = _mm256_mul_ps(vt0, vs0);
1348 vs0 = _mm256_fmsub_ps(vs0, valpha, valpha);
1349 vt1 = _mm256_mul_ps(vt1, vs1);
1350 vs1 = _mm256_fmsub_ps(vs1, valpha, valpha);
1351
1352 const __m256 ve0 = _mm256_fmadd_ps(vp0, vt0, vs0);
1353 vx0 = _mm256_mul_ps(vx0, vbeta);
1354 const __m256 ve1 = _mm256_fmadd_ps(vp1, vt1, vs1);
1355 vx1 = _mm256_mul_ps(vx1, vbeta);
1356
1357 const __m256 vy0 = _mm256_blendv_ps(vx0, ve0, vx0);
1358 const __m256 vy1 = _mm256_blendv_ps(vx1, ve1, vx1);
1359
1360 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vy0, _MM_FROUND_NO_EXC));
1361 _mm_storeu_si128((__m128i*) (o + 8), _mm256_cvtps_ph(vy1, _MM_FROUND_NO_EXC));
1362 o += 16;
1363 }
1364 for (; n >= 8 * sizeof(uint16_t); n -= 8 * sizeof(uint16_t)) {
1365 __m256 vx = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i));
1366 i += 8;
1367
1368 const __m256 vz = _mm256_max_ps(vsat_cutoff, _mm256_mul_ps(vx, vprescale));
1369
1370 __m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias);
1371 __m256 vs = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn), 23));
1372 vn = _mm256_sub_ps(vn, vmagic_bias);
1373 __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz);
1374
1375 __m256 vp = _mm256_fmadd_ps(vc3, vt, vc2);
1376 vp = _mm256_fmadd_ps(vp, vt, vc1);
1377 vt = _mm256_mul_ps(vt, valpha);
1378 vt = _mm256_mul_ps(vt, vs);
1379 vs = _mm256_fmsub_ps(vs, valpha, valpha);
1380 const __m256 ve = _mm256_fmadd_ps(vp, vt, vs);
1381 vx = _mm256_mul_ps(vx, vbeta);
1382 const __m256 vy = _mm256_blendv_ps(vx, ve, vx);
1383
1384 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vy, _MM_FROUND_NO_EXC));
1385 o += 8;
1386 }
1387 if XNN_UNLIKELY(n != 0) {
1388 assert(n >= 1 * sizeof(uint16_t));
1389 assert(n <= 7 * sizeof(uint16_t));
1390 __m256 vx = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i));
1391
1392 const __m256 vz = _mm256_max_ps(vsat_cutoff, _mm256_mul_ps(vx, vprescale));
1393
1394 __m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias);
1395 __m256 vs = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn), 23));
1396 vn = _mm256_sub_ps(vn, vmagic_bias);
1397 __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz);
1398
1399 __m256 vp = _mm256_fmadd_ps(vc3, vt, vc2);
1400 vp = _mm256_fmadd_ps(vp, vt, vc1);
1401 vt = _mm256_mul_ps(vt, valpha);
1402 vt = _mm256_mul_ps(vt, vs);
1403 vs = _mm256_fmsub_ps(vs, valpha, valpha);
1404 const __m256 ve = _mm256_fmadd_ps(vp, vt, vs);
1405 vx = _mm256_mul_ps(vx, vbeta);
1406 const __m256 vy = _mm256_blendv_ps(vx, ve, vx);
1407
1408 __m128i vh = _mm256_cvtps_ph(vy, _MM_FROUND_NO_EXC);
1409 if (n & (4 * sizeof(uint16_t))) {
1410 _mm_storel_epi64((__m128i*) o, vh);
1411 vh = _mm_unpackhi_epi64(vh, vh);
1412 o += 4;
1413 }
1414 if (n & (2 * sizeof(uint16_t))) {
1415 _mm_storeu_si32(o, vh);
1416 vh = _mm_srli_epi64(vh, 32);
1417 o += 2;
1418 }
1419 if (n & (1 * sizeof(uint16_t))) {
1420 *o = (uint16_t) _mm_extract_epi16(vh, 0);
1421 }
1422 }
1423 }
1424
xnn_f16_vsigmoid_ukernel__avx2_rr1_p2_rcp_x32(size_t batch,const void * input,void * output,const union xnn_f16_sigmoid_params params[restrict XNN_MIN_ELEMENTS (1)])1425 void xnn_f16_vsigmoid_ukernel__avx2_rr1_p2_rcp_x32(
1426 size_t batch,
1427 const void* input,
1428 void* output,
1429 const union xnn_f16_sigmoid_params params[restrict XNN_MIN_ELEMENTS(1)])
1430 {
1431 assert(batch % sizeof(uint16_t) == 0);
1432
1433 const __m256 vsign_mask = _mm256_load_ps(params->avx2_rr1_p2.sign_mask);
1434 const __m256 vmagic_bias = _mm256_load_ps(params->avx2_rr1_p2.magic_bias);
1435 const __m256 vlog2e = _mm256_load_ps(params->avx2_rr1_p2.log2e);
1436 const __m256 vminus_ln2 = _mm256_load_ps(params->avx2_rr1_p2.minus_ln2);
1437 const __m256 vc2 = _mm256_load_ps(params->avx2_rr1_p2.c2);
1438 const __m256 vc1 = _mm256_load_ps(params->avx2_rr1_p2.c1);
1439 const __m256 vone = _mm256_load_ps(params->avx2_rr1_p2.one);
1440 const __m256 vdenorm_cutoff = _mm256_load_ps(params->avx2_rr1_p2.denorm_cutoff);
1441
1442 const uint16_t* i = (const uint16_t*) input;
1443 uint16_t* o = (uint16_t*) output;
1444 for (; batch >= 32 * sizeof(uint16_t); batch -= 32 * sizeof(uint16_t)) {
1445 const __m256 vx0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i));
1446 const __m256 vx1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i + 8)));
1447 const __m256 vx2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i + 16)));
1448 const __m256 vx3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i + 24)));
1449 i += 32;
1450
1451 const __m256 vz0 = _mm256_or_ps(vx0, vsign_mask);
1452 const __m256 vz1 = _mm256_or_ps(vx1, vsign_mask);
1453 const __m256 vz2 = _mm256_or_ps(vx2, vsign_mask);
1454 const __m256 vz3 = _mm256_or_ps(vx3, vsign_mask);
1455
1456 __m256 vn0 = _mm256_fmadd_ps(vz0, vlog2e, vmagic_bias);
1457 __m256 vn1 = _mm256_fmadd_ps(vz1, vlog2e, vmagic_bias);
1458 __m256 vn2 = _mm256_fmadd_ps(vz2, vlog2e, vmagic_bias);
1459 __m256 vn3 = _mm256_fmadd_ps(vz3, vlog2e, vmagic_bias);
1460
1461 const __m256 vs0 = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn0), 23));
1462 const __m256 vs1 = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn1), 23));
1463 const __m256 vs2 = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn2), 23));
1464 const __m256 vs3 = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn3), 23));
1465
1466 vn0 = _mm256_sub_ps(vn0, vmagic_bias);
1467 vn1 = _mm256_sub_ps(vn1, vmagic_bias);
1468 vn2 = _mm256_sub_ps(vn2, vmagic_bias);
1469 vn3 = _mm256_sub_ps(vn3, vmagic_bias);
1470
1471 __m256 vt0 = _mm256_fmadd_ps(vn0, vminus_ln2, vz0);
1472 __m256 vt1 = _mm256_fmadd_ps(vn1, vminus_ln2, vz1);
1473 __m256 vt2 = _mm256_fmadd_ps(vn2, vminus_ln2, vz2);
1474 __m256 vt3 = _mm256_fmadd_ps(vn3, vminus_ln2, vz3);
1475
1476 const __m256 vp0 = _mm256_fmadd_ps(vc2, vt0, vc1);
1477 const __m256 vp1 = _mm256_fmadd_ps(vc2, vt1, vc1);
1478 const __m256 vp2 = _mm256_fmadd_ps(vc2, vt2, vc1);
1479 const __m256 vp3 = _mm256_fmadd_ps(vc2, vt3, vc1);
1480
1481 vt0 = _mm256_mul_ps(vt0, vs0);
1482 vt1 = _mm256_mul_ps(vt1, vs1);
1483 vt2 = _mm256_mul_ps(vt2, vs2);
1484 vt3 = _mm256_mul_ps(vt3, vs3);
1485
1486 const __m256 ve0 = _mm256_fmadd_ps(vt0, vp0, vs0);
1487 const __m256 ve1 = _mm256_fmadd_ps(vt1, vp1, vs1);
1488 const __m256 ve2 = _mm256_fmadd_ps(vt2, vp2, vs2);
1489 const __m256 ve3 = _mm256_fmadd_ps(vt3, vp3, vs3);
1490
1491 const __m256 vd0 = _mm256_add_ps(ve0, vone);
1492 const __m256 vd1 = _mm256_add_ps(ve1, vone);
1493 const __m256 vd2 = _mm256_add_ps(ve2, vone);
1494 const __m256 vd3 = _mm256_add_ps(ve3, vone);
1495
1496 const __m256 vr0 = _mm256_rcp_ps(vd0);
1497 const __m256 vr1 = _mm256_rcp_ps(vd1);
1498 const __m256 vr2 = _mm256_rcp_ps(vd2);
1499 const __m256 vr3 = _mm256_rcp_ps(vd3);
1500
1501 __m256 vf0 = _mm256_mul_ps(ve0, vr0);
1502 __m256 vf1 = _mm256_mul_ps(ve1, vr1);
1503 __m256 vf2 = _mm256_mul_ps(ve2, vr2);
1504 __m256 vf3 = _mm256_mul_ps(ve3, vr3);
1505
1506 vf0 = _mm256_andnot_ps(_mm256_cmp_ps(vz0, vdenorm_cutoff, _CMP_LT_OS), vf0);
1507 vf1 = _mm256_andnot_ps(_mm256_cmp_ps(vz1, vdenorm_cutoff, _CMP_LT_OS), vf1);
1508 vf2 = _mm256_andnot_ps(_mm256_cmp_ps(vz2, vdenorm_cutoff, _CMP_LT_OS), vf2);
1509 vf3 = _mm256_andnot_ps(_mm256_cmp_ps(vz3, vdenorm_cutoff, _CMP_LT_OS), vf3);
1510
1511 vf0 = _mm256_blendv_ps(_mm256_sub_ps(vone, vf0), vf0, vx0);
1512 vf1 = _mm256_blendv_ps(_mm256_sub_ps(vone, vf1), vf1, vx1);
1513 vf2 = _mm256_blendv_ps(_mm256_sub_ps(vone, vf2), vf2, vx2);
1514 vf3 = _mm256_blendv_ps(_mm256_sub_ps(vone, vf3), vf3, vx3);
1515
1516 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vf0, _MM_FROUND_NO_EXC));
1517 _mm_storeu_si128((__m128i*) (o + 8), _mm256_cvtps_ph(vf1, _MM_FROUND_NO_EXC));
1518 _mm_storeu_si128((__m128i*) (o + 16), _mm256_cvtps_ph(vf2, _MM_FROUND_NO_EXC));
1519 _mm_storeu_si128((__m128i*) (o + 24), _mm256_cvtps_ph(vf3, _MM_FROUND_NO_EXC));
1520 o += 32;
1521 }
1522 for (; batch >= 8 * sizeof(uint16_t); batch -= 8 * sizeof(uint16_t)) {
1523 const __m256 vx = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i));
1524 i += 8;
1525
1526 const __m256 vz = _mm256_or_ps(vx, vsign_mask);
1527
1528 __m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias);
1529 const __m256 vs = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn), 23));
1530 vn = _mm256_sub_ps(vn, vmagic_bias);
1531
1532 __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz);
1533
1534 const __m256 vp = _mm256_fmadd_ps(vc2, vt, vc1);
1535 vt = _mm256_mul_ps(vt, vs);
1536 const __m256 ve = _mm256_fmadd_ps(vt, vp, vs);
1537
1538 const __m256 vd = _mm256_add_ps(ve, vone);
1539 const __m256 vr = _mm256_rcp_ps(vd);
1540 __m256 vf = _mm256_mul_ps(ve, vr);
1541
1542 vf = _mm256_andnot_ps(_mm256_cmp_ps(vz, vdenorm_cutoff, _CMP_LT_OS), vf);
1543 vf = _mm256_blendv_ps(_mm256_sub_ps(vone, vf), vf, vx);
1544
1545 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vf, _MM_FROUND_NO_EXC));
1546 o += 8;
1547 }
1548 if XNN_UNLIKELY(batch != 0) {
1549 assert(batch >= 1 * sizeof(uint16_t));
1550 assert(batch <= 7 * sizeof(uint16_t));
1551 const __m256 vx = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i));
1552
1553 const __m256 vz = _mm256_or_ps(vx, vsign_mask);
1554
1555 __m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias);
1556 const __m256 vs = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn), 23));
1557 vn = _mm256_sub_ps(vn, vmagic_bias);
1558
1559 __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz);
1560
1561 const __m256 vp = _mm256_fmadd_ps(vc2, vt, vc1);
1562 vt = _mm256_mul_ps(vt, vs);
1563 const __m256 ve = _mm256_fmadd_ps(vt, vp, vs);
1564
1565 const __m256 vd = _mm256_add_ps(ve, vone);
1566 const __m256 vr = _mm256_rcp_ps(vd);
1567 __m256 vf = _mm256_mul_ps(ve, vr);
1568
1569 vf = _mm256_andnot_ps(_mm256_cmp_ps(vz, vdenorm_cutoff, _CMP_LT_OS), vf);
1570 vf = _mm256_blendv_ps(_mm256_sub_ps(vone, vf), vf, vx);
1571
1572 __m128i vh = _mm256_cvtps_ph(vf, _MM_FROUND_NO_EXC);
1573 if (batch & (4 * sizeof(uint16_t))) {
1574 _mm_storel_epi64((__m128i*) o, vh);
1575 vh = _mm_unpackhi_epi64(vh, vh);
1576 o += 4;
1577 }
1578 if (batch & (2 * sizeof(uint16_t))) {
1579 _mm_storeu_si32(o, vh);
1580 vh = _mm_srli_epi64(vh, 32);
1581 o += 2;
1582 }
1583 if (batch & (1 * sizeof(uint16_t))) {
1584 *o = (uint16_t) _mm_extract_epi16(vh, 0);
1585 }
1586 }
1587 }
1588
xnn_f32_qs8_vcvt_ukernel__avx2_x64(size_t n,const float * x,int8_t * y,const union xnn_f32_qs8_cvt_params params[restrict XNN_MIN_ELEMENTS (1)])1589 void xnn_f32_qs8_vcvt_ukernel__avx2_x64(
1590 size_t n,
1591 const float* x,
1592 int8_t* y,
1593 const union xnn_f32_qs8_cvt_params params[restrict XNN_MIN_ELEMENTS(1)])
1594 {
1595 assert(n != 0);
1596 assert(n % sizeof(float) == 0);
1597 assert(x != NULL);
1598 assert(y != NULL);
1599
1600 const __m256 vscale = _mm256_load_ps(params->avx2.scale);
1601 const __m256 voutput_max_less_zero_point = _mm256_load_ps(params->avx2.output_max_less_zero_point);
1602 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->avx2.output_zero_point);
1603 const __m256i vshuffle_mask = _mm256_load_si256((const __m256i*) params->avx2.shuffle_mask);
1604 const __m256i voutput_min = _mm256_load_si256((const __m256i*) params->avx2.output_min);
1605
1606 for (; n >= 64 * sizeof(float); n -= 64 * sizeof(float)) {
1607 __m256 vx01 = _mm256_loadu_ps(x);
1608 __m256 vx23 = _mm256_loadu_ps(x + 8);
1609 __m256 vx45 = _mm256_loadu_ps(x + 16);
1610 __m256 vx67 = _mm256_loadu_ps(x + 24);
1611 __m256 vx89 = _mm256_loadu_ps(x + 32);
1612 __m256 vxAB = _mm256_loadu_ps(x + 40);
1613 __m256 vxCD = _mm256_loadu_ps(x + 48);
1614 __m256 vxEF = _mm256_loadu_ps(x + 56);
1615 x += 64;
1616
1617 vx01 = _mm256_mul_ps(vx01, vscale);
1618 vx23 = _mm256_mul_ps(vx23, vscale);
1619 vx45 = _mm256_mul_ps(vx45, vscale);
1620 vx67 = _mm256_mul_ps(vx67, vscale);
1621 vx89 = _mm256_mul_ps(vx89, vscale);
1622 vxAB = _mm256_mul_ps(vxAB, vscale);
1623 vxCD = _mm256_mul_ps(vxCD, vscale);
1624 vxEF = _mm256_mul_ps(vxEF, vscale);
1625
1626 vx01 = _mm256_min_ps(vx01, voutput_max_less_zero_point);
1627 vx23 = _mm256_min_ps(vx23, voutput_max_less_zero_point);
1628 vx45 = _mm256_min_ps(vx45, voutput_max_less_zero_point);
1629 vx67 = _mm256_min_ps(vx67, voutput_max_less_zero_point);
1630 vx89 = _mm256_min_ps(vx89, voutput_max_less_zero_point);
1631 vxAB = _mm256_min_ps(vxAB, voutput_max_less_zero_point);
1632 vxCD = _mm256_min_ps(vxCD, voutput_max_less_zero_point);
1633 vxEF = _mm256_min_ps(vxEF, voutput_max_less_zero_point);
1634
1635 const __m256i vacc01 = _mm256_cvtps_epi32(vx01);
1636 const __m256i vacc23 = _mm256_cvtps_epi32(vx23);
1637 const __m256i vacc45 = _mm256_cvtps_epi32(vx45);
1638 const __m256i vacc67 = _mm256_cvtps_epi32(vx67);
1639 const __m256i vacc89 = _mm256_cvtps_epi32(vx89);
1640 const __m256i vaccAB = _mm256_cvtps_epi32(vxAB);
1641 const __m256i vaccCD = _mm256_cvtps_epi32(vxCD);
1642 const __m256i vaccEF = _mm256_cvtps_epi32(vxEF);
1643
1644 __m256i vacc0213 = _mm256_packs_epi32(vacc01, vacc23);
1645 __m256i vacc4657 = _mm256_packs_epi32(vacc45, vacc67);
1646 __m256i vacc8A9B = _mm256_packs_epi32(vacc89, vaccAB);
1647 __m256i vaccCEDF = _mm256_packs_epi32(vaccCD, vaccEF);
1648
1649 vacc0213 = _mm256_adds_epi16(vacc0213, voutput_zero_point);
1650 vacc4657 = _mm256_adds_epi16(vacc4657, voutput_zero_point);
1651 vacc8A9B = _mm256_adds_epi16(vacc8A9B, voutput_zero_point);
1652 vaccCEDF = _mm256_adds_epi16(vaccCEDF, voutput_zero_point);
1653
1654 const __m256i vy02461357 = _mm256_packs_epi16(vacc0213, vacc4657);
1655 const __m256i vy8ACE9BDF = _mm256_packs_epi16(vacc8A9B, vaccCEDF);
1656
1657 __m256i vy01234567 = _mm256_permutevar8x32_epi32(vy02461357, vshuffle_mask);
1658 __m256i vy89ABCDEF = _mm256_permutevar8x32_epi32(vy8ACE9BDF, vshuffle_mask);
1659
1660 vy01234567 = _mm256_max_epi8(vy01234567, voutput_min);
1661 vy89ABCDEF = _mm256_max_epi8(vy89ABCDEF, voutput_min);
1662
1663 _mm256_storeu_si256((__m256i*) y, vy01234567);
1664 _mm256_storeu_si256((__m256i*) (y + 32), vy89ABCDEF);
1665 y += 64;
1666 }
1667 for (; n >= 8 * sizeof(float); n -= 8 * sizeof(float)) {
1668 __m256 vx = _mm256_loadu_ps(x);
1669 vx = _mm256_mul_ps(vx, vscale);
1670 vx = _mm256_min_ps(vx, voutput_max_less_zero_point);
1671 x += 8;
1672
1673 const __m256i vacc = _mm256_cvtps_epi32(vx);
1674
1675 __m128i vy = _mm_packs_epi32(_mm256_castsi256_si128(vacc), _mm256_extracti128_si256(vacc, 1));
1676 vy = _mm_adds_epi16(vy, _mm256_castsi256_si128(voutput_zero_point));
1677 vy = _mm_packs_epi16(vy, vy);
1678 vy = _mm_max_epi8(vy, _mm256_castsi256_si128(voutput_min));
1679
1680 _mm_storel_epi64((__m128i*) y, vy);
1681 y += 8;
1682 }
1683 if XNN_UNLIKELY(n != 0) {
1684 assert(n >= 1 * sizeof(float));
1685 assert(n <= 7 * sizeof(float));
1686 const __m256i vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) ¶ms->avx2.mask_table[7] - n));
1687
1688 __m256 vx = _mm256_maskload_ps(x, vmask);
1689 vx = _mm256_mul_ps(vx, vscale);
1690 vx = _mm256_min_ps(vx, voutput_max_less_zero_point);
1691
1692 const __m256i vacc = _mm256_cvtps_epi32(vx);
1693
1694 __m128i vy = _mm_packs_epi32(_mm256_castsi256_si128(vacc), _mm256_extracti128_si256(vacc, 1));
1695 vy = _mm_adds_epi16(vy, _mm256_castsi256_si128(voutput_zero_point));
1696 vy = _mm_packs_epi16(vy, vy);
1697 vy = _mm_max_epi8(vy, _mm256_castsi256_si128(voutput_min));
1698
1699 if (n & (4 * sizeof(float))) {
1700 _mm_storeu_si32(y, vy);
1701 y += 4;
1702 vy = _mm_srli_epi64(vy, 32);
1703 }
1704 if (n & (2 * sizeof(float))) {
1705 _mm_storeu_si16(y, vy);
1706 y += 2;
1707 vy = _mm_srli_epi32(vy, 16);
1708 }
1709 if (n & (1 * sizeof(float))) {
1710 *y = (int8_t) _mm_extract_epi8(vy, 0);
1711 }
1712 }
1713 }
1714
xnn_f32_qu8_vcvt_ukernel__avx2_x64(size_t n,const float * x,uint8_t * y,const union xnn_f32_qu8_cvt_params params[restrict XNN_MIN_ELEMENTS (1)])1715 void xnn_f32_qu8_vcvt_ukernel__avx2_x64(
1716 size_t n,
1717 const float* x,
1718 uint8_t* y,
1719 const union xnn_f32_qu8_cvt_params params[restrict XNN_MIN_ELEMENTS(1)])
1720 {
1721 assert(n != 0);
1722 assert(n % sizeof(float) == 0);
1723 assert(x != NULL);
1724 assert(y != NULL);
1725
1726 const __m256 vscale = _mm256_load_ps(params->avx2.scale);
1727 const __m256 voutput_max_less_zero_point = _mm256_load_ps(params->avx2.output_max_less_zero_point);
1728 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->avx2.output_zero_point);
1729 const __m256i vshuffle_mask = _mm256_load_si256((const __m256i*) params->avx2.shuffle_mask);
1730 const __m256i voutput_min = _mm256_load_si256((const __m256i*) params->avx2.output_min);
1731
1732 for (; n >= 64 * sizeof(float); n -= 64 * sizeof(float)) {
1733 __m256 vx01 = _mm256_loadu_ps(x);
1734 __m256 vx23 = _mm256_loadu_ps(x + 8);
1735 __m256 vx45 = _mm256_loadu_ps(x + 16);
1736 __m256 vx67 = _mm256_loadu_ps(x + 24);
1737 __m256 vx89 = _mm256_loadu_ps(x + 32);
1738 __m256 vxAB = _mm256_loadu_ps(x + 40);
1739 __m256 vxCD = _mm256_loadu_ps(x + 48);
1740 __m256 vxEF = _mm256_loadu_ps(x + 56);
1741 x += 64;
1742
1743 vx01 = _mm256_mul_ps(vx01, vscale);
1744 vx23 = _mm256_mul_ps(vx23, vscale);
1745 vx45 = _mm256_mul_ps(vx45, vscale);
1746 vx67 = _mm256_mul_ps(vx67, vscale);
1747 vx89 = _mm256_mul_ps(vx89, vscale);
1748 vxAB = _mm256_mul_ps(vxAB, vscale);
1749 vxCD = _mm256_mul_ps(vxCD, vscale);
1750 vxEF = _mm256_mul_ps(vxEF, vscale);
1751
1752 vx01 = _mm256_min_ps(vx01, voutput_max_less_zero_point);
1753 vx23 = _mm256_min_ps(vx23, voutput_max_less_zero_point);
1754 vx45 = _mm256_min_ps(vx45, voutput_max_less_zero_point);
1755 vx67 = _mm256_min_ps(vx67, voutput_max_less_zero_point);
1756 vx89 = _mm256_min_ps(vx89, voutput_max_less_zero_point);
1757 vxAB = _mm256_min_ps(vxAB, voutput_max_less_zero_point);
1758 vxCD = _mm256_min_ps(vxCD, voutput_max_less_zero_point);
1759 vxEF = _mm256_min_ps(vxEF, voutput_max_less_zero_point);
1760
1761 const __m256i vacc01 = _mm256_cvtps_epi32(vx01);
1762 const __m256i vacc23 = _mm256_cvtps_epi32(vx23);
1763 const __m256i vacc45 = _mm256_cvtps_epi32(vx45);
1764 const __m256i vacc67 = _mm256_cvtps_epi32(vx67);
1765 const __m256i vacc89 = _mm256_cvtps_epi32(vx89);
1766 const __m256i vaccAB = _mm256_cvtps_epi32(vxAB);
1767 const __m256i vaccCD = _mm256_cvtps_epi32(vxCD);
1768 const __m256i vaccEF = _mm256_cvtps_epi32(vxEF);
1769
1770 __m256i vacc0213 = _mm256_packs_epi32(vacc01, vacc23);
1771 __m256i vacc4657 = _mm256_packs_epi32(vacc45, vacc67);
1772 __m256i vacc8A9B = _mm256_packs_epi32(vacc89, vaccAB);
1773 __m256i vaccCEDF = _mm256_packs_epi32(vaccCD, vaccEF);
1774
1775 vacc0213 = _mm256_adds_epi16(vacc0213, voutput_zero_point);
1776 vacc4657 = _mm256_adds_epi16(vacc4657, voutput_zero_point);
1777 vacc8A9B = _mm256_adds_epi16(vacc8A9B, voutput_zero_point);
1778 vaccCEDF = _mm256_adds_epi16(vaccCEDF, voutput_zero_point);
1779
1780 const __m256i vy02461357 = _mm256_packus_epi16(vacc0213, vacc4657);
1781 const __m256i vy8ACE9BDF = _mm256_packus_epi16(vacc8A9B, vaccCEDF);
1782
1783 __m256i vy01234567 = _mm256_permutevar8x32_epi32(vy02461357, vshuffle_mask);
1784 __m256i vy89ABCDEF = _mm256_permutevar8x32_epi32(vy8ACE9BDF, vshuffle_mask);
1785
1786 vy01234567 = _mm256_max_epu8(vy01234567, voutput_min);
1787 vy89ABCDEF = _mm256_max_epu8(vy89ABCDEF, voutput_min);
1788
1789 _mm256_storeu_si256((__m256i*) y, vy01234567);
1790 _mm256_storeu_si256((__m256i*) (y + 32), vy89ABCDEF);
1791 y += 64;
1792 }
1793 for (; n >= 8 * sizeof(float); n -= 8 * sizeof(float)) {
1794 __m256 vx = _mm256_loadu_ps(x);
1795 vx = _mm256_mul_ps(vx, vscale);
1796 vx = _mm256_min_ps(vx, voutput_max_less_zero_point);
1797 x += 8;
1798
1799 const __m256i vacc = _mm256_cvtps_epi32(vx);
1800
1801 __m128i vy = _mm_packs_epi32(_mm256_castsi256_si128(vacc), _mm256_extracti128_si256(vacc, 1));
1802 vy = _mm_adds_epi16(vy, _mm256_castsi256_si128(voutput_zero_point));
1803 vy = _mm_packus_epi16(vy, vy);
1804 vy = _mm_max_epu8(vy, _mm256_castsi256_si128(voutput_min));
1805
1806 _mm_storel_epi64((__m128i*) y, vy);
1807 y += 8;
1808 }
1809 if XNN_UNLIKELY(n != 0) {
1810 assert(n >= 1 * sizeof(float));
1811 assert(n <= 7 * sizeof(float));
1812 const __m256i vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) ¶ms->avx2.mask_table[7] - n));
1813
1814 __m256 vx = _mm256_maskload_ps(x, vmask);
1815 vx = _mm256_mul_ps(vx, vscale);
1816 vx = _mm256_min_ps(vx, voutput_max_less_zero_point);
1817
1818 const __m256i vacc = _mm256_cvtps_epi32(vx);
1819
1820 __m128i vy = _mm_packs_epi32(_mm256_castsi256_si128(vacc), _mm256_extracti128_si256(vacc, 1));
1821 vy = _mm_adds_epi16(vy, _mm256_castsi256_si128(voutput_zero_point));
1822 vy = _mm_packus_epi16(vy, vy);
1823 vy = _mm_max_epu8(vy, _mm256_castsi256_si128(voutput_min));
1824
1825 if (n & (4 * sizeof(float))) {
1826 _mm_storeu_si32(y, vy);
1827 y += 4;
1828 vy = _mm_srli_epi64(vy, 32);
1829 }
1830 if (n & (2 * sizeof(float))) {
1831 _mm_storeu_si16(y, vy);
1832 y += 2;
1833 vy = _mm_srli_epi32(vy, 16);
1834 }
1835 if (n & (1 * sizeof(float))) {
1836 *y = (uint8_t) _mm_extract_epi8(vy, 0);
1837 }
1838 }
1839 }
1840
xnn_f32_velu_ukernel__avx2_rr1_lut4_p4_perm_x56(size_t n,const float * x,float * y,const union xnn_f32_elu_params params[restrict XNN_MIN_ELEMENTS (1)])1841 void xnn_f32_velu_ukernel__avx2_rr1_lut4_p4_perm_x56(
1842 size_t n,
1843 const float* x,
1844 float* y,
1845 const union xnn_f32_elu_params params[restrict XNN_MIN_ELEMENTS(1)])
1846 {
1847 assert(n % sizeof(float) == 0);
1848
1849 const __m256 vprescale = _mm256_load_ps(params->avx2_rr1_lut4_p4.prescale);
1850 const __m256 valpha = _mm256_load_ps(params->avx2_rr1_lut4_p4.alpha);
1851 const __m256 vbeta = _mm256_load_ps(params->avx2_rr1_lut4_p4.beta);
1852 const __m256 vsat_cutoff = _mm256_load_ps(params->avx2_rr1_lut4_p4.sat_cutoff);
1853 const __m256 vmagic_bias = _mm256_load_ps(params->avx2_rr1_lut4_p4.magic_bias);
1854 const __m256 vlog2e = _mm256_load_ps(params->avx2_rr1_lut4_p4.log2e);
1855 const __m256 vtable = _mm256_load_ps(params->avx2_rr1_lut4_p4.table);
1856 const __m256 vminus_ln2 = _mm256_load_ps(params->avx2_rr1_lut4_p4.minus_ln2);
1857 const __m256 vc4 = _mm256_load_ps(params->avx2_rr1_lut4_p4.c4);
1858 const __m256 vc3 = _mm256_load_ps(params->avx2_rr1_lut4_p4.c3);
1859 const __m256 vc2 = _mm256_load_ps(params->avx2_rr1_lut4_p4.c2);
1860
1861 for (; n >= 56 * sizeof(float); n -= 56 * sizeof(float)) {
1862 __m256 vx0 = _mm256_loadu_ps(x);
1863 __m256 vx1 = _mm256_loadu_ps(x + 8);
1864 __m256 vx2 = _mm256_loadu_ps(x + 16);
1865 __m256 vx3 = _mm256_loadu_ps(x + 24);
1866 __m256 vx4 = _mm256_loadu_ps(x + 32);
1867 __m256 vx5 = _mm256_loadu_ps(x + 40);
1868 __m256 vx6 = _mm256_loadu_ps(x + 48);
1869 x += 56;
1870
1871 const __m256 vz0 = _mm256_max_ps(vsat_cutoff, _mm256_mul_ps(vx0, vprescale));
1872 const __m256 vz1 = _mm256_max_ps(vsat_cutoff, _mm256_mul_ps(vx1, vprescale));
1873 const __m256 vz2 = _mm256_max_ps(vsat_cutoff, _mm256_mul_ps(vx2, vprescale));
1874 const __m256 vz3 = _mm256_max_ps(vsat_cutoff, _mm256_mul_ps(vx3, vprescale));
1875 const __m256 vz4 = _mm256_max_ps(vsat_cutoff, _mm256_mul_ps(vx4, vprescale));
1876 const __m256 vz5 = _mm256_max_ps(vsat_cutoff, _mm256_mul_ps(vx5, vprescale));
1877 const __m256 vz6 = _mm256_max_ps(vsat_cutoff, _mm256_mul_ps(vx6, vprescale));
1878
1879 __m256 vn0 = _mm256_fmadd_ps(vz0, vlog2e, vmagic_bias);
1880 __m256 vn1 = _mm256_fmadd_ps(vz1, vlog2e, vmagic_bias);
1881 __m256 vn2 = _mm256_fmadd_ps(vz2, vlog2e, vmagic_bias);
1882 __m256 vn3 = _mm256_fmadd_ps(vz3, vlog2e, vmagic_bias);
1883 __m256 vn4 = _mm256_fmadd_ps(vz4, vlog2e, vmagic_bias);
1884 __m256 vn5 = _mm256_fmadd_ps(vz5, vlog2e, vmagic_bias);
1885 __m256 vn6 = _mm256_fmadd_ps(vz6, vlog2e, vmagic_bias);
1886
1887 const __m256i ven0 = _mm256_slli_epi32(_mm256_castps_si256(vn0), 21);
1888 const __m256i vl0 = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn0)));
1889 vn0 = _mm256_sub_ps(vn0, vmagic_bias);
1890 const __m256i ven1 = _mm256_slli_epi32(_mm256_castps_si256(vn1), 21);
1891 const __m256i vl1 = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn1)));
1892 vn1 = _mm256_sub_ps(vn1, vmagic_bias);
1893 const __m256i ven2 = _mm256_slli_epi32(_mm256_castps_si256(vn2), 21);
1894 const __m256i vl2 = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn2)));
1895 vn2 = _mm256_sub_ps(vn2, vmagic_bias);
1896 const __m256i ven3 = _mm256_slli_epi32(_mm256_castps_si256(vn3), 21);
1897 const __m256i vl3 = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn3)));
1898 vn3 = _mm256_sub_ps(vn3, vmagic_bias);
1899 const __m256i ven4 = _mm256_slli_epi32(_mm256_castps_si256(vn4), 21);
1900 const __m256i vl4 = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn4)));
1901 vn4 = _mm256_sub_ps(vn4, vmagic_bias);
1902 const __m256i ven5 = _mm256_slli_epi32(_mm256_castps_si256(vn5), 21);
1903 const __m256i vl5 = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn5)));
1904 vn5 = _mm256_sub_ps(vn5, vmagic_bias);
1905 const __m256i ven6 = _mm256_slli_epi32(_mm256_castps_si256(vn6), 21);
1906 const __m256i vl6 = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn6)));
1907 vn6 = _mm256_sub_ps(vn6, vmagic_bias);
1908
1909 __m256 vs0 = _mm256_castsi256_ps(_mm256_add_epi32(vl0, ven0));
1910 __m256 vt0 = _mm256_fmadd_ps(vn0, vminus_ln2, vz0);
1911 __m256 vs1 = _mm256_castsi256_ps(_mm256_add_epi32(vl1, ven1));
1912 __m256 vt1 = _mm256_fmadd_ps(vn1, vminus_ln2, vz1);
1913 __m256 vs2 = _mm256_castsi256_ps(_mm256_add_epi32(vl2, ven2));
1914 __m256 vt2 = _mm256_fmadd_ps(vn2, vminus_ln2, vz2);
1915 __m256 vs3 = _mm256_castsi256_ps(_mm256_add_epi32(vl3, ven3));
1916 __m256 vt3 = _mm256_fmadd_ps(vn3, vminus_ln2, vz3);
1917 __m256 vs4 = _mm256_castsi256_ps(_mm256_add_epi32(vl4, ven4));
1918 __m256 vt4 = _mm256_fmadd_ps(vn4, vminus_ln2, vz4);
1919 __m256 vs5 = _mm256_castsi256_ps(_mm256_add_epi32(vl5, ven5));
1920 __m256 vt5 = _mm256_fmadd_ps(vn5, vminus_ln2, vz5);
1921 __m256 vs6 = _mm256_castsi256_ps(_mm256_add_epi32(vl6, ven6));
1922 __m256 vt6 = _mm256_fmadd_ps(vn6, vminus_ln2, vz6);
1923
1924 __m256 vp0 = _mm256_fmadd_ps(vc4, vt0, vc3);
1925 __m256 vp1 = _mm256_fmadd_ps(vc4, vt1, vc3);
1926 __m256 vp2 = _mm256_fmadd_ps(vc4, vt2, vc3);
1927 __m256 vp3 = _mm256_fmadd_ps(vc4, vt3, vc3);
1928 __m256 vp4 = _mm256_fmadd_ps(vc4, vt4, vc3);
1929 __m256 vp5 = _mm256_fmadd_ps(vc4, vt5, vc3);
1930 __m256 vp6 = _mm256_fmadd_ps(vc4, vt6, vc3);
1931
1932 vp0 = _mm256_fmadd_ps(vp0, vt0, vc2);
1933 vp1 = _mm256_fmadd_ps(vp1, vt1, vc2);
1934 vp2 = _mm256_fmadd_ps(vp2, vt2, vc2);
1935 vp3 = _mm256_fmadd_ps(vp3, vt3, vc2);
1936 vp4 = _mm256_fmadd_ps(vp4, vt4, vc2);
1937 vp5 = _mm256_fmadd_ps(vp5, vt5, vc2);
1938 vp6 = _mm256_fmadd_ps(vp6, vt6, vc2);
1939
1940 vp0 = _mm256_mul_ps(vp0, vt0);
1941 vt0 = _mm256_mul_ps(vt0, vs0);
1942 vp1 = _mm256_mul_ps(vp1, vt1);
1943 vt1 = _mm256_mul_ps(vt1, vs1);
1944 vp2 = _mm256_mul_ps(vp2, vt2);
1945 vt2 = _mm256_mul_ps(vt2, vs2);
1946 vp3 = _mm256_mul_ps(vp3, vt3);
1947 vt3 = _mm256_mul_ps(vt3, vs3);
1948 vp4 = _mm256_mul_ps(vp4, vt4);
1949 vt4 = _mm256_mul_ps(vt4, vs4);
1950 vp5 = _mm256_mul_ps(vp5, vt5);
1951 vt5 = _mm256_mul_ps(vt5, vs5);
1952 vp6 = _mm256_mul_ps(vp6, vt6);
1953 vt6 = _mm256_mul_ps(vt6, vs6);
1954
1955 vs0 = _mm256_fmsub_ps(vs0, valpha, valpha);
1956 vp0 = _mm256_fmadd_ps(vp0, vt0, vt0);
1957 vs1 = _mm256_fmsub_ps(vs1, valpha, valpha);
1958 vp1 = _mm256_fmadd_ps(vp1, vt1, vt1);
1959 vs2 = _mm256_fmsub_ps(vs2, valpha, valpha);
1960 vp2 = _mm256_fmadd_ps(vp2, vt2, vt2);
1961 vs3 = _mm256_fmsub_ps(vs3, valpha, valpha);
1962 vp3 = _mm256_fmadd_ps(vp3, vt3, vt3);
1963 vs4 = _mm256_fmsub_ps(vs4, valpha, valpha);
1964 vp4 = _mm256_fmadd_ps(vp4, vt4, vt4);
1965 vs5 = _mm256_fmsub_ps(vs5, valpha, valpha);
1966 vp5 = _mm256_fmadd_ps(vp5, vt5, vt5);
1967 vs6 = _mm256_fmsub_ps(vs6, valpha, valpha);
1968 vp6 = _mm256_fmadd_ps(vp6, vt6, vt6);
1969
1970 const __m256 ve0 = _mm256_fmadd_ps(vp0, valpha, vs0);
1971 vx0 = _mm256_mul_ps(vx0, vbeta);
1972 const __m256 ve1 = _mm256_fmadd_ps(vp1, valpha, vs1);
1973 vx1 = _mm256_mul_ps(vx1, vbeta);
1974 const __m256 ve2 = _mm256_fmadd_ps(vp2, valpha, vs2);
1975 vx2 = _mm256_mul_ps(vx2, vbeta);
1976 const __m256 ve3 = _mm256_fmadd_ps(vp3, valpha, vs3);
1977 vx3 = _mm256_mul_ps(vx3, vbeta);
1978 const __m256 ve4 = _mm256_fmadd_ps(vp4, valpha, vs4);
1979 vx4 = _mm256_mul_ps(vx4, vbeta);
1980 const __m256 ve5 = _mm256_fmadd_ps(vp5, valpha, vs5);
1981 vx5 = _mm256_mul_ps(vx5, vbeta);
1982 const __m256 ve6 = _mm256_fmadd_ps(vp6, valpha, vs6);
1983 vx6 = _mm256_mul_ps(vx6, vbeta);
1984
1985 const __m256 vy0 = _mm256_blendv_ps(vx0, ve0, vx0);
1986 const __m256 vy1 = _mm256_blendv_ps(vx1, ve1, vx1);
1987 const __m256 vy2 = _mm256_blendv_ps(vx2, ve2, vx2);
1988 const __m256 vy3 = _mm256_blendv_ps(vx3, ve3, vx3);
1989 const __m256 vy4 = _mm256_blendv_ps(vx4, ve4, vx4);
1990 const __m256 vy5 = _mm256_blendv_ps(vx5, ve5, vx5);
1991 const __m256 vy6 = _mm256_blendv_ps(vx6, ve6, vx6);
1992
1993 _mm256_storeu_ps(y, vy0);
1994 _mm256_storeu_ps(y + 8, vy1);
1995 _mm256_storeu_ps(y + 16, vy2);
1996 _mm256_storeu_ps(y + 24, vy3);
1997 _mm256_storeu_ps(y + 32, vy4);
1998 _mm256_storeu_ps(y + 40, vy5);
1999 _mm256_storeu_ps(y + 48, vy6);
2000 y += 56;
2001 }
2002 for (; n >= 8 * sizeof(float); n -= 8 * sizeof(float)) {
2003 __m256 vx = _mm256_loadu_ps(x);
2004 x += 8;
2005
2006 const __m256 vz = _mm256_max_ps(vsat_cutoff, _mm256_mul_ps(vx, vprescale));
2007
2008 __m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias);
2009 const __m256i ven = _mm256_slli_epi32(_mm256_castps_si256(vn), 21);
2010 const __m256i vl = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn)));
2011 __m256 vs = _mm256_castsi256_ps(_mm256_add_epi32(vl, ven));
2012 vn = _mm256_sub_ps(vn, vmagic_bias);
2013
2014 __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz);
2015
2016 __m256 vp = _mm256_fmadd_ps(vc4, vt, vc3);
2017 vp = _mm256_fmadd_ps(vp, vt, vc2);
2018 vp = _mm256_mul_ps(vp, vt);
2019
2020 vt = _mm256_mul_ps(vt, vs);
2021 vs = _mm256_fmsub_ps(vs, valpha, valpha);
2022 vp = _mm256_fmadd_ps(vp, vt, vt);
2023 const __m256 ve = _mm256_fmadd_ps(vp, valpha, vs);
2024
2025 vx = _mm256_mul_ps(vx, vbeta);
2026 const __m256 vy = _mm256_blendv_ps(vx, ve, vx);
2027
2028 _mm256_storeu_ps(y, vy);
2029 y += 8;
2030 }
2031 if XNN_UNLIKELY(n != 0) {
2032 assert(n >= 1 * sizeof(float));
2033 assert(n <= 7 * sizeof(float));
2034 const __m256i vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) ¶ms->avx2_rr1_lut4_p4.mask_table[7] - n));
2035
2036 __m256 vx = _mm256_maskload_ps(x, vmask);
2037
2038 const __m256 vz = _mm256_max_ps(vsat_cutoff, _mm256_mul_ps(vx, vprescale));
2039
2040 __m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias);
2041 const __m256i ven = _mm256_slli_epi32(_mm256_castps_si256(vn), 21);
2042 const __m256i vl = _mm256_castps_si256(_mm256_permutevar_ps(vtable, _mm256_castps_si256(vn)));
2043 __m256 vs = _mm256_castsi256_ps(_mm256_add_epi32(vl, ven));
2044 vn = _mm256_sub_ps(vn, vmagic_bias);
2045
2046 __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz);
2047
2048 __m256 vp = _mm256_fmadd_ps(vc4, vt, vc3);
2049 vp = _mm256_fmadd_ps(vp, vt, vc2);
2050 vp = _mm256_mul_ps(vp, vt);
2051
2052 vt = _mm256_mul_ps(vt, vs);
2053 vs = _mm256_fmsub_ps(vs, valpha, valpha);
2054 vp = _mm256_fmadd_ps(vp, vt, vt);
2055 const __m256 ve = _mm256_fmadd_ps(vp, valpha, vs);
2056
2057 vx = _mm256_mul_ps(vx, vbeta);
2058 const __m256 vy = _mm256_blendv_ps(vx, ve, vx);
2059
2060 __m128 vy_lo = _mm256_castps256_ps128(vy);
2061 if (n & (4 * sizeof(float))) {
2062 _mm_storeu_ps(y, vy_lo);
2063 vy_lo = _mm256_extractf128_ps(vy, 1);
2064 y += 4;
2065 }
2066 if (n & (2 * sizeof(float))) {
2067 _mm_storel_pi((__m64*) y, vy_lo);
2068 vy_lo = _mm_movehl_ps(vy_lo, vy_lo);
2069 y += 2;
2070 }
2071 if (n & (1 * sizeof(float))) {
2072 _mm_store_ss(y, vy_lo);
2073 }
2074 }
2075 }
2076
xnn_f32_vsigmoid_ukernel__avx2_rr1_p5_div_x40(size_t n,const float * x,float * y,const union xnn_f32_sigmoid_params params[restrict XNN_MIN_ELEMENTS (1)])2077 void xnn_f32_vsigmoid_ukernel__avx2_rr1_p5_div_x40(
2078 size_t n,
2079 const float* x,
2080 float* y,
2081 const union xnn_f32_sigmoid_params params[restrict XNN_MIN_ELEMENTS(1)])
2082 {
2083 assert(n % sizeof(float) == 0);
2084
2085 const __m256 vsign_mask = _mm256_load_ps(params->avx2_rr1_p5.sign_mask);
2086 const __m256 vmagic_bias = _mm256_load_ps(params->avx2_rr1_p5.magic_bias);
2087 const __m256 vlog2e = _mm256_load_ps(params->avx2_rr1_p5.log2e);
2088 const __m256 vminus_ln2 = _mm256_load_ps(params->avx2_rr1_p5.minus_ln2);
2089 const __m256 vc5 = _mm256_load_ps(params->avx2_rr1_p5.c5);
2090 const __m256 vc4 = _mm256_load_ps(params->avx2_rr1_p5.c4);
2091 const __m256 vc3 = _mm256_load_ps(params->avx2_rr1_p5.c3);
2092 const __m256 vc2 = _mm256_load_ps(params->avx2_rr1_p5.c2);
2093 const __m256 vc1 = _mm256_load_ps(params->avx2_rr1_p5.c1);
2094 const __m256 vone = _mm256_load_ps(params->avx2_rr1_p5.one);
2095 const __m256 vdenorm_cutoff = _mm256_load_ps(params->avx2_rr1_p5.denorm_cutoff);
2096
2097 for (; n >= 40 * sizeof(float); n -= 40 * sizeof(float)) {
2098 const __m256 vx0 = _mm256_loadu_ps(x);
2099 const __m256 vx1 = _mm256_loadu_ps(x + 8);
2100 const __m256 vx2 = _mm256_loadu_ps(x + 16);
2101 const __m256 vx3 = _mm256_loadu_ps(x + 24);
2102 const __m256 vx4 = _mm256_loadu_ps(x + 32);
2103 x += 40;
2104
2105 const __m256 vz0 = _mm256_or_ps(vx0, vsign_mask);
2106 const __m256 vz1 = _mm256_or_ps(vx1, vsign_mask);
2107 const __m256 vz2 = _mm256_or_ps(vx2, vsign_mask);
2108 const __m256 vz3 = _mm256_or_ps(vx3, vsign_mask);
2109 const __m256 vz4 = _mm256_or_ps(vx4, vsign_mask);
2110
2111 __m256 vn0 = _mm256_fmadd_ps(vz0, vlog2e, vmagic_bias);
2112 __m256 vn1 = _mm256_fmadd_ps(vz1, vlog2e, vmagic_bias);
2113 __m256 vn2 = _mm256_fmadd_ps(vz2, vlog2e, vmagic_bias);
2114 __m256 vn3 = _mm256_fmadd_ps(vz3, vlog2e, vmagic_bias);
2115 __m256 vn4 = _mm256_fmadd_ps(vz4, vlog2e, vmagic_bias);
2116
2117 const __m256 vs0 = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn0), 23));
2118 const __m256 vs1 = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn1), 23));
2119 const __m256 vs2 = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn2), 23));
2120 const __m256 vs3 = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn3), 23));
2121 const __m256 vs4 = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn4), 23));
2122
2123 vn0 = _mm256_sub_ps(vn0, vmagic_bias);
2124 vn1 = _mm256_sub_ps(vn1, vmagic_bias);
2125 vn2 = _mm256_sub_ps(vn2, vmagic_bias);
2126 vn3 = _mm256_sub_ps(vn3, vmagic_bias);
2127 vn4 = _mm256_sub_ps(vn4, vmagic_bias);
2128
2129 __m256 vt0 = _mm256_fmadd_ps(vn0, vminus_ln2, vz0);
2130 __m256 vt1 = _mm256_fmadd_ps(vn1, vminus_ln2, vz1);
2131 __m256 vt2 = _mm256_fmadd_ps(vn2, vminus_ln2, vz2);
2132 __m256 vt3 = _mm256_fmadd_ps(vn3, vminus_ln2, vz3);
2133 __m256 vt4 = _mm256_fmadd_ps(vn4, vminus_ln2, vz4);
2134
2135 __m256 vp0 = _mm256_fmadd_ps(vc5, vt0, vc4);
2136 __m256 vp1 = _mm256_fmadd_ps(vc5, vt1, vc4);
2137 __m256 vp2 = _mm256_fmadd_ps(vc5, vt2, vc4);
2138 __m256 vp3 = _mm256_fmadd_ps(vc5, vt3, vc4);
2139 __m256 vp4 = _mm256_fmadd_ps(vc5, vt4, vc4);
2140
2141 vp0 = _mm256_fmadd_ps(vp0, vt0, vc3);
2142 vp1 = _mm256_fmadd_ps(vp1, vt1, vc3);
2143 vp2 = _mm256_fmadd_ps(vp2, vt2, vc3);
2144 vp3 = _mm256_fmadd_ps(vp3, vt3, vc3);
2145 vp4 = _mm256_fmadd_ps(vp4, vt4, vc3);
2146
2147 vp0 = _mm256_fmadd_ps(vp0, vt0, vc2);
2148 vp1 = _mm256_fmadd_ps(vp1, vt1, vc2);
2149 vp2 = _mm256_fmadd_ps(vp2, vt2, vc2);
2150 vp3 = _mm256_fmadd_ps(vp3, vt3, vc2);
2151 vp4 = _mm256_fmadd_ps(vp4, vt4, vc2);
2152
2153 vp0 = _mm256_fmadd_ps(vp0, vt0, vc1);
2154 vp1 = _mm256_fmadd_ps(vp1, vt1, vc1);
2155 vp2 = _mm256_fmadd_ps(vp2, vt2, vc1);
2156 vp3 = _mm256_fmadd_ps(vp3, vt3, vc1);
2157 vp4 = _mm256_fmadd_ps(vp4, vt4, vc1);
2158
2159 vt0 = _mm256_mul_ps(vt0, vs0);
2160 vt1 = _mm256_mul_ps(vt1, vs1);
2161 vt2 = _mm256_mul_ps(vt2, vs2);
2162 vt3 = _mm256_mul_ps(vt3, vs3);
2163 vt4 = _mm256_mul_ps(vt4, vs4);
2164
2165 const __m256 ve0 = _mm256_fmadd_ps(vt0, vp0, vs0);
2166 const __m256 ve1 = _mm256_fmadd_ps(vt1, vp1, vs1);
2167 const __m256 ve2 = _mm256_fmadd_ps(vt2, vp2, vs2);
2168 const __m256 ve3 = _mm256_fmadd_ps(vt3, vp3, vs3);
2169 const __m256 ve4 = _mm256_fmadd_ps(vt4, vp4, vs4);
2170
2171 const __m256 vd0 = _mm256_add_ps(ve0, vone);
2172 const __m256 vd1 = _mm256_add_ps(ve1, vone);
2173 const __m256 vd2 = _mm256_add_ps(ve2, vone);
2174 const __m256 vd3 = _mm256_add_ps(ve3, vone);
2175 const __m256 vd4 = _mm256_add_ps(ve4, vone);
2176
2177 __m256 vf0 = _mm256_div_ps(ve0, vd0);
2178 __m256 vf1 = _mm256_div_ps(ve1, vd1);
2179 __m256 vf2 = _mm256_div_ps(ve2, vd2);
2180 __m256 vf3 = _mm256_div_ps(ve3, vd3);
2181 __m256 vf4 = _mm256_div_ps(ve4, vd4);
2182
2183 vf0 = _mm256_andnot_ps(_mm256_cmp_ps(vz0, vdenorm_cutoff, _CMP_LT_OS), vf0);
2184 vf1 = _mm256_andnot_ps(_mm256_cmp_ps(vz1, vdenorm_cutoff, _CMP_LT_OS), vf1);
2185 vf2 = _mm256_andnot_ps(_mm256_cmp_ps(vz2, vdenorm_cutoff, _CMP_LT_OS), vf2);
2186 vf3 = _mm256_andnot_ps(_mm256_cmp_ps(vz3, vdenorm_cutoff, _CMP_LT_OS), vf3);
2187 vf4 = _mm256_andnot_ps(_mm256_cmp_ps(vz4, vdenorm_cutoff, _CMP_LT_OS), vf4);
2188
2189 vf0 = _mm256_blendv_ps(_mm256_sub_ps(vone, vf0), vf0, vx0);
2190 vf1 = _mm256_blendv_ps(_mm256_sub_ps(vone, vf1), vf1, vx1);
2191 vf2 = _mm256_blendv_ps(_mm256_sub_ps(vone, vf2), vf2, vx2);
2192 vf3 = _mm256_blendv_ps(_mm256_sub_ps(vone, vf3), vf3, vx3);
2193 vf4 = _mm256_blendv_ps(_mm256_sub_ps(vone, vf4), vf4, vx4);
2194
2195 _mm256_storeu_ps(y, vf0);
2196 _mm256_storeu_ps(y + 8, vf1);
2197 _mm256_storeu_ps(y + 16, vf2);
2198 _mm256_storeu_ps(y + 24, vf3);
2199 _mm256_storeu_ps(y + 32, vf4);
2200 y += 40;
2201 }
2202 for (; n >= 8 * sizeof(float); n -= 8 * sizeof(float)) {
2203 const __m256 vx = _mm256_loadu_ps(x);
2204 x += 8;
2205
2206 const __m256 vz = _mm256_or_ps(vx, vsign_mask);
2207
2208 __m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias);
2209 const __m256 vs = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn), 23));
2210 vn = _mm256_sub_ps(vn, vmagic_bias);
2211
2212 __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz);
2213
2214 __m256 vp = _mm256_fmadd_ps(vc5, vt, vc4);
2215 vp = _mm256_fmadd_ps(vp, vt, vc3);
2216 vp = _mm256_fmadd_ps(vp, vt, vc2);
2217 vp = _mm256_fmadd_ps(vp, vt, vc1);
2218
2219 vt = _mm256_mul_ps(vt, vs);
2220 const __m256 ve = _mm256_fmadd_ps(vt, vp, vs);
2221
2222 const __m256 vd = _mm256_add_ps(ve, vone);
2223 __m256 vf = _mm256_div_ps(ve, vd);
2224
2225 vf = _mm256_andnot_ps(_mm256_cmp_ps(vz, vdenorm_cutoff, _CMP_LT_OS), vf);
2226 vf = _mm256_blendv_ps(_mm256_sub_ps(vone, vf), vf, vx);
2227
2228 _mm256_storeu_ps(y, vf);
2229 y += 8;
2230 }
2231 if XNN_UNLIKELY(n != 0) {
2232 assert(n >= 1 * sizeof(float));
2233 assert(n <= 7 * sizeof(float));
2234 const __m256i vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) ¶ms->avx2_rr1_p5.mask_table[7] - n));
2235
2236 const __m256 vx = _mm256_maskload_ps(x, vmask);
2237
2238 const __m256 vz = _mm256_or_ps(vx, vsign_mask);
2239
2240 __m256 vn = _mm256_fmadd_ps(vz, vlog2e, vmagic_bias);
2241 const __m256 vs = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_castps_si256(vn), 23));
2242 vn = _mm256_sub_ps(vn, vmagic_bias);
2243
2244 __m256 vt = _mm256_fmadd_ps(vn, vminus_ln2, vz);
2245
2246 __m256 vp = _mm256_fmadd_ps(vc5, vt, vc4);
2247 vp = _mm256_fmadd_ps(vp, vt, vc3);
2248 vp = _mm256_fmadd_ps(vp, vt, vc2);
2249 vp = _mm256_fmadd_ps(vp, vt, vc1);
2250
2251 vt = _mm256_mul_ps(vt, vs);
2252 const __m256 ve = _mm256_fmadd_ps(vt, vp, vs);
2253
2254 const __m256 vd = _mm256_add_ps(ve, vone);
2255 __m256 vf = _mm256_div_ps(ve, vd);
2256
2257 vf = _mm256_andnot_ps(_mm256_cmp_ps(vz, vdenorm_cutoff, _CMP_LT_OS), vf);
2258 vf = _mm256_blendv_ps(_mm256_sub_ps(vone, vf), vf, vx);
2259
2260 __m128 vf_lo = _mm256_castps256_ps128(vf);
2261 if (n & (4 * sizeof(float))) {
2262 _mm_storeu_ps(y, vf_lo);
2263 vf_lo = _mm256_extractf128_ps(vf, 1);
2264 y += 4;
2265 }
2266 if (n & (2 * sizeof(float))) {
2267 _mm_storel_pi((__m64*) y, vf_lo);
2268 vf_lo = _mm_movehl_ps(vf_lo, vf_lo);
2269 y += 2;
2270 }
2271 if (n & (1 * sizeof(float))) {
2272 _mm_store_ss(y, vf_lo);
2273 }
2274 }
2275 }
2276
xnn_qc8_dwconv_minmax_fp32_ukernel_up16x25__avx2_mul32(size_t channels,size_t output_width,const int8_t ** input,const void * weights,int8_t * output,size_t input_stride,size_t output_increment,size_t input_offset,const int8_t * zero,const union xnn_qc8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])2277 void xnn_qc8_dwconv_minmax_fp32_ukernel_up16x25__avx2_mul32(
2278 size_t channels,
2279 size_t output_width,
2280 const int8_t** input,
2281 const void* weights,
2282 int8_t* output,
2283 size_t input_stride,
2284 size_t output_increment,
2285 size_t input_offset,
2286 const int8_t* zero,
2287 const union xnn_qc8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
2288 {
2289 assert(channels != 0);
2290 assert(output_width != 0);
2291
2292 do {
2293 const int8_t* i0 = input[0];
2294 assert(i0 != NULL);
2295 if XNN_UNPREDICTABLE(i0 != zero) {
2296 i0 = (const int8_t*) ((uintptr_t) i0 + input_offset);
2297 }
2298 const int8_t* i1 = input[1];
2299 assert(i1 != NULL);
2300 if XNN_UNPREDICTABLE(i1 != zero) {
2301 i1 = (const int8_t*) ((uintptr_t) i1 + input_offset);
2302 }
2303 const int8_t* i2 = input[2];
2304 assert(i2 != NULL);
2305 if XNN_UNPREDICTABLE(i2 != zero) {
2306 i2 = (const int8_t*) ((uintptr_t) i2 + input_offset);
2307 }
2308 const int8_t* i3 = input[3];
2309 assert(i3 != NULL);
2310 if XNN_UNPREDICTABLE(i3 != zero) {
2311 i3 = (const int8_t*) ((uintptr_t) i3 + input_offset);
2312 }
2313 const int8_t* i4 = input[4];
2314 assert(i4 != NULL);
2315 if XNN_UNPREDICTABLE(i4 != zero) {
2316 i4 = (const int8_t*) ((uintptr_t) i4 + input_offset);
2317 }
2318 const int8_t* i5 = input[5];
2319 assert(i5 != NULL);
2320 if XNN_UNPREDICTABLE(i5 != zero) {
2321 i5 = (const int8_t*) ((uintptr_t) i5 + input_offset);
2322 }
2323 const int8_t* i6 = input[6];
2324 assert(i6 != NULL);
2325 if XNN_UNPREDICTABLE(i6 != zero) {
2326 i6 = (const int8_t*) ((uintptr_t) i6 + input_offset);
2327 }
2328 const int8_t* i7 = input[7];
2329 assert(i7 != NULL);
2330 if XNN_UNPREDICTABLE(i7 != zero) {
2331 i7 = (const int8_t*) ((uintptr_t) i7 + input_offset);
2332 }
2333 const int8_t* i8 = input[8];
2334 assert(i8 != NULL);
2335 if XNN_UNPREDICTABLE(i8 != zero) {
2336 i8 = (const int8_t*) ((uintptr_t) i8 + input_offset);
2337 }
2338 const int8_t* i9 = input[9];
2339 assert(i9 != NULL);
2340 if XNN_UNPREDICTABLE(i9 != zero) {
2341 i9 = (const int8_t*) ((uintptr_t) i9 + input_offset);
2342 }
2343 const int8_t* i10 = input[10];
2344 assert(i10 != NULL);
2345 if XNN_UNPREDICTABLE(i10 != zero) {
2346 i10 = (const int8_t*) ((uintptr_t) i10 + input_offset);
2347 }
2348 const int8_t* i11 = input[11];
2349 assert(i11 != NULL);
2350 if XNN_UNPREDICTABLE(i11 != zero) {
2351 i11 = (const int8_t*) ((uintptr_t) i11 + input_offset);
2352 }
2353 const int8_t* i12 = input[12];
2354 assert(i12 != NULL);
2355 if XNN_UNPREDICTABLE(i12 != zero) {
2356 i12 = (const int8_t*) ((uintptr_t) i12 + input_offset);
2357 }
2358 const int8_t* i13 = input[13];
2359 assert(i13 != NULL);
2360 if XNN_UNPREDICTABLE(i13 != zero) {
2361 i13 = (const int8_t*) ((uintptr_t) i13 + input_offset);
2362 }
2363 const int8_t* i14 = input[14];
2364 assert(i14 != NULL);
2365 if XNN_UNPREDICTABLE(i14 != zero) {
2366 i14 = (const int8_t*) ((uintptr_t) i14 + input_offset);
2367 }
2368 const int8_t* i15 = input[15];
2369 assert(i15 != NULL);
2370 if XNN_UNPREDICTABLE(i15 != zero) {
2371 i15 = (const int8_t*) ((uintptr_t) i15 + input_offset);
2372 }
2373 const int8_t* i16 = input[16];
2374 assert(i16 != NULL);
2375 if XNN_UNPREDICTABLE(i16 != zero) {
2376 i16 = (const int8_t*) ((uintptr_t) i16 + input_offset);
2377 }
2378 const int8_t* i17 = input[17];
2379 assert(i17 != NULL);
2380 if XNN_UNPREDICTABLE(i17 != zero) {
2381 i17 = (const int8_t*) ((uintptr_t) i17 + input_offset);
2382 }
2383 const int8_t* i18 = input[18];
2384 assert(i18 != NULL);
2385 if XNN_UNPREDICTABLE(i18 != zero) {
2386 i18 = (const int8_t*) ((uintptr_t) i18 + input_offset);
2387 }
2388 const int8_t* i19 = input[19];
2389 assert(i19 != NULL);
2390 if XNN_UNPREDICTABLE(i19 != zero) {
2391 i19 = (const int8_t*) ((uintptr_t) i19 + input_offset);
2392 }
2393 const int8_t* i20 = input[20];
2394 assert(i20 != NULL);
2395 if XNN_UNPREDICTABLE(i20 != zero) {
2396 i20 = (const int8_t*) ((uintptr_t) i20 + input_offset);
2397 }
2398 const int8_t* i21 = input[21];
2399 assert(i21 != NULL);
2400 if XNN_UNPREDICTABLE(i21 != zero) {
2401 i21 = (const int8_t*) ((uintptr_t) i21 + input_offset);
2402 }
2403 const int8_t* i22 = input[22];
2404 assert(i22 != NULL);
2405 if XNN_UNPREDICTABLE(i22 != zero) {
2406 i22 = (const int8_t*) ((uintptr_t) i22 + input_offset);
2407 }
2408 const int8_t* i23 = input[23];
2409 assert(i23 != NULL);
2410 if XNN_UNPREDICTABLE(i23 != zero) {
2411 i23 = (const int8_t*) ((uintptr_t) i23 + input_offset);
2412 }
2413 const int8_t* i24 = input[24];
2414 assert(i24 != NULL);
2415 if XNN_UNPREDICTABLE(i24 != zero) {
2416 i24 = (const int8_t*) ((uintptr_t) i24 + input_offset);
2417 }
2418 input = (const int8_t**) ((uintptr_t) input + input_stride);
2419
2420 size_t c = channels;
2421 const void* w = weights;
2422 for (; c >= 16; c -= 16) {
2423 __m256i vacc01234567 = _mm256_loadu_si256((const __m256i*) w);
2424 __m256i vacc89ABCDEF = _mm256_loadu_si256((const __m256i*) ((const int32_t*) w + 8));
2425
2426
2427 const __m256i vi0x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i0));
2428 const __m256i vk0x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 0 * sizeof(int8_t))));
2429 const __m256i vi0x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i0 + 8)));
2430 const __m256i vk0x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 8 * sizeof(int8_t))));
2431 i0 += 16;
2432
2433 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi0x01234567, vk0x01234567));
2434 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi0x89ABCDEF, vk0x89ABCDEF));
2435
2436 const __m256i vi1x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i1));
2437 const __m256i vk1x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 16 * sizeof(int8_t))));
2438 const __m256i vi1x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i1 + 8)));
2439 const __m256i vk1x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 24 * sizeof(int8_t))));
2440 i1 += 16;
2441
2442 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi1x01234567, vk1x01234567));
2443 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi1x89ABCDEF, vk1x89ABCDEF));
2444
2445 const __m256i vi2x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i2));
2446 const __m256i vk2x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 32 * sizeof(int8_t))));
2447 const __m256i vi2x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i2 + 8)));
2448 const __m256i vk2x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 40 * sizeof(int8_t))));
2449 i2 += 16;
2450
2451 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi2x01234567, vk2x01234567));
2452 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi2x89ABCDEF, vk2x89ABCDEF));
2453
2454 const __m256i vi3x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i3));
2455 const __m256i vk3x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 48 * sizeof(int8_t))));
2456 const __m256i vi3x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i3 + 8)));
2457 const __m256i vk3x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 56 * sizeof(int8_t))));
2458 i3 += 16;
2459
2460 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi3x01234567, vk3x01234567));
2461 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi3x89ABCDEF, vk3x89ABCDEF));
2462
2463 const __m256i vi4x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i4));
2464 const __m256i vk4x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 64 * sizeof(int8_t))));
2465 const __m256i vi4x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i4 + 8)));
2466 const __m256i vk4x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 72 * sizeof(int8_t))));
2467 i4 += 16;
2468
2469 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi4x01234567, vk4x01234567));
2470 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi4x89ABCDEF, vk4x89ABCDEF));
2471
2472 const __m256i vi5x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i5));
2473 const __m256i vk5x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 80 * sizeof(int8_t))));
2474 const __m256i vi5x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i5 + 8)));
2475 const __m256i vk5x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 88 * sizeof(int8_t))));
2476 i5 += 16;
2477
2478 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi5x01234567, vk5x01234567));
2479 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi5x89ABCDEF, vk5x89ABCDEF));
2480
2481 const __m256i vi6x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i6));
2482 const __m256i vk6x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 96 * sizeof(int8_t))));
2483 const __m256i vi6x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i6 + 8)));
2484 const __m256i vk6x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 104 * sizeof(int8_t))));
2485 i6 += 16;
2486
2487 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi6x01234567, vk6x01234567));
2488 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi6x89ABCDEF, vk6x89ABCDEF));
2489
2490 const __m256i vi7x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i7));
2491 const __m256i vk7x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 112 * sizeof(int8_t))));
2492 const __m256i vi7x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i7 + 8)));
2493 const __m256i vk7x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 120 * sizeof(int8_t))));
2494 i7 += 16;
2495
2496 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi7x01234567, vk7x01234567));
2497 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi7x89ABCDEF, vk7x89ABCDEF));
2498
2499 const __m256i vi8x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i8));
2500 const __m256i vk8x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 128 * sizeof(int8_t))));
2501 const __m256i vi8x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i8 + 8)));
2502 const __m256i vk8x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 136 * sizeof(int8_t))));
2503 i8 += 16;
2504
2505 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi8x01234567, vk8x01234567));
2506 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi8x89ABCDEF, vk8x89ABCDEF));
2507
2508 const __m256i vi9x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i9));
2509 const __m256i vk9x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 144 * sizeof(int8_t))));
2510 const __m256i vi9x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i9 + 8)));
2511 const __m256i vk9x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 152 * sizeof(int8_t))));
2512 i9 += 16;
2513
2514 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi9x01234567, vk9x01234567));
2515 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi9x89ABCDEF, vk9x89ABCDEF));
2516
2517 const __m256i vi10x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i10));
2518 const __m256i vk10x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 160 * sizeof(int8_t))));
2519 const __m256i vi10x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i10 + 8)));
2520 const __m256i vk10x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 168 * sizeof(int8_t))));
2521 i10 += 16;
2522
2523 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi10x01234567, vk10x01234567));
2524 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi10x89ABCDEF, vk10x89ABCDEF));
2525
2526 const __m256i vi11x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i11));
2527 const __m256i vk11x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 176 * sizeof(int8_t))));
2528 const __m256i vi11x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i11 + 8)));
2529 const __m256i vk11x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 184 * sizeof(int8_t))));
2530 i11 += 16;
2531
2532 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi11x01234567, vk11x01234567));
2533 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi11x89ABCDEF, vk11x89ABCDEF));
2534
2535 const __m256i vi12x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i12));
2536 const __m256i vk12x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 192 * sizeof(int8_t))));
2537 const __m256i vi12x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i12 + 8)));
2538 const __m256i vk12x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 200 * sizeof(int8_t))));
2539 i12 += 16;
2540
2541 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi12x01234567, vk12x01234567));
2542 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi12x89ABCDEF, vk12x89ABCDEF));
2543
2544 const __m256i vi13x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i13));
2545 const __m256i vk13x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 208 * sizeof(int8_t))));
2546 const __m256i vi13x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i13 + 8)));
2547 const __m256i vk13x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 216 * sizeof(int8_t))));
2548 i13 += 16;
2549
2550 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi13x01234567, vk13x01234567));
2551 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi13x89ABCDEF, vk13x89ABCDEF));
2552
2553 const __m256i vi14x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i14));
2554 const __m256i vk14x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 224 * sizeof(int8_t))));
2555 const __m256i vi14x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i14 + 8)));
2556 const __m256i vk14x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 232 * sizeof(int8_t))));
2557 i14 += 16;
2558
2559 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi14x01234567, vk14x01234567));
2560 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi14x89ABCDEF, vk14x89ABCDEF));
2561
2562 const __m256i vi15x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i15));
2563 const __m256i vk15x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 240 * sizeof(int8_t))));
2564 const __m256i vi15x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i15 + 8)));
2565 const __m256i vk15x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 248 * sizeof(int8_t))));
2566 i15 += 16;
2567
2568 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi15x01234567, vk15x01234567));
2569 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi15x89ABCDEF, vk15x89ABCDEF));
2570
2571 const __m256i vi16x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i16));
2572 const __m256i vk16x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 256 * sizeof(int8_t))));
2573 const __m256i vi16x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i16 + 8)));
2574 const __m256i vk16x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 264 * sizeof(int8_t))));
2575 i16 += 16;
2576
2577 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi16x01234567, vk16x01234567));
2578 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi16x89ABCDEF, vk16x89ABCDEF));
2579
2580 const __m256i vi17x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i17));
2581 const __m256i vk17x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 272 * sizeof(int8_t))));
2582 const __m256i vi17x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i17 + 8)));
2583 const __m256i vk17x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 280 * sizeof(int8_t))));
2584 i17 += 16;
2585
2586 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi17x01234567, vk17x01234567));
2587 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi17x89ABCDEF, vk17x89ABCDEF));
2588
2589 const __m256i vi18x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i18));
2590 const __m256i vk18x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 288 * sizeof(int8_t))));
2591 const __m256i vi18x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i18 + 8)));
2592 const __m256i vk18x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 296 * sizeof(int8_t))));
2593 i18 += 16;
2594
2595 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi18x01234567, vk18x01234567));
2596 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi18x89ABCDEF, vk18x89ABCDEF));
2597
2598 const __m256i vi19x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i19));
2599 const __m256i vk19x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 304 * sizeof(int8_t))));
2600 const __m256i vi19x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i19 + 8)));
2601 const __m256i vk19x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 312 * sizeof(int8_t))));
2602 i19 += 16;
2603
2604 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi19x01234567, vk19x01234567));
2605 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi19x89ABCDEF, vk19x89ABCDEF));
2606
2607 const __m256i vi20x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i20));
2608 const __m256i vk20x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 320 * sizeof(int8_t))));
2609 const __m256i vi20x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i20 + 8)));
2610 const __m256i vk20x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 328 * sizeof(int8_t))));
2611 i20 += 16;
2612
2613 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi20x01234567, vk20x01234567));
2614 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi20x89ABCDEF, vk20x89ABCDEF));
2615
2616 const __m256i vi21x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i21));
2617 const __m256i vk21x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 336 * sizeof(int8_t))));
2618 const __m256i vi21x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i21 + 8)));
2619 const __m256i vk21x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 344 * sizeof(int8_t))));
2620 i21 += 16;
2621
2622 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi21x01234567, vk21x01234567));
2623 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi21x89ABCDEF, vk21x89ABCDEF));
2624
2625 const __m256i vi22x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i22));
2626 const __m256i vk22x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 352 * sizeof(int8_t))));
2627 const __m256i vi22x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i22 + 8)));
2628 const __m256i vk22x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 360 * sizeof(int8_t))));
2629 i22 += 16;
2630
2631 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi22x01234567, vk22x01234567));
2632 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi22x89ABCDEF, vk22x89ABCDEF));
2633
2634 const __m256i vi23x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i23));
2635 const __m256i vk23x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 368 * sizeof(int8_t))));
2636 const __m256i vi23x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i23 + 8)));
2637 const __m256i vk23x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 376 * sizeof(int8_t))));
2638 i23 += 16;
2639
2640 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi23x01234567, vk23x01234567));
2641 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi23x89ABCDEF, vk23x89ABCDEF));
2642
2643 const __m256i vi24x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i24));
2644 const __m256i vk24x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 384 * sizeof(int8_t))));
2645 const __m256i vi24x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i24 + 8)));
2646 const __m256i vk24x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 392 * sizeof(int8_t))));
2647 i24 += 16;
2648
2649 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi24x01234567, vk24x01234567));
2650 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi24x89ABCDEF, vk24x89ABCDEF));
2651
2652 w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t) + 400 * sizeof(int8_t));
2653
2654 __m256 vscaled01234567 = _mm256_cvtepi32_ps(vacc01234567);
2655 __m256 vscaled89ABCDEF = _mm256_cvtepi32_ps(vacc89ABCDEF);
2656
2657 const __m256 vscale01234567 = _mm256_loadu_ps((const float*) w);
2658 const __m256 vscale89ABCDEF = _mm256_loadu_ps((const float*) w + 8);
2659 w = (const void*) ((const float*) w + 16);
2660 vscaled01234567 = _mm256_mul_ps(vscaled01234567, vscale01234567);
2661 vscaled89ABCDEF = _mm256_mul_ps(vscaled89ABCDEF, vscale89ABCDEF);
2662
2663 const __m256 voutput_max_less_zero_point = _mm256_load_ps(params->fp32_avx2.output_max_less_zero_point);
2664 vscaled01234567 = _mm256_min_ps(vscaled01234567, voutput_max_less_zero_point);
2665 vscaled89ABCDEF = _mm256_min_ps(vscaled89ABCDEF, voutput_max_less_zero_point);
2666
2667 vacc01234567 = _mm256_cvtps_epi32(vscaled01234567);
2668 vacc89ABCDEF = _mm256_cvtps_epi32(vscaled89ABCDEF);
2669
2670 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx2.output_zero_point);
2671 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(vacc01234567, vacc89ABCDEF), voutput_zero_point);
2672
2673 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packs_epi16(_mm256_castsi256_si128(vout012389AB4567CDEF), _mm256_extracti128_si256(vout012389AB4567CDEF, 1)), _MM_SHUFFLE(3, 1, 2, 0));
2674
2675 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx2.output_min);
2676 vout0123456789ABCDEF = _mm_max_epi8(vout0123456789ABCDEF, voutput_min);
2677
2678 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
2679 output += 16;
2680 }
2681 if XNN_UNLIKELY(c != 0) {
2682 const int8_t* k = (const int8_t*) ((const int32_t*) w + 16);
2683 do {
2684 __m256i vacc01234567 = _mm256_loadu_si256((const __m256i*) w);
2685
2686
2687 const __m256i vi0x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i0));
2688 const __m256i vk0x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) k));
2689 i0 += 8;
2690
2691 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi0x01234567, vk0x01234567));
2692
2693 const __m256i vi1x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i1));
2694 const __m256i vk1x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 16)));
2695 i1 += 8;
2696
2697 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi1x01234567, vk1x01234567));
2698
2699 const __m256i vi2x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i2));
2700 const __m256i vk2x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 32)));
2701 i2 += 8;
2702
2703 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi2x01234567, vk2x01234567));
2704
2705 const __m256i vi3x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i3));
2706 const __m256i vk3x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 48)));
2707 i3 += 8;
2708
2709 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi3x01234567, vk3x01234567));
2710
2711 const __m256i vi4x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i4));
2712 const __m256i vk4x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 64)));
2713 i4 += 8;
2714
2715 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi4x01234567, vk4x01234567));
2716
2717 const __m256i vi5x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i5));
2718 const __m256i vk5x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 80)));
2719 i5 += 8;
2720
2721 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi5x01234567, vk5x01234567));
2722
2723 const __m256i vi6x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i6));
2724 const __m256i vk6x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 96)));
2725 i6 += 8;
2726
2727 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi6x01234567, vk6x01234567));
2728
2729 const __m256i vi7x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i7));
2730 const __m256i vk7x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 112)));
2731 i7 += 8;
2732
2733 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi7x01234567, vk7x01234567));
2734
2735 const __m256i vi8x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i8));
2736 const __m256i vk8x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 128)));
2737 i8 += 8;
2738
2739 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi8x01234567, vk8x01234567));
2740
2741 const __m256i vi9x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i9));
2742 const __m256i vk9x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 144)));
2743 i9 += 8;
2744
2745 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi9x01234567, vk9x01234567));
2746
2747 const __m256i vi10x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i10));
2748 const __m256i vk10x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 160)));
2749 i10 += 8;
2750
2751 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi10x01234567, vk10x01234567));
2752
2753 const __m256i vi11x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i11));
2754 const __m256i vk11x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 176)));
2755 i11 += 8;
2756
2757 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi11x01234567, vk11x01234567));
2758
2759 const __m256i vi12x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i12));
2760 const __m256i vk12x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 192)));
2761 i12 += 8;
2762
2763 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi12x01234567, vk12x01234567));
2764
2765 const __m256i vi13x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i13));
2766 const __m256i vk13x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 208)));
2767 i13 += 8;
2768
2769 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi13x01234567, vk13x01234567));
2770
2771 const __m256i vi14x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i14));
2772 const __m256i vk14x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 224)));
2773 i14 += 8;
2774
2775 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi14x01234567, vk14x01234567));
2776
2777 const __m256i vi15x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i15));
2778 const __m256i vk15x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 240)));
2779 i15 += 8;
2780
2781 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi15x01234567, vk15x01234567));
2782
2783 const __m256i vi16x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i16));
2784 const __m256i vk16x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 256)));
2785 i16 += 8;
2786
2787 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi16x01234567, vk16x01234567));
2788
2789 const __m256i vi17x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i17));
2790 const __m256i vk17x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 272)));
2791 i17 += 8;
2792
2793 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi17x01234567, vk17x01234567));
2794
2795 const __m256i vi18x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i18));
2796 const __m256i vk18x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 288)));
2797 i18 += 8;
2798
2799 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi18x01234567, vk18x01234567));
2800
2801 const __m256i vi19x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i19));
2802 const __m256i vk19x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 304)));
2803 i19 += 8;
2804
2805 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi19x01234567, vk19x01234567));
2806
2807 const __m256i vi20x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i20));
2808 const __m256i vk20x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 320)));
2809 i20 += 8;
2810
2811 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi20x01234567, vk20x01234567));
2812
2813 const __m256i vi21x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i21));
2814 const __m256i vk21x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 336)));
2815 i21 += 8;
2816
2817 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi21x01234567, vk21x01234567));
2818
2819 const __m256i vi22x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i22));
2820 const __m256i vk22x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 352)));
2821 i22 += 8;
2822
2823 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi22x01234567, vk22x01234567));
2824
2825 const __m256i vi23x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i23));
2826 const __m256i vk23x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 368)));
2827 i23 += 8;
2828
2829 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi23x01234567, vk23x01234567));
2830
2831 const __m256i vi24x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i24));
2832 const __m256i vk24x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 384)));
2833 i24 += 8;
2834
2835 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi24x01234567, vk24x01234567));
2836
2837 k += 8;
2838
2839 __m256 vscaled01234567 = _mm256_cvtepi32_ps(vacc01234567);
2840 const __m256 vscale01234567 = _mm256_loadu_ps((const float*) ((uintptr_t) w + 16 * sizeof(int32_t) + 400 * sizeof(int8_t)));
2841 vscaled01234567 = _mm256_mul_ps(vscaled01234567, vscale01234567);
2842 vscaled01234567 = _mm256_min_ps(vscaled01234567, _mm256_load_ps(params->fp32_avx2.output_max_less_zero_point));
2843 vacc01234567 = _mm256_cvtps_epi32(vscaled01234567);
2844
2845 w = (const void*) ((const int32_t*) w + 8);
2846
2847 const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->fp32_avx2.output_zero_point);
2848 __m128i vout01234567 = _mm_adds_epi16(_mm_packs_epi32(_mm256_castsi256_si128(vacc01234567), _mm256_extracti128_si256(vacc01234567, 1)), voutput_zero_point);
2849
2850 __m128i vout0123456701234567 = _mm_packs_epi16(vout01234567, vout01234567);
2851
2852 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx2.output_min);
2853 vout0123456701234567 = _mm_max_epi8(vout0123456701234567, voutput_min);
2854
2855 if XNN_LIKELY(c >= 8) {
2856 _mm_storel_epi64((__m128i*) output, vout0123456701234567);
2857 output += 8;
2858 c -= 8;
2859 } else {
2860 if (c & 4) {
2861 unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vout0123456701234567));
2862 vout0123456701234567 = _mm_srli_epi64(vout0123456701234567, 32);
2863 output += 4;
2864 }
2865 if (c & 2) {
2866 unaligned_store_u16(output, (uint16_t) _mm_extract_epi16(vout0123456701234567, 0));
2867 vout0123456701234567 = _mm_srli_epi32(vout0123456701234567, 16);
2868 output += 2;
2869 }
2870 if (c & 1) {
2871 *output = (int8_t) _mm_extract_epi8(vout0123456701234567, 0);
2872 output += 1;
2873 }
2874 c = 0;
2875 }
2876 } while (c != 0);
2877 }
2878
2879 output = (int8_t*) ((uintptr_t) output + output_increment);
2880 } while (--output_width != 0);
2881 }
2882
xnn_qc8_dwconv_minmax_fp32_ukernel_up16x3__avx2_mul32(size_t channels,size_t output_width,const int8_t ** input,const void * weights,int8_t * output,size_t input_stride,size_t output_increment,size_t input_offset,const int8_t * zero,const union xnn_qc8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])2883 void xnn_qc8_dwconv_minmax_fp32_ukernel_up16x3__avx2_mul32(
2884 size_t channels,
2885 size_t output_width,
2886 const int8_t** input,
2887 const void* weights,
2888 int8_t* output,
2889 size_t input_stride,
2890 size_t output_increment,
2891 size_t input_offset,
2892 const int8_t* zero,
2893 const union xnn_qc8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
2894 {
2895 assert(channels != 0);
2896 assert(output_width != 0);
2897
2898 do {
2899 const int8_t* i0 = input[0];
2900 assert(i0 != NULL);
2901 if XNN_UNPREDICTABLE(i0 != zero) {
2902 i0 = (const int8_t*) ((uintptr_t) i0 + input_offset);
2903 }
2904 const int8_t* i1 = input[1];
2905 assert(i1 != NULL);
2906 if XNN_UNPREDICTABLE(i1 != zero) {
2907 i1 = (const int8_t*) ((uintptr_t) i1 + input_offset);
2908 }
2909 const int8_t* i2 = input[2];
2910 assert(i2 != NULL);
2911 if XNN_UNPREDICTABLE(i2 != zero) {
2912 i2 = (const int8_t*) ((uintptr_t) i2 + input_offset);
2913 }
2914 input = (const int8_t**) ((uintptr_t) input + input_stride);
2915
2916 size_t c = channels;
2917 const void* w = weights;
2918 for (; c >= 16; c -= 16) {
2919 __m256i vacc01234567 = _mm256_loadu_si256((const __m256i*) w);
2920 __m256i vacc89ABCDEF = _mm256_loadu_si256((const __m256i*) ((const int32_t*) w + 8));
2921
2922
2923 const __m256i vi0x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i0));
2924 const __m256i vk0x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 0 * sizeof(int8_t))));
2925 const __m256i vi0x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i0 + 8)));
2926 const __m256i vk0x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 8 * sizeof(int8_t))));
2927 i0 += 16;
2928
2929 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi0x01234567, vk0x01234567));
2930 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi0x89ABCDEF, vk0x89ABCDEF));
2931
2932 const __m256i vi1x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i1));
2933 const __m256i vk1x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 16 * sizeof(int8_t))));
2934 const __m256i vi1x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i1 + 8)));
2935 const __m256i vk1x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 24 * sizeof(int8_t))));
2936 i1 += 16;
2937
2938 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi1x01234567, vk1x01234567));
2939 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi1x89ABCDEF, vk1x89ABCDEF));
2940
2941 const __m256i vi2x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i2));
2942 const __m256i vk2x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 32 * sizeof(int8_t))));
2943 const __m256i vi2x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i2 + 8)));
2944 const __m256i vk2x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 40 * sizeof(int8_t))));
2945 i2 += 16;
2946
2947 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi2x01234567, vk2x01234567));
2948 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi2x89ABCDEF, vk2x89ABCDEF));
2949
2950 w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t) + 48 * sizeof(int8_t));
2951
2952 __m256 vscaled01234567 = _mm256_cvtepi32_ps(vacc01234567);
2953 __m256 vscaled89ABCDEF = _mm256_cvtepi32_ps(vacc89ABCDEF);
2954
2955 const __m256 vscale01234567 = _mm256_loadu_ps((const float*) w);
2956 const __m256 vscale89ABCDEF = _mm256_loadu_ps((const float*) w + 8);
2957 w = (const void*) ((const float*) w + 16);
2958 vscaled01234567 = _mm256_mul_ps(vscaled01234567, vscale01234567);
2959 vscaled89ABCDEF = _mm256_mul_ps(vscaled89ABCDEF, vscale89ABCDEF);
2960
2961 const __m256 voutput_max_less_zero_point = _mm256_load_ps(params->fp32_avx2.output_max_less_zero_point);
2962 vscaled01234567 = _mm256_min_ps(vscaled01234567, voutput_max_less_zero_point);
2963 vscaled89ABCDEF = _mm256_min_ps(vscaled89ABCDEF, voutput_max_less_zero_point);
2964
2965 vacc01234567 = _mm256_cvtps_epi32(vscaled01234567);
2966 vacc89ABCDEF = _mm256_cvtps_epi32(vscaled89ABCDEF);
2967
2968 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx2.output_zero_point);
2969 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(vacc01234567, vacc89ABCDEF), voutput_zero_point);
2970
2971 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packs_epi16(_mm256_castsi256_si128(vout012389AB4567CDEF), _mm256_extracti128_si256(vout012389AB4567CDEF, 1)), _MM_SHUFFLE(3, 1, 2, 0));
2972
2973 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx2.output_min);
2974 vout0123456789ABCDEF = _mm_max_epi8(vout0123456789ABCDEF, voutput_min);
2975
2976 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
2977 output += 16;
2978 }
2979 if XNN_UNLIKELY(c != 0) {
2980 const int8_t* k = (const int8_t*) ((const int32_t*) w + 16);
2981 do {
2982 __m256i vacc01234567 = _mm256_loadu_si256((const __m256i*) w);
2983
2984
2985 const __m256i vi0x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i0));
2986 const __m256i vk0x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) k));
2987 i0 += 8;
2988
2989 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi0x01234567, vk0x01234567));
2990
2991 const __m256i vi1x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i1));
2992 const __m256i vk1x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 16)));
2993 i1 += 8;
2994
2995 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi1x01234567, vk1x01234567));
2996
2997 const __m256i vi2x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i2));
2998 const __m256i vk2x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 32)));
2999 i2 += 8;
3000
3001 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi2x01234567, vk2x01234567));
3002
3003 k += 8;
3004
3005 __m256 vscaled01234567 = _mm256_cvtepi32_ps(vacc01234567);
3006 const __m256 vscale01234567 = _mm256_loadu_ps((const float*) ((uintptr_t) w + 16 * sizeof(int32_t) + 48 * sizeof(int8_t)));
3007 vscaled01234567 = _mm256_mul_ps(vscaled01234567, vscale01234567);
3008 vscaled01234567 = _mm256_min_ps(vscaled01234567, _mm256_load_ps(params->fp32_avx2.output_max_less_zero_point));
3009 vacc01234567 = _mm256_cvtps_epi32(vscaled01234567);
3010
3011 w = (const void*) ((const int32_t*) w + 8);
3012
3013 const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->fp32_avx2.output_zero_point);
3014 __m128i vout01234567 = _mm_adds_epi16(_mm_packs_epi32(_mm256_castsi256_si128(vacc01234567), _mm256_extracti128_si256(vacc01234567, 1)), voutput_zero_point);
3015
3016 __m128i vout0123456701234567 = _mm_packs_epi16(vout01234567, vout01234567);
3017
3018 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx2.output_min);
3019 vout0123456701234567 = _mm_max_epi8(vout0123456701234567, voutput_min);
3020
3021 if XNN_LIKELY(c >= 8) {
3022 _mm_storel_epi64((__m128i*) output, vout0123456701234567);
3023 output += 8;
3024 c -= 8;
3025 } else {
3026 if (c & 4) {
3027 unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vout0123456701234567));
3028 vout0123456701234567 = _mm_srli_epi64(vout0123456701234567, 32);
3029 output += 4;
3030 }
3031 if (c & 2) {
3032 unaligned_store_u16(output, (uint16_t) _mm_extract_epi16(vout0123456701234567, 0));
3033 vout0123456701234567 = _mm_srli_epi32(vout0123456701234567, 16);
3034 output += 2;
3035 }
3036 if (c & 1) {
3037 *output = (int8_t) _mm_extract_epi8(vout0123456701234567, 0);
3038 output += 1;
3039 }
3040 c = 0;
3041 }
3042 } while (c != 0);
3043 }
3044
3045 output = (int8_t*) ((uintptr_t) output + output_increment);
3046 } while (--output_width != 0);
3047 }
3048
xnn_qc8_dwconv_minmax_fp32_ukernel_up16x9__avx2_mul32(size_t channels,size_t output_width,const int8_t ** input,const void * weights,int8_t * output,size_t input_stride,size_t output_increment,size_t input_offset,const int8_t * zero,const union xnn_qc8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])3049 void xnn_qc8_dwconv_minmax_fp32_ukernel_up16x9__avx2_mul32(
3050 size_t channels,
3051 size_t output_width,
3052 const int8_t** input,
3053 const void* weights,
3054 int8_t* output,
3055 size_t input_stride,
3056 size_t output_increment,
3057 size_t input_offset,
3058 const int8_t* zero,
3059 const union xnn_qc8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
3060 {
3061 assert(channels != 0);
3062 assert(output_width != 0);
3063
3064 do {
3065 const int8_t* i0 = input[0];
3066 assert(i0 != NULL);
3067 if XNN_UNPREDICTABLE(i0 != zero) {
3068 i0 = (const int8_t*) ((uintptr_t) i0 + input_offset);
3069 }
3070 const int8_t* i1 = input[1];
3071 assert(i1 != NULL);
3072 if XNN_UNPREDICTABLE(i1 != zero) {
3073 i1 = (const int8_t*) ((uintptr_t) i1 + input_offset);
3074 }
3075 const int8_t* i2 = input[2];
3076 assert(i2 != NULL);
3077 if XNN_UNPREDICTABLE(i2 != zero) {
3078 i2 = (const int8_t*) ((uintptr_t) i2 + input_offset);
3079 }
3080 const int8_t* i3 = input[3];
3081 assert(i3 != NULL);
3082 if XNN_UNPREDICTABLE(i3 != zero) {
3083 i3 = (const int8_t*) ((uintptr_t) i3 + input_offset);
3084 }
3085 const int8_t* i4 = input[4];
3086 assert(i4 != NULL);
3087 if XNN_UNPREDICTABLE(i4 != zero) {
3088 i4 = (const int8_t*) ((uintptr_t) i4 + input_offset);
3089 }
3090 const int8_t* i5 = input[5];
3091 assert(i5 != NULL);
3092 if XNN_UNPREDICTABLE(i5 != zero) {
3093 i5 = (const int8_t*) ((uintptr_t) i5 + input_offset);
3094 }
3095 const int8_t* i6 = input[6];
3096 assert(i6 != NULL);
3097 if XNN_UNPREDICTABLE(i6 != zero) {
3098 i6 = (const int8_t*) ((uintptr_t) i6 + input_offset);
3099 }
3100 const int8_t* i7 = input[7];
3101 assert(i7 != NULL);
3102 if XNN_UNPREDICTABLE(i7 != zero) {
3103 i7 = (const int8_t*) ((uintptr_t) i7 + input_offset);
3104 }
3105 const int8_t* i8 = input[8];
3106 assert(i8 != NULL);
3107 if XNN_UNPREDICTABLE(i8 != zero) {
3108 i8 = (const int8_t*) ((uintptr_t) i8 + input_offset);
3109 }
3110 input = (const int8_t**) ((uintptr_t) input + input_stride);
3111
3112 size_t c = channels;
3113 const void* w = weights;
3114 for (; c >= 16; c -= 16) {
3115 __m256i vacc01234567 = _mm256_loadu_si256((const __m256i*) w);
3116 __m256i vacc89ABCDEF = _mm256_loadu_si256((const __m256i*) ((const int32_t*) w + 8));
3117
3118
3119 const __m256i vi0x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i0));
3120 const __m256i vk0x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 0 * sizeof(int8_t))));
3121 const __m256i vi0x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i0 + 8)));
3122 const __m256i vk0x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 8 * sizeof(int8_t))));
3123 i0 += 16;
3124
3125 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi0x01234567, vk0x01234567));
3126 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi0x89ABCDEF, vk0x89ABCDEF));
3127
3128 const __m256i vi1x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i1));
3129 const __m256i vk1x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 16 * sizeof(int8_t))));
3130 const __m256i vi1x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i1 + 8)));
3131 const __m256i vk1x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 24 * sizeof(int8_t))));
3132 i1 += 16;
3133
3134 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi1x01234567, vk1x01234567));
3135 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi1x89ABCDEF, vk1x89ABCDEF));
3136
3137 const __m256i vi2x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i2));
3138 const __m256i vk2x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 32 * sizeof(int8_t))));
3139 const __m256i vi2x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i2 + 8)));
3140 const __m256i vk2x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 40 * sizeof(int8_t))));
3141 i2 += 16;
3142
3143 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi2x01234567, vk2x01234567));
3144 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi2x89ABCDEF, vk2x89ABCDEF));
3145
3146 const __m256i vi3x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i3));
3147 const __m256i vk3x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 48 * sizeof(int8_t))));
3148 const __m256i vi3x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i3 + 8)));
3149 const __m256i vk3x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 56 * sizeof(int8_t))));
3150 i3 += 16;
3151
3152 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi3x01234567, vk3x01234567));
3153 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi3x89ABCDEF, vk3x89ABCDEF));
3154
3155 const __m256i vi4x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i4));
3156 const __m256i vk4x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 64 * sizeof(int8_t))));
3157 const __m256i vi4x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i4 + 8)));
3158 const __m256i vk4x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 72 * sizeof(int8_t))));
3159 i4 += 16;
3160
3161 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi4x01234567, vk4x01234567));
3162 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi4x89ABCDEF, vk4x89ABCDEF));
3163
3164 const __m256i vi5x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i5));
3165 const __m256i vk5x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 80 * sizeof(int8_t))));
3166 const __m256i vi5x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i5 + 8)));
3167 const __m256i vk5x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 88 * sizeof(int8_t))));
3168 i5 += 16;
3169
3170 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi5x01234567, vk5x01234567));
3171 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi5x89ABCDEF, vk5x89ABCDEF));
3172
3173 const __m256i vi6x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i6));
3174 const __m256i vk6x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 96 * sizeof(int8_t))));
3175 const __m256i vi6x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i6 + 8)));
3176 const __m256i vk6x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 104 * sizeof(int8_t))));
3177 i6 += 16;
3178
3179 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi6x01234567, vk6x01234567));
3180 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi6x89ABCDEF, vk6x89ABCDEF));
3181
3182 const __m256i vi7x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i7));
3183 const __m256i vk7x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 112 * sizeof(int8_t))));
3184 const __m256i vi7x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i7 + 8)));
3185 const __m256i vk7x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 120 * sizeof(int8_t))));
3186 i7 += 16;
3187
3188 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi7x01234567, vk7x01234567));
3189 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi7x89ABCDEF, vk7x89ABCDEF));
3190
3191 const __m256i vi8x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i8));
3192 const __m256i vk8x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 128 * sizeof(int8_t))));
3193 const __m256i vi8x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i8 + 8)));
3194 const __m256i vk8x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 136 * sizeof(int8_t))));
3195 i8 += 16;
3196
3197 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi8x01234567, vk8x01234567));
3198 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi8x89ABCDEF, vk8x89ABCDEF));
3199
3200 w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t) + 144 * sizeof(int8_t));
3201
3202 __m256 vscaled01234567 = _mm256_cvtepi32_ps(vacc01234567);
3203 __m256 vscaled89ABCDEF = _mm256_cvtepi32_ps(vacc89ABCDEF);
3204
3205 const __m256 vscale01234567 = _mm256_loadu_ps((const float*) w);
3206 const __m256 vscale89ABCDEF = _mm256_loadu_ps((const float*) w + 8);
3207 w = (const void*) ((const float*) w + 16);
3208 vscaled01234567 = _mm256_mul_ps(vscaled01234567, vscale01234567);
3209 vscaled89ABCDEF = _mm256_mul_ps(vscaled89ABCDEF, vscale89ABCDEF);
3210
3211 const __m256 voutput_max_less_zero_point = _mm256_load_ps(params->fp32_avx2.output_max_less_zero_point);
3212 vscaled01234567 = _mm256_min_ps(vscaled01234567, voutput_max_less_zero_point);
3213 vscaled89ABCDEF = _mm256_min_ps(vscaled89ABCDEF, voutput_max_less_zero_point);
3214
3215 vacc01234567 = _mm256_cvtps_epi32(vscaled01234567);
3216 vacc89ABCDEF = _mm256_cvtps_epi32(vscaled89ABCDEF);
3217
3218 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx2.output_zero_point);
3219 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(vacc01234567, vacc89ABCDEF), voutput_zero_point);
3220
3221 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packs_epi16(_mm256_castsi256_si128(vout012389AB4567CDEF), _mm256_extracti128_si256(vout012389AB4567CDEF, 1)), _MM_SHUFFLE(3, 1, 2, 0));
3222
3223 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx2.output_min);
3224 vout0123456789ABCDEF = _mm_max_epi8(vout0123456789ABCDEF, voutput_min);
3225
3226 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
3227 output += 16;
3228 }
3229 if XNN_UNLIKELY(c != 0) {
3230 const int8_t* k = (const int8_t*) ((const int32_t*) w + 16);
3231 do {
3232 __m256i vacc01234567 = _mm256_loadu_si256((const __m256i*) w);
3233
3234
3235 const __m256i vi0x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i0));
3236 const __m256i vk0x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) k));
3237 i0 += 8;
3238
3239 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi0x01234567, vk0x01234567));
3240
3241 const __m256i vi1x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i1));
3242 const __m256i vk1x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 16)));
3243 i1 += 8;
3244
3245 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi1x01234567, vk1x01234567));
3246
3247 const __m256i vi2x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i2));
3248 const __m256i vk2x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 32)));
3249 i2 += 8;
3250
3251 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi2x01234567, vk2x01234567));
3252
3253 const __m256i vi3x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i3));
3254 const __m256i vk3x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 48)));
3255 i3 += 8;
3256
3257 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi3x01234567, vk3x01234567));
3258
3259 const __m256i vi4x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i4));
3260 const __m256i vk4x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 64)));
3261 i4 += 8;
3262
3263 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi4x01234567, vk4x01234567));
3264
3265 const __m256i vi5x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i5));
3266 const __m256i vk5x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 80)));
3267 i5 += 8;
3268
3269 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi5x01234567, vk5x01234567));
3270
3271 const __m256i vi6x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i6));
3272 const __m256i vk6x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 96)));
3273 i6 += 8;
3274
3275 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi6x01234567, vk6x01234567));
3276
3277 const __m256i vi7x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i7));
3278 const __m256i vk7x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 112)));
3279 i7 += 8;
3280
3281 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi7x01234567, vk7x01234567));
3282
3283 const __m256i vi8x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i8));
3284 const __m256i vk8x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 128)));
3285 i8 += 8;
3286
3287 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi8x01234567, vk8x01234567));
3288
3289 k += 8;
3290
3291 __m256 vscaled01234567 = _mm256_cvtepi32_ps(vacc01234567);
3292 const __m256 vscale01234567 = _mm256_loadu_ps((const float*) ((uintptr_t) w + 16 * sizeof(int32_t) + 144 * sizeof(int8_t)));
3293 vscaled01234567 = _mm256_mul_ps(vscaled01234567, vscale01234567);
3294 vscaled01234567 = _mm256_min_ps(vscaled01234567, _mm256_load_ps(params->fp32_avx2.output_max_less_zero_point));
3295 vacc01234567 = _mm256_cvtps_epi32(vscaled01234567);
3296
3297 w = (const void*) ((const int32_t*) w + 8);
3298
3299 const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->fp32_avx2.output_zero_point);
3300 __m128i vout01234567 = _mm_adds_epi16(_mm_packs_epi32(_mm256_castsi256_si128(vacc01234567), _mm256_extracti128_si256(vacc01234567, 1)), voutput_zero_point);
3301
3302 __m128i vout0123456701234567 = _mm_packs_epi16(vout01234567, vout01234567);
3303
3304 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx2.output_min);
3305 vout0123456701234567 = _mm_max_epi8(vout0123456701234567, voutput_min);
3306
3307 if XNN_LIKELY(c >= 8) {
3308 _mm_storel_epi64((__m128i*) output, vout0123456701234567);
3309 output += 8;
3310 c -= 8;
3311 } else {
3312 if (c & 4) {
3313 unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vout0123456701234567));
3314 vout0123456701234567 = _mm_srli_epi64(vout0123456701234567, 32);
3315 output += 4;
3316 }
3317 if (c & 2) {
3318 unaligned_store_u16(output, (uint16_t) _mm_extract_epi16(vout0123456701234567, 0));
3319 vout0123456701234567 = _mm_srli_epi32(vout0123456701234567, 16);
3320 output += 2;
3321 }
3322 if (c & 1) {
3323 *output = (int8_t) _mm_extract_epi8(vout0123456701234567, 0);
3324 output += 1;
3325 }
3326 c = 0;
3327 }
3328 } while (c != 0);
3329 }
3330
3331 output = (int8_t*) ((uintptr_t) output + output_increment);
3332 } while (--output_width != 0);
3333 }
3334
xnn_qc8_gemm_minmax_fp32_ukernel_1x8c8__avx2(size_t mr,size_t nc,size_t kc,const int8_t * restrict a,size_t a_stride,const void * restrict w,int8_t * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_qc8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])3335 void xnn_qc8_gemm_minmax_fp32_ukernel_1x8c8__avx2(
3336 size_t mr,
3337 size_t nc,
3338 size_t kc,
3339 const int8_t* restrict a,
3340 size_t a_stride,
3341 const void* restrict w,
3342 int8_t* restrict c,
3343 size_t cm_stride,
3344 size_t cn_stride,
3345 const union xnn_qc8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
3346 {
3347 assert(mr != 0);
3348 assert(mr <= 1);
3349 assert(nc != 0);
3350 assert(kc != 0);
3351 assert(kc % sizeof(int8_t) == 0);
3352 assert(a != NULL);
3353 assert(w != NULL);
3354 assert(c != NULL);
3355
3356 kc = round_up_po2(kc, 8);
3357 const int8_t* a0 = a;
3358 int8_t* c0 = c;
3359
3360 do {
3361 const __m128i vbias0x0 = _mm_cvtsi32_si128(((const int*) w)[0]);
3362 const __m128i vbias0x1 = _mm_cvtsi32_si128(((const int*) w)[1]);
3363 __m256i vacc0x01 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x0), vbias0x1, 1);
3364 const __m128i vbias0x2 = _mm_cvtsi32_si128(((const int*) w)[2]);
3365 const __m128i vbias0x3 = _mm_cvtsi32_si128(((const int*) w)[3]);
3366 __m256i vacc0x23 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x2), vbias0x3, 1);
3367 const __m128i vbias0x4 = _mm_cvtsi32_si128(((const int*) w)[4]);
3368 const __m128i vbias0x5 = _mm_cvtsi32_si128(((const int*) w)[5]);
3369 __m256i vacc0x45 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x4), vbias0x5, 1);
3370 const __m128i vbias0x6 = _mm_cvtsi32_si128(((const int*) w)[6]);
3371 const __m128i vbias0x7 = _mm_cvtsi32_si128(((const int*) w)[7]);
3372 __m256i vacc0x67 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x6), vbias0x7, 1);
3373 w = (const int32_t*) w + 8;
3374
3375 size_t k = 0;
3376 while (k < kc) {
3377 const __m128i va0 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a0));
3378 const __m256i vxa0 = _mm256_cvtepi8_epi16(va0);
3379 a0 += 8;
3380
3381 const __m128i vb01 = _mm_load_si128((const __m128i*) w);
3382 const __m256i vxb01 = _mm256_cvtepi8_epi16(vb01);
3383
3384 vacc0x01 = _mm256_add_epi32(vacc0x01, _mm256_madd_epi16(vxa0, vxb01));
3385 const __m128i vb23 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 16));
3386 const __m256i vxb23 = _mm256_cvtepi8_epi16(vb23);
3387
3388 vacc0x23 = _mm256_add_epi32(vacc0x23, _mm256_madd_epi16(vxa0, vxb23));
3389 const __m128i vb45 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 32));
3390 const __m256i vxb45 = _mm256_cvtepi8_epi16(vb45);
3391
3392 vacc0x45 = _mm256_add_epi32(vacc0x45, _mm256_madd_epi16(vxa0, vxb45));
3393 const __m128i vb67 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 48));
3394 const __m256i vxb67 = _mm256_cvtepi8_epi16(vb67);
3395
3396 vacc0x67 = _mm256_add_epi32(vacc0x67, _mm256_madd_epi16(vxa0, vxb67));
3397
3398 w = (const void*) ((const int8_t*) w + 64);
3399 k += 8 * sizeof(int8_t);
3400 }
3401
3402 const __m256i vacc0x0213 = _mm256_hadd_epi32(vacc0x01, vacc0x23);
3403 const __m256i vacc0x4657 = _mm256_hadd_epi32(vacc0x45, vacc0x67);
3404
3405 const __m256i vacc0x02461357 = _mm256_hadd_epi32(vacc0x0213, vacc0x4657);
3406
3407 const __m256i vpermute_mask = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
3408 __m256i vacc0x01234567 = _mm256_permutevar8x32_epi32(vacc0x02461357, vpermute_mask);
3409
3410 __m256 vscaled0x01234567 = _mm256_cvtepi32_ps(vacc0x01234567);
3411
3412 const __m256 vscale01234567 = _mm256_load_ps(w);
3413 w = (const void*) ((const float*) w + 8);
3414 vscaled0x01234567 = _mm256_mul_ps(vscaled0x01234567, vscale01234567);
3415
3416 const __m256 voutput_max_less_zero_point = _mm256_load_ps(params->fp32_avx2.output_max_less_zero_point);
3417 vscaled0x01234567 = _mm256_min_ps(vscaled0x01234567, voutput_max_less_zero_point);
3418
3419 vacc0x01234567 = _mm256_cvtps_epi32(vscaled0x01234567);
3420
3421 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx2.output_zero_point);
3422 __m256i vacc00x01234567 = _mm256_adds_epi16(_mm256_packs_epi32(vacc0x01234567, vacc0x01234567), voutput_zero_point);
3423
3424 vacc00x01234567 = _mm256_permute4x64_epi64(vacc00x01234567, _MM_SHUFFLE(3, 1, 2, 0));
3425
3426 __m256i vout = _mm256_packs_epi16(vacc00x01234567, vacc00x01234567);
3427
3428 vout = _mm256_max_epi8(vout, _mm256_load_si256((const __m256i*) params->fp32_avx2.output_min));
3429
3430 __m128i vout_lo = _mm256_castsi256_si128(vout);
3431 __m128i vout_hi = _mm256_extracti128_si256(vout, 1);
3432
3433 if (nc >= 8) {
3434 _mm_storel_epi64((__m128i*) c0, vout_lo);
3435
3436 c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
3437
3438 a0 = (const int8_t*) ((uintptr_t) a0 - kc);
3439
3440 nc -= 8;
3441 } else {
3442 if (nc & 4) {
3443 _mm_storeu_si32(c0, vout_lo);
3444
3445 c0 += 4;
3446
3447 vout_lo = _mm_srli_epi64(vout_lo, 32);
3448 vout_hi = _mm_srli_epi64(vout_hi, 32);
3449 }
3450 if (nc & 2) {
3451 unaligned_store_u16(c0, (uint16_t) _mm_extract_epi16(vout_lo, 0));
3452
3453 c0 += 2;
3454
3455 vout_lo = _mm_srli_epi32(vout_lo, 16);
3456 vout_hi = _mm_srli_epi32(vout_hi, 16);
3457 }
3458 if (nc & 1) {
3459 *c0 = (int8_t) _mm_extract_epi8(vout_lo, 0);
3460 }
3461
3462 nc = 0;
3463 }
3464 } while (nc != 0);
3465 }
3466
xnn_qc8_gemm_minmax_fp32_ukernel_3x8c8__avx2(size_t mr,size_t nc,size_t kc,const int8_t * restrict a,size_t a_stride,const void * restrict w,int8_t * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_qc8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])3467 void xnn_qc8_gemm_minmax_fp32_ukernel_3x8c8__avx2(
3468 size_t mr,
3469 size_t nc,
3470 size_t kc,
3471 const int8_t* restrict a,
3472 size_t a_stride,
3473 const void* restrict w,
3474 int8_t* restrict c,
3475 size_t cm_stride,
3476 size_t cn_stride,
3477 const union xnn_qc8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
3478 {
3479 assert(mr != 0);
3480 assert(mr <= 3);
3481 assert(nc != 0);
3482 assert(kc != 0);
3483 assert(kc % sizeof(int8_t) == 0);
3484 assert(a != NULL);
3485 assert(w != NULL);
3486 assert(c != NULL);
3487
3488 kc = round_up_po2(kc, 8);
3489 const int8_t* a0 = a;
3490 int8_t* c0 = c;
3491 const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
3492 int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
3493 if XNN_UNPREDICTABLE(mr < 2) {
3494 a1 = a0;
3495 c1 = c0;
3496 }
3497 const int8_t* a2 = (const int8_t*) ((uintptr_t) a1 + a_stride);
3498 int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride);
3499 if XNN_UNPREDICTABLE(mr <= 2) {
3500 a2 = a1;
3501 c2 = c1;
3502 }
3503
3504 do {
3505 const __m128i vbias0x0 = _mm_cvtsi32_si128(((const int*) w)[0]);
3506 const __m128i vbias0x1 = _mm_cvtsi32_si128(((const int*) w)[1]);
3507 __m256i vacc0x01 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x0), vbias0x1, 1);
3508 const __m128i vbias0x2 = _mm_cvtsi32_si128(((const int*) w)[2]);
3509 const __m128i vbias0x3 = _mm_cvtsi32_si128(((const int*) w)[3]);
3510 __m256i vacc0x23 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x2), vbias0x3, 1);
3511 const __m128i vbias0x4 = _mm_cvtsi32_si128(((const int*) w)[4]);
3512 const __m128i vbias0x5 = _mm_cvtsi32_si128(((const int*) w)[5]);
3513 __m256i vacc0x45 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x4), vbias0x5, 1);
3514 const __m128i vbias0x6 = _mm_cvtsi32_si128(((const int*) w)[6]);
3515 const __m128i vbias0x7 = _mm_cvtsi32_si128(((const int*) w)[7]);
3516 __m256i vacc0x67 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x6), vbias0x7, 1);
3517 __m256i vacc1x01 = vacc0x01;
3518 __m256i vacc1x23 = vacc0x23;
3519 __m256i vacc1x45 = vacc0x45;
3520 __m256i vacc1x67 = vacc0x67;
3521 __m256i vacc2x01 = vacc0x01;
3522 __m256i vacc2x23 = vacc0x23;
3523 __m256i vacc2x45 = vacc0x45;
3524 __m256i vacc2x67 = vacc0x67;
3525 w = (const int32_t*) w + 8;
3526
3527 size_t k = 0;
3528 while (k < kc) {
3529 const __m128i va0 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a0));
3530 const __m256i vxa0 = _mm256_cvtepi8_epi16(va0);
3531 a0 += 8;
3532 const __m128i va1 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a1));
3533 const __m256i vxa1 = _mm256_cvtepi8_epi16(va1);
3534 a1 += 8;
3535 const __m128i va2 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a2));
3536 const __m256i vxa2 = _mm256_cvtepi8_epi16(va2);
3537 a2 += 8;
3538
3539 const __m128i vb01 = _mm_load_si128((const __m128i*) w);
3540 const __m256i vxb01 = _mm256_cvtepi8_epi16(vb01);
3541
3542 vacc0x01 = _mm256_add_epi32(vacc0x01, _mm256_madd_epi16(vxa0, vxb01));
3543 vacc1x01 = _mm256_add_epi32(vacc1x01, _mm256_madd_epi16(vxa1, vxb01));
3544 vacc2x01 = _mm256_add_epi32(vacc2x01, _mm256_madd_epi16(vxa2, vxb01));
3545 const __m128i vb23 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 16));
3546 const __m256i vxb23 = _mm256_cvtepi8_epi16(vb23);
3547
3548 vacc0x23 = _mm256_add_epi32(vacc0x23, _mm256_madd_epi16(vxa0, vxb23));
3549 vacc1x23 = _mm256_add_epi32(vacc1x23, _mm256_madd_epi16(vxa1, vxb23));
3550 vacc2x23 = _mm256_add_epi32(vacc2x23, _mm256_madd_epi16(vxa2, vxb23));
3551 const __m128i vb45 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 32));
3552 const __m256i vxb45 = _mm256_cvtepi8_epi16(vb45);
3553
3554 vacc0x45 = _mm256_add_epi32(vacc0x45, _mm256_madd_epi16(vxa0, vxb45));
3555 vacc1x45 = _mm256_add_epi32(vacc1x45, _mm256_madd_epi16(vxa1, vxb45));
3556 vacc2x45 = _mm256_add_epi32(vacc2x45, _mm256_madd_epi16(vxa2, vxb45));
3557 const __m128i vb67 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 48));
3558 const __m256i vxb67 = _mm256_cvtepi8_epi16(vb67);
3559
3560 vacc0x67 = _mm256_add_epi32(vacc0x67, _mm256_madd_epi16(vxa0, vxb67));
3561 vacc1x67 = _mm256_add_epi32(vacc1x67, _mm256_madd_epi16(vxa1, vxb67));
3562 vacc2x67 = _mm256_add_epi32(vacc2x67, _mm256_madd_epi16(vxa2, vxb67));
3563
3564 w = (const void*) ((const int8_t*) w + 64);
3565 k += 8 * sizeof(int8_t);
3566 }
3567
3568 const __m256i vacc0x0213 = _mm256_hadd_epi32(vacc0x01, vacc0x23);
3569 const __m256i vacc0x4657 = _mm256_hadd_epi32(vacc0x45, vacc0x67);
3570 const __m256i vacc1x0213 = _mm256_hadd_epi32(vacc1x01, vacc1x23);
3571 const __m256i vacc1x4657 = _mm256_hadd_epi32(vacc1x45, vacc1x67);
3572 const __m256i vacc2x0213 = _mm256_hadd_epi32(vacc2x01, vacc2x23);
3573 const __m256i vacc2x4657 = _mm256_hadd_epi32(vacc2x45, vacc2x67);
3574
3575 const __m256i vacc0x02461357 = _mm256_hadd_epi32(vacc0x0213, vacc0x4657);
3576 const __m256i vacc1x02461357 = _mm256_hadd_epi32(vacc1x0213, vacc1x4657);
3577 const __m256i vacc2x02461357 = _mm256_hadd_epi32(vacc2x0213, vacc2x4657);
3578
3579 const __m256i vpermute_mask = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
3580 __m256i vacc0x01234567 = _mm256_permutevar8x32_epi32(vacc0x02461357, vpermute_mask);
3581 __m256i vacc1x01234567 = _mm256_permutevar8x32_epi32(vacc1x02461357, vpermute_mask);
3582 __m256i vacc2x01234567 = _mm256_permutevar8x32_epi32(vacc2x02461357, vpermute_mask);
3583
3584 __m256 vscaled0x01234567 = _mm256_cvtepi32_ps(vacc0x01234567);
3585 __m256 vscaled1x01234567 = _mm256_cvtepi32_ps(vacc1x01234567);
3586 __m256 vscaled2x01234567 = _mm256_cvtepi32_ps(vacc2x01234567);
3587
3588 const __m256 vscale01234567 = _mm256_load_ps(w);
3589 w = (const void*) ((const float*) w + 8);
3590 vscaled0x01234567 = _mm256_mul_ps(vscaled0x01234567, vscale01234567);
3591 vscaled1x01234567 = _mm256_mul_ps(vscaled1x01234567, vscale01234567);
3592 vscaled2x01234567 = _mm256_mul_ps(vscaled2x01234567, vscale01234567);
3593
3594 const __m256 voutput_max_less_zero_point = _mm256_load_ps(params->fp32_avx2.output_max_less_zero_point);
3595 vscaled0x01234567 = _mm256_min_ps(vscaled0x01234567, voutput_max_less_zero_point);
3596 vscaled1x01234567 = _mm256_min_ps(vscaled1x01234567, voutput_max_less_zero_point);
3597 vscaled2x01234567 = _mm256_min_ps(vscaled2x01234567, voutput_max_less_zero_point);
3598
3599 vacc0x01234567 = _mm256_cvtps_epi32(vscaled0x01234567);
3600 vacc1x01234567 = _mm256_cvtps_epi32(vscaled1x01234567);
3601 vacc2x01234567 = _mm256_cvtps_epi32(vscaled2x01234567);
3602
3603 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx2.output_zero_point);
3604 __m256i vacc01x01234567 = _mm256_adds_epi16(_mm256_packs_epi32(vacc0x01234567, vacc1x01234567), voutput_zero_point);
3605 __m256i vacc22x01234567 = _mm256_adds_epi16(_mm256_packs_epi32(vacc2x01234567, vacc2x01234567), voutput_zero_point);
3606
3607 vacc01x01234567 = _mm256_permute4x64_epi64(vacc01x01234567, _MM_SHUFFLE(3, 1, 2, 0));
3608 vacc22x01234567 = _mm256_permute4x64_epi64(vacc22x01234567, _MM_SHUFFLE(3, 1, 2, 0));
3609
3610 __m256i vout = _mm256_packs_epi16(vacc01x01234567, vacc22x01234567);
3611
3612 vout = _mm256_max_epi8(vout, _mm256_load_si256((const __m256i*) params->fp32_avx2.output_min));
3613
3614 __m128i vout_lo = _mm256_castsi256_si128(vout);
3615 __m128i vout_hi = _mm256_extracti128_si256(vout, 1);
3616
3617 if (nc >= 8) {
3618 _mm_storel_epi64((__m128i*) c0, vout_lo);
3619 _mm_storel_epi64((__m128i*) c1, vout_hi);
3620 _mm_storeh_pi((__m64*) c2, _mm_castsi128_ps(vout_lo));
3621
3622 c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
3623 c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
3624 c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
3625
3626 a0 = (const int8_t*) ((uintptr_t) a0 - kc);
3627 a1 = (const int8_t*) ((uintptr_t) a1 - kc);
3628 a2 = (const int8_t*) ((uintptr_t) a2 - kc);
3629
3630 nc -= 8;
3631 } else {
3632 if (nc & 4) {
3633 _mm_storeu_si32(c0, vout_lo);
3634 _mm_storeu_si32(c1, vout_hi);
3635 unaligned_store_u32(c2, (uint32_t) _mm_extract_epi32(vout_lo, 2));
3636
3637 c0 += 4;
3638 c1 += 4;
3639 c2 += 4;
3640
3641 vout_lo = _mm_srli_epi64(vout_lo, 32);
3642 vout_hi = _mm_srli_epi64(vout_hi, 32);
3643 }
3644 if (nc & 2) {
3645 unaligned_store_u16(c0, (uint16_t) _mm_extract_epi16(vout_lo, 0));
3646 unaligned_store_u16(c1, (uint16_t) _mm_extract_epi16(vout_hi, 0));
3647 unaligned_store_u16(c2, (uint16_t) _mm_extract_epi16(vout_lo, 4));
3648
3649 c0 += 2;
3650 c1 += 2;
3651 c2 += 2;
3652
3653 vout_lo = _mm_srli_epi32(vout_lo, 16);
3654 vout_hi = _mm_srli_epi32(vout_hi, 16);
3655 }
3656 if (nc & 1) {
3657 *c0 = (int8_t) _mm_extract_epi8(vout_lo, 0);
3658 *c1 = (int8_t) _mm_extract_epi8(vout_hi, 0);
3659 *c2 = (int8_t) _mm_extract_epi8(vout_lo, 8);
3660 }
3661
3662 nc = 0;
3663 }
3664 } while (nc != 0);
3665 }
3666
xnn_qc8_igemm_minmax_fp32_ukernel_1x8c8__avx2(size_t mr,size_t nc,size_t kc,size_t ks,const int8_t ** restrict a,const void * restrict w,int8_t * restrict c,size_t cm_stride,size_t cn_stride,size_t a_offset,const int8_t * zero,const union xnn_qc8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])3667 void xnn_qc8_igemm_minmax_fp32_ukernel_1x8c8__avx2(
3668 size_t mr,
3669 size_t nc,
3670 size_t kc,
3671 size_t ks,
3672 const int8_t** restrict a,
3673 const void* restrict w,
3674 int8_t* restrict c,
3675 size_t cm_stride,
3676 size_t cn_stride,
3677 size_t a_offset,
3678 const int8_t* zero,
3679 const union xnn_qc8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
3680 {
3681 assert(mr != 0);
3682 assert(mr <= 1);
3683 assert(nc != 0);
3684 assert(kc != 0);
3685 assert(ks != 0);
3686 assert(ks % (1 * sizeof(void*)) == 0);
3687 assert(a_offset % sizeof(int8_t) == 0);
3688 assert(a != NULL);
3689 assert(w != NULL);
3690 assert(c != NULL);
3691
3692 kc = round_up_po2(kc, 8);
3693 int8_t* c0 = c;
3694
3695 do {
3696 const __m128i vbias0x0 = _mm_cvtsi32_si128(((const int*) w)[0]);
3697 const __m128i vbias0x1 = _mm_cvtsi32_si128(((const int*) w)[1]);
3698 __m256i vacc0x01 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x0), vbias0x1, 1);
3699 const __m128i vbias0x2 = _mm_cvtsi32_si128(((const int*) w)[2]);
3700 const __m128i vbias0x3 = _mm_cvtsi32_si128(((const int*) w)[3]);
3701 __m256i vacc0x23 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x2), vbias0x3, 1);
3702 const __m128i vbias0x4 = _mm_cvtsi32_si128(((const int*) w)[4]);
3703 const __m128i vbias0x5 = _mm_cvtsi32_si128(((const int*) w)[5]);
3704 __m256i vacc0x45 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x4), vbias0x5, 1);
3705 const __m128i vbias0x6 = _mm_cvtsi32_si128(((const int*) w)[6]);
3706 const __m128i vbias0x7 = _mm_cvtsi32_si128(((const int*) w)[7]);
3707 __m256i vacc0x67 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x6), vbias0x7, 1);
3708 w = (const int32_t*) w + 8;
3709
3710 size_t p = ks;
3711 do {
3712 const int8_t* restrict a0 = a[0];
3713 if XNN_UNPREDICTABLE(a0 != zero) {
3714 a0 = (const int8_t*) ((uintptr_t) a0 + a_offset);
3715 }
3716 a += 1;
3717
3718 size_t k = 0;
3719 while (k < kc) {
3720 const __m128i va0 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a0));
3721 const __m256i vxa0 = _mm256_cvtepi8_epi16(va0);
3722 a0 += 8;
3723
3724 const __m128i vb01 = _mm_load_si128((const __m128i*) w);
3725 const __m256i vxb01 = _mm256_cvtepi8_epi16(vb01);
3726
3727 vacc0x01 = _mm256_add_epi32(vacc0x01, _mm256_madd_epi16(vxa0, vxb01));
3728 const __m128i vb23 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 16));
3729 const __m256i vxb23 = _mm256_cvtepi8_epi16(vb23);
3730
3731 vacc0x23 = _mm256_add_epi32(vacc0x23, _mm256_madd_epi16(vxa0, vxb23));
3732 const __m128i vb45 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 32));
3733 const __m256i vxb45 = _mm256_cvtepi8_epi16(vb45);
3734
3735 vacc0x45 = _mm256_add_epi32(vacc0x45, _mm256_madd_epi16(vxa0, vxb45));
3736 const __m128i vb67 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 48));
3737 const __m256i vxb67 = _mm256_cvtepi8_epi16(vb67);
3738
3739 vacc0x67 = _mm256_add_epi32(vacc0x67, _mm256_madd_epi16(vxa0, vxb67));
3740
3741 w = (const void*) ((const int8_t*) w + 64);
3742 k += 8 * sizeof(int8_t);
3743 }
3744 p -= 1 * sizeof(void*);
3745 } while (p != 0);
3746
3747 const __m256i vacc0x0213 = _mm256_hadd_epi32(vacc0x01, vacc0x23);
3748 const __m256i vacc0x4657 = _mm256_hadd_epi32(vacc0x45, vacc0x67);
3749
3750 const __m256i vacc0x02461357 = _mm256_hadd_epi32(vacc0x0213, vacc0x4657);
3751
3752 const __m256i vpermute_mask = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
3753 __m256i vacc0x01234567 = _mm256_permutevar8x32_epi32(vacc0x02461357, vpermute_mask);
3754
3755 __m256 vscaled0x01234567 = _mm256_cvtepi32_ps(vacc0x01234567);
3756
3757 const __m256 vscale01234567 = _mm256_load_ps(w);
3758 w = (const void*) ((const float*) w + 8);
3759 vscaled0x01234567 = _mm256_mul_ps(vscaled0x01234567, vscale01234567);
3760
3761 const __m256 voutput_max_less_zero_point = _mm256_load_ps(params->fp32_avx2.output_max_less_zero_point);
3762 vscaled0x01234567 = _mm256_min_ps(vscaled0x01234567, voutput_max_less_zero_point);
3763
3764 vacc0x01234567 = _mm256_cvtps_epi32(vscaled0x01234567);
3765
3766 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx2.output_zero_point);
3767 __m256i vacc00x01234567 = _mm256_adds_epi16(_mm256_packs_epi32(vacc0x01234567, vacc0x01234567), voutput_zero_point);
3768
3769 vacc00x01234567 = _mm256_permute4x64_epi64(vacc00x01234567, _MM_SHUFFLE(3, 1, 2, 0));
3770
3771 __m256i vout = _mm256_packs_epi16(vacc00x01234567, vacc00x01234567);
3772
3773 vout = _mm256_max_epi8(vout, _mm256_load_si256((const __m256i*) params->fp32_avx2.output_min));
3774
3775 __m128i vout_lo = _mm256_castsi256_si128(vout);
3776 __m128i vout_hi = _mm256_extracti128_si256(vout, 1);
3777
3778 if (nc >= 8) {
3779 _mm_storel_epi64((__m128i*) c0, vout_lo);
3780
3781 c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
3782
3783 a = (const int8_t**restrict) ((uintptr_t) a - ks);
3784
3785 nc -= 8;
3786 } else {
3787 if (nc & 4) {
3788 _mm_storeu_si32(c0, vout_lo);
3789
3790 c0 += 4;
3791
3792 vout_lo = _mm_srli_epi64(vout_lo, 32);
3793 vout_hi = _mm_srli_epi64(vout_hi, 32);
3794 }
3795 if (nc & 2) {
3796 unaligned_store_u16(c0, (uint16_t) _mm_extract_epi16(vout_lo, 0));
3797
3798 c0 += 2;
3799
3800 vout_lo = _mm_srli_epi32(vout_lo, 16);
3801 vout_hi = _mm_srli_epi32(vout_hi, 16);
3802 }
3803 if (nc & 1) {
3804 *c0 = (int8_t) _mm_extract_epi8(vout_lo, 0);
3805 }
3806
3807 nc = 0;
3808 }
3809 } while (nc != 0);
3810 }
3811
xnn_qc8_igemm_minmax_fp32_ukernel_3x8c8__avx2(size_t mr,size_t nc,size_t kc,size_t ks,const int8_t ** restrict a,const void * restrict w,int8_t * restrict c,size_t cm_stride,size_t cn_stride,size_t a_offset,const int8_t * zero,const union xnn_qc8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])3812 void xnn_qc8_igemm_minmax_fp32_ukernel_3x8c8__avx2(
3813 size_t mr,
3814 size_t nc,
3815 size_t kc,
3816 size_t ks,
3817 const int8_t** restrict a,
3818 const void* restrict w,
3819 int8_t* restrict c,
3820 size_t cm_stride,
3821 size_t cn_stride,
3822 size_t a_offset,
3823 const int8_t* zero,
3824 const union xnn_qc8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
3825 {
3826 assert(mr != 0);
3827 assert(mr <= 3);
3828 assert(nc != 0);
3829 assert(kc != 0);
3830 assert(ks != 0);
3831 assert(ks % (3 * sizeof(void*)) == 0);
3832 assert(a_offset % sizeof(int8_t) == 0);
3833 assert(a != NULL);
3834 assert(w != NULL);
3835 assert(c != NULL);
3836
3837 kc = round_up_po2(kc, 8);
3838 int8_t* c0 = c;
3839 int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
3840 if XNN_UNPREDICTABLE(mr < 2) {
3841 c1 = c0;
3842 }
3843 int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride);
3844 if XNN_UNPREDICTABLE(mr <= 2) {
3845 c2 = c1;
3846 }
3847
3848 do {
3849 const __m128i vbias0x0 = _mm_cvtsi32_si128(((const int*) w)[0]);
3850 const __m128i vbias0x1 = _mm_cvtsi32_si128(((const int*) w)[1]);
3851 __m256i vacc0x01 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x0), vbias0x1, 1);
3852 const __m128i vbias0x2 = _mm_cvtsi32_si128(((const int*) w)[2]);
3853 const __m128i vbias0x3 = _mm_cvtsi32_si128(((const int*) w)[3]);
3854 __m256i vacc0x23 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x2), vbias0x3, 1);
3855 const __m128i vbias0x4 = _mm_cvtsi32_si128(((const int*) w)[4]);
3856 const __m128i vbias0x5 = _mm_cvtsi32_si128(((const int*) w)[5]);
3857 __m256i vacc0x45 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x4), vbias0x5, 1);
3858 const __m128i vbias0x6 = _mm_cvtsi32_si128(((const int*) w)[6]);
3859 const __m128i vbias0x7 = _mm_cvtsi32_si128(((const int*) w)[7]);
3860 __m256i vacc0x67 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x6), vbias0x7, 1);
3861 __m256i vacc1x01 = vacc0x01;
3862 __m256i vacc1x23 = vacc0x23;
3863 __m256i vacc1x45 = vacc0x45;
3864 __m256i vacc1x67 = vacc0x67;
3865 __m256i vacc2x01 = vacc0x01;
3866 __m256i vacc2x23 = vacc0x23;
3867 __m256i vacc2x45 = vacc0x45;
3868 __m256i vacc2x67 = vacc0x67;
3869 w = (const int32_t*) w + 8;
3870
3871 size_t p = ks;
3872 do {
3873 const int8_t* restrict a0 = a[0];
3874 if XNN_UNPREDICTABLE(a0 != zero) {
3875 a0 = (const int8_t*) ((uintptr_t) a0 + a_offset);
3876 }
3877 const int8_t* restrict a1 = a[1];
3878 if XNN_UNPREDICTABLE(a1 != zero) {
3879 a1 = (const int8_t*) ((uintptr_t) a1 + a_offset);
3880 }
3881 const int8_t* restrict a2 = a[2];
3882 if XNN_UNPREDICTABLE(a2 != zero) {
3883 a2 = (const int8_t*) ((uintptr_t) a2 + a_offset);
3884 }
3885 a += 3;
3886
3887 size_t k = 0;
3888 while (k < kc) {
3889 const __m128i va0 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a0));
3890 const __m256i vxa0 = _mm256_cvtepi8_epi16(va0);
3891 a0 += 8;
3892 const __m128i va1 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a1));
3893 const __m256i vxa1 = _mm256_cvtepi8_epi16(va1);
3894 a1 += 8;
3895 const __m128i va2 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a2));
3896 const __m256i vxa2 = _mm256_cvtepi8_epi16(va2);
3897 a2 += 8;
3898
3899 const __m128i vb01 = _mm_load_si128((const __m128i*) w);
3900 const __m256i vxb01 = _mm256_cvtepi8_epi16(vb01);
3901
3902 vacc0x01 = _mm256_add_epi32(vacc0x01, _mm256_madd_epi16(vxa0, vxb01));
3903 vacc1x01 = _mm256_add_epi32(vacc1x01, _mm256_madd_epi16(vxa1, vxb01));
3904 vacc2x01 = _mm256_add_epi32(vacc2x01, _mm256_madd_epi16(vxa2, vxb01));
3905 const __m128i vb23 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 16));
3906 const __m256i vxb23 = _mm256_cvtepi8_epi16(vb23);
3907
3908 vacc0x23 = _mm256_add_epi32(vacc0x23, _mm256_madd_epi16(vxa0, vxb23));
3909 vacc1x23 = _mm256_add_epi32(vacc1x23, _mm256_madd_epi16(vxa1, vxb23));
3910 vacc2x23 = _mm256_add_epi32(vacc2x23, _mm256_madd_epi16(vxa2, vxb23));
3911 const __m128i vb45 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 32));
3912 const __m256i vxb45 = _mm256_cvtepi8_epi16(vb45);
3913
3914 vacc0x45 = _mm256_add_epi32(vacc0x45, _mm256_madd_epi16(vxa0, vxb45));
3915 vacc1x45 = _mm256_add_epi32(vacc1x45, _mm256_madd_epi16(vxa1, vxb45));
3916 vacc2x45 = _mm256_add_epi32(vacc2x45, _mm256_madd_epi16(vxa2, vxb45));
3917 const __m128i vb67 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 48));
3918 const __m256i vxb67 = _mm256_cvtepi8_epi16(vb67);
3919
3920 vacc0x67 = _mm256_add_epi32(vacc0x67, _mm256_madd_epi16(vxa0, vxb67));
3921 vacc1x67 = _mm256_add_epi32(vacc1x67, _mm256_madd_epi16(vxa1, vxb67));
3922 vacc2x67 = _mm256_add_epi32(vacc2x67, _mm256_madd_epi16(vxa2, vxb67));
3923
3924 w = (const void*) ((const int8_t*) w + 64);
3925 k += 8 * sizeof(int8_t);
3926 }
3927 p -= 3 * sizeof(void*);
3928 } while (p != 0);
3929
3930 const __m256i vacc0x0213 = _mm256_hadd_epi32(vacc0x01, vacc0x23);
3931 const __m256i vacc0x4657 = _mm256_hadd_epi32(vacc0x45, vacc0x67);
3932 const __m256i vacc1x0213 = _mm256_hadd_epi32(vacc1x01, vacc1x23);
3933 const __m256i vacc1x4657 = _mm256_hadd_epi32(vacc1x45, vacc1x67);
3934 const __m256i vacc2x0213 = _mm256_hadd_epi32(vacc2x01, vacc2x23);
3935 const __m256i vacc2x4657 = _mm256_hadd_epi32(vacc2x45, vacc2x67);
3936
3937 const __m256i vacc0x02461357 = _mm256_hadd_epi32(vacc0x0213, vacc0x4657);
3938 const __m256i vacc1x02461357 = _mm256_hadd_epi32(vacc1x0213, vacc1x4657);
3939 const __m256i vacc2x02461357 = _mm256_hadd_epi32(vacc2x0213, vacc2x4657);
3940
3941 const __m256i vpermute_mask = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
3942 __m256i vacc0x01234567 = _mm256_permutevar8x32_epi32(vacc0x02461357, vpermute_mask);
3943 __m256i vacc1x01234567 = _mm256_permutevar8x32_epi32(vacc1x02461357, vpermute_mask);
3944 __m256i vacc2x01234567 = _mm256_permutevar8x32_epi32(vacc2x02461357, vpermute_mask);
3945
3946 __m256 vscaled0x01234567 = _mm256_cvtepi32_ps(vacc0x01234567);
3947 __m256 vscaled1x01234567 = _mm256_cvtepi32_ps(vacc1x01234567);
3948 __m256 vscaled2x01234567 = _mm256_cvtepi32_ps(vacc2x01234567);
3949
3950 const __m256 vscale01234567 = _mm256_load_ps(w);
3951 w = (const void*) ((const float*) w + 8);
3952 vscaled0x01234567 = _mm256_mul_ps(vscaled0x01234567, vscale01234567);
3953 vscaled1x01234567 = _mm256_mul_ps(vscaled1x01234567, vscale01234567);
3954 vscaled2x01234567 = _mm256_mul_ps(vscaled2x01234567, vscale01234567);
3955
3956 const __m256 voutput_max_less_zero_point = _mm256_load_ps(params->fp32_avx2.output_max_less_zero_point);
3957 vscaled0x01234567 = _mm256_min_ps(vscaled0x01234567, voutput_max_less_zero_point);
3958 vscaled1x01234567 = _mm256_min_ps(vscaled1x01234567, voutput_max_less_zero_point);
3959 vscaled2x01234567 = _mm256_min_ps(vscaled2x01234567, voutput_max_less_zero_point);
3960
3961 vacc0x01234567 = _mm256_cvtps_epi32(vscaled0x01234567);
3962 vacc1x01234567 = _mm256_cvtps_epi32(vscaled1x01234567);
3963 vacc2x01234567 = _mm256_cvtps_epi32(vscaled2x01234567);
3964
3965 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx2.output_zero_point);
3966 __m256i vacc01x01234567 = _mm256_adds_epi16(_mm256_packs_epi32(vacc0x01234567, vacc1x01234567), voutput_zero_point);
3967 __m256i vacc22x01234567 = _mm256_adds_epi16(_mm256_packs_epi32(vacc2x01234567, vacc2x01234567), voutput_zero_point);
3968
3969 vacc01x01234567 = _mm256_permute4x64_epi64(vacc01x01234567, _MM_SHUFFLE(3, 1, 2, 0));
3970 vacc22x01234567 = _mm256_permute4x64_epi64(vacc22x01234567, _MM_SHUFFLE(3, 1, 2, 0));
3971
3972 __m256i vout = _mm256_packs_epi16(vacc01x01234567, vacc22x01234567);
3973
3974 vout = _mm256_max_epi8(vout, _mm256_load_si256((const __m256i*) params->fp32_avx2.output_min));
3975
3976 __m128i vout_lo = _mm256_castsi256_si128(vout);
3977 __m128i vout_hi = _mm256_extracti128_si256(vout, 1);
3978
3979 if (nc >= 8) {
3980 _mm_storeh_pi((__m64*) c2, _mm_castsi128_ps(vout_lo));
3981 _mm_storel_epi64((__m128i*) c1, vout_hi);
3982 _mm_storel_epi64((__m128i*) c0, vout_lo);
3983
3984 c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
3985 c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
3986 c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
3987
3988 a = (const int8_t**restrict) ((uintptr_t) a - ks);
3989
3990 nc -= 8;
3991 } else {
3992 if (nc & 4) {
3993 unaligned_store_u32(c2, (uint32_t) _mm_extract_epi32(vout_lo, 2));
3994 _mm_storeu_si32(c1, vout_hi);
3995 _mm_storeu_si32(c0, vout_lo);
3996
3997 c2 += 4;
3998 c1 += 4;
3999 c0 += 4;
4000
4001 vout_lo = _mm_srli_epi64(vout_lo, 32);
4002 vout_hi = _mm_srli_epi64(vout_hi, 32);
4003 }
4004 if (nc & 2) {
4005 unaligned_store_u16(c2, (uint16_t) _mm_extract_epi16(vout_lo, 4));
4006 unaligned_store_u16(c1, (uint16_t) _mm_extract_epi16(vout_hi, 0));
4007 unaligned_store_u16(c0, (uint16_t) _mm_extract_epi16(vout_lo, 0));
4008
4009 c2 += 2;
4010 c1 += 2;
4011 c0 += 2;
4012
4013 vout_lo = _mm_srli_epi32(vout_lo, 16);
4014 vout_hi = _mm_srli_epi32(vout_hi, 16);
4015 }
4016 if (nc & 1) {
4017 *c2 = (int8_t) _mm_extract_epi8(vout_lo, 8);
4018 *c1 = (int8_t) _mm_extract_epi8(vout_hi, 0);
4019 *c0 = (int8_t) _mm_extract_epi8(vout_lo, 0);
4020 }
4021
4022 nc = 0;
4023 }
4024 } while (nc != 0);
4025 }
4026
xnn_qs8_dwconv_minmax_fp32_ukernel_up16x25__avx2_mul32(size_t channels,size_t output_width,const int8_t ** input,const void * weights,int8_t * output,size_t input_stride,size_t output_increment,size_t input_offset,const int8_t * zero,const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])4027 void xnn_qs8_dwconv_minmax_fp32_ukernel_up16x25__avx2_mul32(
4028 size_t channels,
4029 size_t output_width,
4030 const int8_t** input,
4031 const void* weights,
4032 int8_t* output,
4033 size_t input_stride,
4034 size_t output_increment,
4035 size_t input_offset,
4036 const int8_t* zero,
4037 const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
4038 {
4039 assert(channels != 0);
4040 assert(output_width != 0);
4041
4042 do {
4043 const int8_t* i0 = input[0];
4044 assert(i0 != NULL);
4045 if XNN_UNPREDICTABLE(i0 != zero) {
4046 i0 = (const int8_t*) ((uintptr_t) i0 + input_offset);
4047 }
4048 const int8_t* i1 = input[1];
4049 assert(i1 != NULL);
4050 if XNN_UNPREDICTABLE(i1 != zero) {
4051 i1 = (const int8_t*) ((uintptr_t) i1 + input_offset);
4052 }
4053 const int8_t* i2 = input[2];
4054 assert(i2 != NULL);
4055 if XNN_UNPREDICTABLE(i2 != zero) {
4056 i2 = (const int8_t*) ((uintptr_t) i2 + input_offset);
4057 }
4058 const int8_t* i3 = input[3];
4059 assert(i3 != NULL);
4060 if XNN_UNPREDICTABLE(i3 != zero) {
4061 i3 = (const int8_t*) ((uintptr_t) i3 + input_offset);
4062 }
4063 const int8_t* i4 = input[4];
4064 assert(i4 != NULL);
4065 if XNN_UNPREDICTABLE(i4 != zero) {
4066 i4 = (const int8_t*) ((uintptr_t) i4 + input_offset);
4067 }
4068 const int8_t* i5 = input[5];
4069 assert(i5 != NULL);
4070 if XNN_UNPREDICTABLE(i5 != zero) {
4071 i5 = (const int8_t*) ((uintptr_t) i5 + input_offset);
4072 }
4073 const int8_t* i6 = input[6];
4074 assert(i6 != NULL);
4075 if XNN_UNPREDICTABLE(i6 != zero) {
4076 i6 = (const int8_t*) ((uintptr_t) i6 + input_offset);
4077 }
4078 const int8_t* i7 = input[7];
4079 assert(i7 != NULL);
4080 if XNN_UNPREDICTABLE(i7 != zero) {
4081 i7 = (const int8_t*) ((uintptr_t) i7 + input_offset);
4082 }
4083 const int8_t* i8 = input[8];
4084 assert(i8 != NULL);
4085 if XNN_UNPREDICTABLE(i8 != zero) {
4086 i8 = (const int8_t*) ((uintptr_t) i8 + input_offset);
4087 }
4088 const int8_t* i9 = input[9];
4089 assert(i9 != NULL);
4090 if XNN_UNPREDICTABLE(i9 != zero) {
4091 i9 = (const int8_t*) ((uintptr_t) i9 + input_offset);
4092 }
4093 const int8_t* i10 = input[10];
4094 assert(i10 != NULL);
4095 if XNN_UNPREDICTABLE(i10 != zero) {
4096 i10 = (const int8_t*) ((uintptr_t) i10 + input_offset);
4097 }
4098 const int8_t* i11 = input[11];
4099 assert(i11 != NULL);
4100 if XNN_UNPREDICTABLE(i11 != zero) {
4101 i11 = (const int8_t*) ((uintptr_t) i11 + input_offset);
4102 }
4103 const int8_t* i12 = input[12];
4104 assert(i12 != NULL);
4105 if XNN_UNPREDICTABLE(i12 != zero) {
4106 i12 = (const int8_t*) ((uintptr_t) i12 + input_offset);
4107 }
4108 const int8_t* i13 = input[13];
4109 assert(i13 != NULL);
4110 if XNN_UNPREDICTABLE(i13 != zero) {
4111 i13 = (const int8_t*) ((uintptr_t) i13 + input_offset);
4112 }
4113 const int8_t* i14 = input[14];
4114 assert(i14 != NULL);
4115 if XNN_UNPREDICTABLE(i14 != zero) {
4116 i14 = (const int8_t*) ((uintptr_t) i14 + input_offset);
4117 }
4118 const int8_t* i15 = input[15];
4119 assert(i15 != NULL);
4120 if XNN_UNPREDICTABLE(i15 != zero) {
4121 i15 = (const int8_t*) ((uintptr_t) i15 + input_offset);
4122 }
4123 const int8_t* i16 = input[16];
4124 assert(i16 != NULL);
4125 if XNN_UNPREDICTABLE(i16 != zero) {
4126 i16 = (const int8_t*) ((uintptr_t) i16 + input_offset);
4127 }
4128 const int8_t* i17 = input[17];
4129 assert(i17 != NULL);
4130 if XNN_UNPREDICTABLE(i17 != zero) {
4131 i17 = (const int8_t*) ((uintptr_t) i17 + input_offset);
4132 }
4133 const int8_t* i18 = input[18];
4134 assert(i18 != NULL);
4135 if XNN_UNPREDICTABLE(i18 != zero) {
4136 i18 = (const int8_t*) ((uintptr_t) i18 + input_offset);
4137 }
4138 const int8_t* i19 = input[19];
4139 assert(i19 != NULL);
4140 if XNN_UNPREDICTABLE(i19 != zero) {
4141 i19 = (const int8_t*) ((uintptr_t) i19 + input_offset);
4142 }
4143 const int8_t* i20 = input[20];
4144 assert(i20 != NULL);
4145 if XNN_UNPREDICTABLE(i20 != zero) {
4146 i20 = (const int8_t*) ((uintptr_t) i20 + input_offset);
4147 }
4148 const int8_t* i21 = input[21];
4149 assert(i21 != NULL);
4150 if XNN_UNPREDICTABLE(i21 != zero) {
4151 i21 = (const int8_t*) ((uintptr_t) i21 + input_offset);
4152 }
4153 const int8_t* i22 = input[22];
4154 assert(i22 != NULL);
4155 if XNN_UNPREDICTABLE(i22 != zero) {
4156 i22 = (const int8_t*) ((uintptr_t) i22 + input_offset);
4157 }
4158 const int8_t* i23 = input[23];
4159 assert(i23 != NULL);
4160 if XNN_UNPREDICTABLE(i23 != zero) {
4161 i23 = (const int8_t*) ((uintptr_t) i23 + input_offset);
4162 }
4163 const int8_t* i24 = input[24];
4164 assert(i24 != NULL);
4165 if XNN_UNPREDICTABLE(i24 != zero) {
4166 i24 = (const int8_t*) ((uintptr_t) i24 + input_offset);
4167 }
4168 input = (const int8_t**) ((uintptr_t) input + input_stride);
4169
4170 size_t c = channels;
4171 const void* w = weights;
4172 for (; c >= 16; c -= 16) {
4173 __m256i vacc01234567 = _mm256_loadu_si256((const __m256i*) w);
4174 __m256i vacc89ABCDEF = _mm256_loadu_si256((const __m256i*) ((const int32_t*) w + 8));
4175
4176
4177 const __m256i vi0x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i0));
4178 const __m256i vk0x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 0 * sizeof(int8_t))));
4179 const __m256i vi0x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i0 + 8)));
4180 const __m256i vk0x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 8 * sizeof(int8_t))));
4181 i0 += 16;
4182
4183 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi0x01234567, vk0x01234567));
4184 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi0x89ABCDEF, vk0x89ABCDEF));
4185
4186 const __m256i vi1x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i1));
4187 const __m256i vk1x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 16 * sizeof(int8_t))));
4188 const __m256i vi1x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i1 + 8)));
4189 const __m256i vk1x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 24 * sizeof(int8_t))));
4190 i1 += 16;
4191
4192 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi1x01234567, vk1x01234567));
4193 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi1x89ABCDEF, vk1x89ABCDEF));
4194
4195 const __m256i vi2x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i2));
4196 const __m256i vk2x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 32 * sizeof(int8_t))));
4197 const __m256i vi2x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i2 + 8)));
4198 const __m256i vk2x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 40 * sizeof(int8_t))));
4199 i2 += 16;
4200
4201 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi2x01234567, vk2x01234567));
4202 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi2x89ABCDEF, vk2x89ABCDEF));
4203
4204 const __m256i vi3x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i3));
4205 const __m256i vk3x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 48 * sizeof(int8_t))));
4206 const __m256i vi3x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i3 + 8)));
4207 const __m256i vk3x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 56 * sizeof(int8_t))));
4208 i3 += 16;
4209
4210 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi3x01234567, vk3x01234567));
4211 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi3x89ABCDEF, vk3x89ABCDEF));
4212
4213 const __m256i vi4x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i4));
4214 const __m256i vk4x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 64 * sizeof(int8_t))));
4215 const __m256i vi4x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i4 + 8)));
4216 const __m256i vk4x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 72 * sizeof(int8_t))));
4217 i4 += 16;
4218
4219 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi4x01234567, vk4x01234567));
4220 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi4x89ABCDEF, vk4x89ABCDEF));
4221
4222 const __m256i vi5x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i5));
4223 const __m256i vk5x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 80 * sizeof(int8_t))));
4224 const __m256i vi5x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i5 + 8)));
4225 const __m256i vk5x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 88 * sizeof(int8_t))));
4226 i5 += 16;
4227
4228 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi5x01234567, vk5x01234567));
4229 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi5x89ABCDEF, vk5x89ABCDEF));
4230
4231 const __m256i vi6x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i6));
4232 const __m256i vk6x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 96 * sizeof(int8_t))));
4233 const __m256i vi6x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i6 + 8)));
4234 const __m256i vk6x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 104 * sizeof(int8_t))));
4235 i6 += 16;
4236
4237 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi6x01234567, vk6x01234567));
4238 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi6x89ABCDEF, vk6x89ABCDEF));
4239
4240 const __m256i vi7x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i7));
4241 const __m256i vk7x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 112 * sizeof(int8_t))));
4242 const __m256i vi7x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i7 + 8)));
4243 const __m256i vk7x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 120 * sizeof(int8_t))));
4244 i7 += 16;
4245
4246 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi7x01234567, vk7x01234567));
4247 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi7x89ABCDEF, vk7x89ABCDEF));
4248
4249 const __m256i vi8x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i8));
4250 const __m256i vk8x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 128 * sizeof(int8_t))));
4251 const __m256i vi8x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i8 + 8)));
4252 const __m256i vk8x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 136 * sizeof(int8_t))));
4253 i8 += 16;
4254
4255 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi8x01234567, vk8x01234567));
4256 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi8x89ABCDEF, vk8x89ABCDEF));
4257
4258 const __m256i vi9x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i9));
4259 const __m256i vk9x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 144 * sizeof(int8_t))));
4260 const __m256i vi9x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i9 + 8)));
4261 const __m256i vk9x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 152 * sizeof(int8_t))));
4262 i9 += 16;
4263
4264 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi9x01234567, vk9x01234567));
4265 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi9x89ABCDEF, vk9x89ABCDEF));
4266
4267 const __m256i vi10x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i10));
4268 const __m256i vk10x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 160 * sizeof(int8_t))));
4269 const __m256i vi10x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i10 + 8)));
4270 const __m256i vk10x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 168 * sizeof(int8_t))));
4271 i10 += 16;
4272
4273 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi10x01234567, vk10x01234567));
4274 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi10x89ABCDEF, vk10x89ABCDEF));
4275
4276 const __m256i vi11x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i11));
4277 const __m256i vk11x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 176 * sizeof(int8_t))));
4278 const __m256i vi11x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i11 + 8)));
4279 const __m256i vk11x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 184 * sizeof(int8_t))));
4280 i11 += 16;
4281
4282 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi11x01234567, vk11x01234567));
4283 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi11x89ABCDEF, vk11x89ABCDEF));
4284
4285 const __m256i vi12x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i12));
4286 const __m256i vk12x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 192 * sizeof(int8_t))));
4287 const __m256i vi12x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i12 + 8)));
4288 const __m256i vk12x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 200 * sizeof(int8_t))));
4289 i12 += 16;
4290
4291 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi12x01234567, vk12x01234567));
4292 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi12x89ABCDEF, vk12x89ABCDEF));
4293
4294 const __m256i vi13x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i13));
4295 const __m256i vk13x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 208 * sizeof(int8_t))));
4296 const __m256i vi13x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i13 + 8)));
4297 const __m256i vk13x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 216 * sizeof(int8_t))));
4298 i13 += 16;
4299
4300 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi13x01234567, vk13x01234567));
4301 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi13x89ABCDEF, vk13x89ABCDEF));
4302
4303 const __m256i vi14x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i14));
4304 const __m256i vk14x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 224 * sizeof(int8_t))));
4305 const __m256i vi14x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i14 + 8)));
4306 const __m256i vk14x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 232 * sizeof(int8_t))));
4307 i14 += 16;
4308
4309 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi14x01234567, vk14x01234567));
4310 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi14x89ABCDEF, vk14x89ABCDEF));
4311
4312 const __m256i vi15x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i15));
4313 const __m256i vk15x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 240 * sizeof(int8_t))));
4314 const __m256i vi15x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i15 + 8)));
4315 const __m256i vk15x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 248 * sizeof(int8_t))));
4316 i15 += 16;
4317
4318 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi15x01234567, vk15x01234567));
4319 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi15x89ABCDEF, vk15x89ABCDEF));
4320
4321 const __m256i vi16x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i16));
4322 const __m256i vk16x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 256 * sizeof(int8_t))));
4323 const __m256i vi16x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i16 + 8)));
4324 const __m256i vk16x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 264 * sizeof(int8_t))));
4325 i16 += 16;
4326
4327 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi16x01234567, vk16x01234567));
4328 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi16x89ABCDEF, vk16x89ABCDEF));
4329
4330 const __m256i vi17x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i17));
4331 const __m256i vk17x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 272 * sizeof(int8_t))));
4332 const __m256i vi17x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i17 + 8)));
4333 const __m256i vk17x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 280 * sizeof(int8_t))));
4334 i17 += 16;
4335
4336 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi17x01234567, vk17x01234567));
4337 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi17x89ABCDEF, vk17x89ABCDEF));
4338
4339 const __m256i vi18x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i18));
4340 const __m256i vk18x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 288 * sizeof(int8_t))));
4341 const __m256i vi18x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i18 + 8)));
4342 const __m256i vk18x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 296 * sizeof(int8_t))));
4343 i18 += 16;
4344
4345 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi18x01234567, vk18x01234567));
4346 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi18x89ABCDEF, vk18x89ABCDEF));
4347
4348 const __m256i vi19x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i19));
4349 const __m256i vk19x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 304 * sizeof(int8_t))));
4350 const __m256i vi19x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i19 + 8)));
4351 const __m256i vk19x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 312 * sizeof(int8_t))));
4352 i19 += 16;
4353
4354 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi19x01234567, vk19x01234567));
4355 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi19x89ABCDEF, vk19x89ABCDEF));
4356
4357 const __m256i vi20x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i20));
4358 const __m256i vk20x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 320 * sizeof(int8_t))));
4359 const __m256i vi20x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i20 + 8)));
4360 const __m256i vk20x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 328 * sizeof(int8_t))));
4361 i20 += 16;
4362
4363 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi20x01234567, vk20x01234567));
4364 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi20x89ABCDEF, vk20x89ABCDEF));
4365
4366 const __m256i vi21x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i21));
4367 const __m256i vk21x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 336 * sizeof(int8_t))));
4368 const __m256i vi21x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i21 + 8)));
4369 const __m256i vk21x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 344 * sizeof(int8_t))));
4370 i21 += 16;
4371
4372 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi21x01234567, vk21x01234567));
4373 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi21x89ABCDEF, vk21x89ABCDEF));
4374
4375 const __m256i vi22x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i22));
4376 const __m256i vk22x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 352 * sizeof(int8_t))));
4377 const __m256i vi22x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i22 + 8)));
4378 const __m256i vk22x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 360 * sizeof(int8_t))));
4379 i22 += 16;
4380
4381 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi22x01234567, vk22x01234567));
4382 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi22x89ABCDEF, vk22x89ABCDEF));
4383
4384 const __m256i vi23x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i23));
4385 const __m256i vk23x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 368 * sizeof(int8_t))));
4386 const __m256i vi23x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i23 + 8)));
4387 const __m256i vk23x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 376 * sizeof(int8_t))));
4388 i23 += 16;
4389
4390 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi23x01234567, vk23x01234567));
4391 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi23x89ABCDEF, vk23x89ABCDEF));
4392
4393 const __m256i vi24x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i24));
4394 const __m256i vk24x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 384 * sizeof(int8_t))));
4395 const __m256i vi24x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i24 + 8)));
4396 const __m256i vk24x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 392 * sizeof(int8_t))));
4397 i24 += 16;
4398
4399 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi24x01234567, vk24x01234567));
4400 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi24x89ABCDEF, vk24x89ABCDEF));
4401
4402 w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t) + 400 * sizeof(int8_t));
4403
4404 __m256 vscaled01234567 = _mm256_cvtepi32_ps(vacc01234567);
4405 __m256 vscaled89ABCDEF = _mm256_cvtepi32_ps(vacc89ABCDEF);
4406
4407 const __m256 vscale = _mm256_load_ps(params->fp32_avx2.scale);
4408 vscaled01234567 = _mm256_mul_ps(vscaled01234567, vscale);
4409 vscaled89ABCDEF = _mm256_mul_ps(vscaled89ABCDEF, vscale);
4410
4411 const __m256 voutput_max_less_zero_point = _mm256_load_ps(params->fp32_avx2.output_max_less_zero_point);
4412 vscaled01234567 = _mm256_min_ps(vscaled01234567, voutput_max_less_zero_point);
4413 vscaled89ABCDEF = _mm256_min_ps(vscaled89ABCDEF, voutput_max_less_zero_point);
4414
4415 vacc01234567 = _mm256_cvtps_epi32(vscaled01234567);
4416 vacc89ABCDEF = _mm256_cvtps_epi32(vscaled89ABCDEF);
4417
4418 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx2.output_zero_point);
4419 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(vacc01234567, vacc89ABCDEF), voutput_zero_point);
4420
4421 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packs_epi16(_mm256_castsi256_si128(vout012389AB4567CDEF), _mm256_extracti128_si256(vout012389AB4567CDEF, 1)), _MM_SHUFFLE(3, 1, 2, 0));
4422
4423 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx2.output_min);
4424 vout0123456789ABCDEF = _mm_max_epi8(vout0123456789ABCDEF, voutput_min);
4425
4426 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
4427 output += 16;
4428 }
4429 if XNN_UNLIKELY(c != 0) {
4430 const int8_t* k = (const int8_t*) ((const int32_t*) w + 16);
4431 do {
4432 __m256i vacc01234567 = _mm256_loadu_si256((const __m256i*) w);
4433
4434
4435 const __m256i vi0x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i0));
4436 const __m256i vk0x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) k));
4437 i0 += 8;
4438
4439 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi0x01234567, vk0x01234567));
4440
4441 const __m256i vi1x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i1));
4442 const __m256i vk1x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 16)));
4443 i1 += 8;
4444
4445 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi1x01234567, vk1x01234567));
4446
4447 const __m256i vi2x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i2));
4448 const __m256i vk2x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 32)));
4449 i2 += 8;
4450
4451 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi2x01234567, vk2x01234567));
4452
4453 const __m256i vi3x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i3));
4454 const __m256i vk3x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 48)));
4455 i3 += 8;
4456
4457 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi3x01234567, vk3x01234567));
4458
4459 const __m256i vi4x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i4));
4460 const __m256i vk4x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 64)));
4461 i4 += 8;
4462
4463 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi4x01234567, vk4x01234567));
4464
4465 const __m256i vi5x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i5));
4466 const __m256i vk5x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 80)));
4467 i5 += 8;
4468
4469 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi5x01234567, vk5x01234567));
4470
4471 const __m256i vi6x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i6));
4472 const __m256i vk6x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 96)));
4473 i6 += 8;
4474
4475 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi6x01234567, vk6x01234567));
4476
4477 const __m256i vi7x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i7));
4478 const __m256i vk7x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 112)));
4479 i7 += 8;
4480
4481 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi7x01234567, vk7x01234567));
4482
4483 const __m256i vi8x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i8));
4484 const __m256i vk8x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 128)));
4485 i8 += 8;
4486
4487 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi8x01234567, vk8x01234567));
4488
4489 const __m256i vi9x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i9));
4490 const __m256i vk9x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 144)));
4491 i9 += 8;
4492
4493 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi9x01234567, vk9x01234567));
4494
4495 const __m256i vi10x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i10));
4496 const __m256i vk10x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 160)));
4497 i10 += 8;
4498
4499 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi10x01234567, vk10x01234567));
4500
4501 const __m256i vi11x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i11));
4502 const __m256i vk11x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 176)));
4503 i11 += 8;
4504
4505 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi11x01234567, vk11x01234567));
4506
4507 const __m256i vi12x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i12));
4508 const __m256i vk12x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 192)));
4509 i12 += 8;
4510
4511 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi12x01234567, vk12x01234567));
4512
4513 const __m256i vi13x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i13));
4514 const __m256i vk13x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 208)));
4515 i13 += 8;
4516
4517 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi13x01234567, vk13x01234567));
4518
4519 const __m256i vi14x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i14));
4520 const __m256i vk14x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 224)));
4521 i14 += 8;
4522
4523 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi14x01234567, vk14x01234567));
4524
4525 const __m256i vi15x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i15));
4526 const __m256i vk15x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 240)));
4527 i15 += 8;
4528
4529 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi15x01234567, vk15x01234567));
4530
4531 const __m256i vi16x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i16));
4532 const __m256i vk16x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 256)));
4533 i16 += 8;
4534
4535 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi16x01234567, vk16x01234567));
4536
4537 const __m256i vi17x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i17));
4538 const __m256i vk17x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 272)));
4539 i17 += 8;
4540
4541 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi17x01234567, vk17x01234567));
4542
4543 const __m256i vi18x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i18));
4544 const __m256i vk18x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 288)));
4545 i18 += 8;
4546
4547 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi18x01234567, vk18x01234567));
4548
4549 const __m256i vi19x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i19));
4550 const __m256i vk19x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 304)));
4551 i19 += 8;
4552
4553 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi19x01234567, vk19x01234567));
4554
4555 const __m256i vi20x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i20));
4556 const __m256i vk20x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 320)));
4557 i20 += 8;
4558
4559 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi20x01234567, vk20x01234567));
4560
4561 const __m256i vi21x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i21));
4562 const __m256i vk21x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 336)));
4563 i21 += 8;
4564
4565 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi21x01234567, vk21x01234567));
4566
4567 const __m256i vi22x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i22));
4568 const __m256i vk22x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 352)));
4569 i22 += 8;
4570
4571 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi22x01234567, vk22x01234567));
4572
4573 const __m256i vi23x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i23));
4574 const __m256i vk23x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 368)));
4575 i23 += 8;
4576
4577 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi23x01234567, vk23x01234567));
4578
4579 const __m256i vi24x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i24));
4580 const __m256i vk24x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 384)));
4581 i24 += 8;
4582
4583 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi24x01234567, vk24x01234567));
4584
4585 k += 8;
4586
4587 __m256 vscaled01234567 = _mm256_cvtepi32_ps(vacc01234567);
4588 vscaled01234567 = _mm256_mul_ps(vscaled01234567, _mm256_load_ps(params->fp32_avx2.scale));
4589 vscaled01234567 = _mm256_min_ps(vscaled01234567, _mm256_load_ps(params->fp32_avx2.output_max_less_zero_point));
4590 vacc01234567 = _mm256_cvtps_epi32(vscaled01234567);
4591
4592 w = (const void*) ((const int32_t*) w + 8);
4593
4594 const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->fp32_avx2.output_zero_point);
4595 __m128i vout01234567 = _mm_adds_epi16(_mm_packs_epi32(_mm256_castsi256_si128(vacc01234567), _mm256_extracti128_si256(vacc01234567, 1)), voutput_zero_point);
4596
4597 __m128i vout0123456701234567 = _mm_packs_epi16(vout01234567, vout01234567);
4598
4599 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx2.output_min);
4600 vout0123456701234567 = _mm_max_epi8(vout0123456701234567, voutput_min);
4601
4602 if XNN_LIKELY(c >= 8) {
4603 _mm_storel_epi64((__m128i*) output, vout0123456701234567);
4604 output += 8;
4605 c -= 8;
4606 } else {
4607 if (c & 4) {
4608 unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vout0123456701234567));
4609 vout0123456701234567 = _mm_srli_epi64(vout0123456701234567, 32);
4610 output += 4;
4611 }
4612 if (c & 2) {
4613 unaligned_store_u16(output, (uint16_t) _mm_extract_epi16(vout0123456701234567, 0));
4614 vout0123456701234567 = _mm_srli_epi32(vout0123456701234567, 16);
4615 output += 2;
4616 }
4617 if (c & 1) {
4618 *output = (int8_t) _mm_extract_epi8(vout0123456701234567, 0);
4619 output += 1;
4620 }
4621 c = 0;
4622 }
4623 } while (c != 0);
4624 }
4625
4626 output = (int8_t*) ((uintptr_t) output + output_increment);
4627 } while (--output_width != 0);
4628 }
4629
xnn_qs8_dwconv_minmax_fp32_ukernel_up16x9__avx2_mul32(size_t channels,size_t output_width,const int8_t ** input,const void * weights,int8_t * output,size_t input_stride,size_t output_increment,size_t input_offset,const int8_t * zero,const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])4630 void xnn_qs8_dwconv_minmax_fp32_ukernel_up16x9__avx2_mul32(
4631 size_t channels,
4632 size_t output_width,
4633 const int8_t** input,
4634 const void* weights,
4635 int8_t* output,
4636 size_t input_stride,
4637 size_t output_increment,
4638 size_t input_offset,
4639 const int8_t* zero,
4640 const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
4641 {
4642 assert(channels != 0);
4643 assert(output_width != 0);
4644
4645 do {
4646 const int8_t* i0 = input[0];
4647 assert(i0 != NULL);
4648 if XNN_UNPREDICTABLE(i0 != zero) {
4649 i0 = (const int8_t*) ((uintptr_t) i0 + input_offset);
4650 }
4651 const int8_t* i1 = input[1];
4652 assert(i1 != NULL);
4653 if XNN_UNPREDICTABLE(i1 != zero) {
4654 i1 = (const int8_t*) ((uintptr_t) i1 + input_offset);
4655 }
4656 const int8_t* i2 = input[2];
4657 assert(i2 != NULL);
4658 if XNN_UNPREDICTABLE(i2 != zero) {
4659 i2 = (const int8_t*) ((uintptr_t) i2 + input_offset);
4660 }
4661 const int8_t* i3 = input[3];
4662 assert(i3 != NULL);
4663 if XNN_UNPREDICTABLE(i3 != zero) {
4664 i3 = (const int8_t*) ((uintptr_t) i3 + input_offset);
4665 }
4666 const int8_t* i4 = input[4];
4667 assert(i4 != NULL);
4668 if XNN_UNPREDICTABLE(i4 != zero) {
4669 i4 = (const int8_t*) ((uintptr_t) i4 + input_offset);
4670 }
4671 const int8_t* i5 = input[5];
4672 assert(i5 != NULL);
4673 if XNN_UNPREDICTABLE(i5 != zero) {
4674 i5 = (const int8_t*) ((uintptr_t) i5 + input_offset);
4675 }
4676 const int8_t* i6 = input[6];
4677 assert(i6 != NULL);
4678 if XNN_UNPREDICTABLE(i6 != zero) {
4679 i6 = (const int8_t*) ((uintptr_t) i6 + input_offset);
4680 }
4681 const int8_t* i7 = input[7];
4682 assert(i7 != NULL);
4683 if XNN_UNPREDICTABLE(i7 != zero) {
4684 i7 = (const int8_t*) ((uintptr_t) i7 + input_offset);
4685 }
4686 const int8_t* i8 = input[8];
4687 assert(i8 != NULL);
4688 if XNN_UNPREDICTABLE(i8 != zero) {
4689 i8 = (const int8_t*) ((uintptr_t) i8 + input_offset);
4690 }
4691 input = (const int8_t**) ((uintptr_t) input + input_stride);
4692
4693 size_t c = channels;
4694 const void* w = weights;
4695 for (; c >= 16; c -= 16) {
4696 __m256i vacc01234567 = _mm256_loadu_si256((const __m256i*) w);
4697 __m256i vacc89ABCDEF = _mm256_loadu_si256((const __m256i*) ((const int32_t*) w + 8));
4698
4699
4700 const __m256i vi0x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i0));
4701 const __m256i vk0x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 0 * sizeof(int8_t))));
4702 const __m256i vi0x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i0 + 8)));
4703 const __m256i vk0x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 8 * sizeof(int8_t))));
4704 i0 += 16;
4705
4706 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi0x01234567, vk0x01234567));
4707 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi0x89ABCDEF, vk0x89ABCDEF));
4708
4709 const __m256i vi1x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i1));
4710 const __m256i vk1x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 16 * sizeof(int8_t))));
4711 const __m256i vi1x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i1 + 8)));
4712 const __m256i vk1x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 24 * sizeof(int8_t))));
4713 i1 += 16;
4714
4715 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi1x01234567, vk1x01234567));
4716 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi1x89ABCDEF, vk1x89ABCDEF));
4717
4718 const __m256i vi2x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i2));
4719 const __m256i vk2x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 32 * sizeof(int8_t))));
4720 const __m256i vi2x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i2 + 8)));
4721 const __m256i vk2x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 40 * sizeof(int8_t))));
4722 i2 += 16;
4723
4724 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi2x01234567, vk2x01234567));
4725 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi2x89ABCDEF, vk2x89ABCDEF));
4726
4727 const __m256i vi3x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i3));
4728 const __m256i vk3x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 48 * sizeof(int8_t))));
4729 const __m256i vi3x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i3 + 8)));
4730 const __m256i vk3x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 56 * sizeof(int8_t))));
4731 i3 += 16;
4732
4733 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi3x01234567, vk3x01234567));
4734 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi3x89ABCDEF, vk3x89ABCDEF));
4735
4736 const __m256i vi4x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i4));
4737 const __m256i vk4x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 64 * sizeof(int8_t))));
4738 const __m256i vi4x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i4 + 8)));
4739 const __m256i vk4x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 72 * sizeof(int8_t))));
4740 i4 += 16;
4741
4742 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi4x01234567, vk4x01234567));
4743 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi4x89ABCDEF, vk4x89ABCDEF));
4744
4745 const __m256i vi5x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i5));
4746 const __m256i vk5x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 80 * sizeof(int8_t))));
4747 const __m256i vi5x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i5 + 8)));
4748 const __m256i vk5x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 88 * sizeof(int8_t))));
4749 i5 += 16;
4750
4751 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi5x01234567, vk5x01234567));
4752 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi5x89ABCDEF, vk5x89ABCDEF));
4753
4754 const __m256i vi6x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i6));
4755 const __m256i vk6x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 96 * sizeof(int8_t))));
4756 const __m256i vi6x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i6 + 8)));
4757 const __m256i vk6x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 104 * sizeof(int8_t))));
4758 i6 += 16;
4759
4760 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi6x01234567, vk6x01234567));
4761 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi6x89ABCDEF, vk6x89ABCDEF));
4762
4763 const __m256i vi7x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i7));
4764 const __m256i vk7x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 112 * sizeof(int8_t))));
4765 const __m256i vi7x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i7 + 8)));
4766 const __m256i vk7x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 120 * sizeof(int8_t))));
4767 i7 += 16;
4768
4769 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi7x01234567, vk7x01234567));
4770 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi7x89ABCDEF, vk7x89ABCDEF));
4771
4772 const __m256i vi8x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i8));
4773 const __m256i vk8x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 128 * sizeof(int8_t))));
4774 const __m256i vi8x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (i8 + 8)));
4775 const __m256i vk8x89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 136 * sizeof(int8_t))));
4776 i8 += 16;
4777
4778 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi8x01234567, vk8x01234567));
4779 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi8x89ABCDEF, vk8x89ABCDEF));
4780
4781 w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t) + 144 * sizeof(int8_t));
4782
4783 __m256 vscaled01234567 = _mm256_cvtepi32_ps(vacc01234567);
4784 __m256 vscaled89ABCDEF = _mm256_cvtepi32_ps(vacc89ABCDEF);
4785
4786 const __m256 vscale = _mm256_load_ps(params->fp32_avx2.scale);
4787 vscaled01234567 = _mm256_mul_ps(vscaled01234567, vscale);
4788 vscaled89ABCDEF = _mm256_mul_ps(vscaled89ABCDEF, vscale);
4789
4790 const __m256 voutput_max_less_zero_point = _mm256_load_ps(params->fp32_avx2.output_max_less_zero_point);
4791 vscaled01234567 = _mm256_min_ps(vscaled01234567, voutput_max_less_zero_point);
4792 vscaled89ABCDEF = _mm256_min_ps(vscaled89ABCDEF, voutput_max_less_zero_point);
4793
4794 vacc01234567 = _mm256_cvtps_epi32(vscaled01234567);
4795 vacc89ABCDEF = _mm256_cvtps_epi32(vscaled89ABCDEF);
4796
4797 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx2.output_zero_point);
4798 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(vacc01234567, vacc89ABCDEF), voutput_zero_point);
4799
4800 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packs_epi16(_mm256_castsi256_si128(vout012389AB4567CDEF), _mm256_extracti128_si256(vout012389AB4567CDEF, 1)), _MM_SHUFFLE(3, 1, 2, 0));
4801
4802 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx2.output_min);
4803 vout0123456789ABCDEF = _mm_max_epi8(vout0123456789ABCDEF, voutput_min);
4804
4805 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
4806 output += 16;
4807 }
4808 if XNN_UNLIKELY(c != 0) {
4809 const int8_t* k = (const int8_t*) ((const int32_t*) w + 16);
4810 do {
4811 __m256i vacc01234567 = _mm256_loadu_si256((const __m256i*) w);
4812
4813
4814 const __m256i vi0x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i0));
4815 const __m256i vk0x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) k));
4816 i0 += 8;
4817
4818 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi0x01234567, vk0x01234567));
4819
4820 const __m256i vi1x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i1));
4821 const __m256i vk1x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 16)));
4822 i1 += 8;
4823
4824 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi1x01234567, vk1x01234567));
4825
4826 const __m256i vi2x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i2));
4827 const __m256i vk2x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 32)));
4828 i2 += 8;
4829
4830 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi2x01234567, vk2x01234567));
4831
4832 const __m256i vi3x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i3));
4833 const __m256i vk3x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 48)));
4834 i3 += 8;
4835
4836 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi3x01234567, vk3x01234567));
4837
4838 const __m256i vi4x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i4));
4839 const __m256i vk4x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 64)));
4840 i4 += 8;
4841
4842 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi4x01234567, vk4x01234567));
4843
4844 const __m256i vi5x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i5));
4845 const __m256i vk5x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 80)));
4846 i5 += 8;
4847
4848 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi5x01234567, vk5x01234567));
4849
4850 const __m256i vi6x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i6));
4851 const __m256i vk6x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 96)));
4852 i6 += 8;
4853
4854 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi6x01234567, vk6x01234567));
4855
4856 const __m256i vi7x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i7));
4857 const __m256i vk7x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 112)));
4858 i7 += 8;
4859
4860 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi7x01234567, vk7x01234567));
4861
4862 const __m256i vi8x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) i8));
4863 const __m256i vk8x01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (k + 128)));
4864 i8 += 8;
4865
4866 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi8x01234567, vk8x01234567));
4867
4868 k += 8;
4869
4870 __m256 vscaled01234567 = _mm256_cvtepi32_ps(vacc01234567);
4871 vscaled01234567 = _mm256_mul_ps(vscaled01234567, _mm256_load_ps(params->fp32_avx2.scale));
4872 vscaled01234567 = _mm256_min_ps(vscaled01234567, _mm256_load_ps(params->fp32_avx2.output_max_less_zero_point));
4873 vacc01234567 = _mm256_cvtps_epi32(vscaled01234567);
4874
4875 w = (const void*) ((const int32_t*) w + 8);
4876
4877 const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->fp32_avx2.output_zero_point);
4878 __m128i vout01234567 = _mm_adds_epi16(_mm_packs_epi32(_mm256_castsi256_si128(vacc01234567), _mm256_extracti128_si256(vacc01234567, 1)), voutput_zero_point);
4879
4880 __m128i vout0123456701234567 = _mm_packs_epi16(vout01234567, vout01234567);
4881
4882 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx2.output_min);
4883 vout0123456701234567 = _mm_max_epi8(vout0123456701234567, voutput_min);
4884
4885 if XNN_LIKELY(c >= 8) {
4886 _mm_storel_epi64((__m128i*) output, vout0123456701234567);
4887 output += 8;
4888 c -= 8;
4889 } else {
4890 if (c & 4) {
4891 unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vout0123456701234567));
4892 vout0123456701234567 = _mm_srli_epi64(vout0123456701234567, 32);
4893 output += 4;
4894 }
4895 if (c & 2) {
4896 unaligned_store_u16(output, (uint16_t) _mm_extract_epi16(vout0123456701234567, 0));
4897 vout0123456701234567 = _mm_srli_epi32(vout0123456701234567, 16);
4898 output += 2;
4899 }
4900 if (c & 1) {
4901 *output = (int8_t) _mm_extract_epi8(vout0123456701234567, 0);
4902 output += 1;
4903 }
4904 c = 0;
4905 }
4906 } while (c != 0);
4907 }
4908
4909 output = (int8_t*) ((uintptr_t) output + output_increment);
4910 } while (--output_width != 0);
4911 }
4912
xnn_qs8_f32_vcvt_ukernel__avx2_x16(size_t n,const int8_t * x,float * y,const union xnn_qs8_f32_cvt_params params[restrict XNN_MIN_ELEMENTS (1)])4913 void xnn_qs8_f32_vcvt_ukernel__avx2_x16(
4914 size_t n,
4915 const int8_t* x,
4916 float* y,
4917 const union xnn_qs8_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
4918 {
4919 assert(n != 0);
4920 assert(n % sizeof(int8_t) == 0);
4921 assert(x != NULL);
4922 assert(y != NULL);
4923
4924 const __m256i vminus_zero_point = _mm256_load_si256((const __m256i*) params->avx.minus_zero_point);
4925 const __m256 vscale = _mm256_load_ps(params->avx.scale);
4926 for (; n >= 16 * sizeof(int8_t); n -= 16 * sizeof(int8_t)) {
4927 __m256i vx01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) x));
4928 __m256i vx89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (x + 8)));
4929 x += 16;
4930
4931 vx01234567 = _mm256_add_epi32(vx01234567, vminus_zero_point);
4932 vx89ABCDEF = _mm256_add_epi32(vx89ABCDEF, vminus_zero_point);
4933
4934 __m256 vy01234567 = _mm256_cvtepi32_ps(vx01234567);
4935 __m256 vy89ABCDEF = _mm256_cvtepi32_ps(vx89ABCDEF);
4936
4937 vy01234567 = _mm256_mul_ps(vy01234567, vscale);
4938 vy89ABCDEF = _mm256_mul_ps(vy89ABCDEF, vscale);
4939
4940 _mm256_storeu_ps(y, vy01234567);
4941 _mm256_storeu_ps(y + 8, vy89ABCDEF);
4942 y += 16;
4943 }
4944 for (; n >= 8 * sizeof(int8_t); n -= 8 * sizeof(int8_t)) {
4945 __m256i vx = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) x));
4946 vx = _mm256_add_epi32(vx, vminus_zero_point);
4947 x += 8;
4948
4949 __m256 vy = _mm256_cvtepi32_ps(vx);
4950 vy = _mm256_mul_ps(vy, vscale);
4951
4952 _mm256_storeu_ps(y, vy);
4953 y += 8;
4954 }
4955 if XNN_UNLIKELY(n != 0) {
4956 assert(n >= 1 * sizeof(int8_t));
4957 assert(n <= 7 * sizeof(int8_t));
4958
4959 __m256i vx = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) x));
4960 vx = _mm256_add_epi32(vx, vminus_zero_point);
4961
4962 __m256 vy = _mm256_cvtepi32_ps(vx);
4963 vy = _mm256_mul_ps(vy, vscale);
4964
4965 __m128 vy_lo = _mm256_castps256_ps128(vy);
4966 if (n & (4 * sizeof(int8_t))) {
4967 _mm_storeu_ps(y, vy_lo);
4968 vy_lo = _mm256_extractf128_ps(vy, 1);
4969 y += 4;
4970 }
4971 if (n & (2 * sizeof(int8_t))) {
4972 _mm_storel_pi((__m64*) y, vy_lo);
4973 vy_lo = _mm_movehl_ps(vy_lo, vy_lo);
4974 y += 2;
4975 }
4976 if (n & (1 * sizeof(int8_t))) {
4977 _mm_store_ss(y, vy_lo);
4978 }
4979 }
4980 }
4981
xnn_qs8_gemm_minmax_fp32_ukernel_1x8c8__avx2(size_t mr,size_t nc,size_t kc,const int8_t * restrict a,size_t a_stride,const void * restrict w,int8_t * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])4982 void xnn_qs8_gemm_minmax_fp32_ukernel_1x8c8__avx2(
4983 size_t mr,
4984 size_t nc,
4985 size_t kc,
4986 const int8_t* restrict a,
4987 size_t a_stride,
4988 const void* restrict w,
4989 int8_t* restrict c,
4990 size_t cm_stride,
4991 size_t cn_stride,
4992 const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
4993 {
4994 assert(mr != 0);
4995 assert(mr <= 1);
4996 assert(nc != 0);
4997 assert(kc != 0);
4998 assert(kc % sizeof(int8_t) == 0);
4999 assert(a != NULL);
5000 assert(w != NULL);
5001 assert(c != NULL);
5002
5003 kc = round_up_po2(kc, 8);
5004 const int8_t* a0 = a;
5005 int8_t* c0 = c;
5006
5007 do {
5008 const __m128i vbias0x0 = _mm_cvtsi32_si128(((const int*) w)[0]);
5009 const __m128i vbias0x1 = _mm_cvtsi32_si128(((const int*) w)[1]);
5010 __m256i vacc0x01 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x0), vbias0x1, 1);
5011 const __m128i vbias0x2 = _mm_cvtsi32_si128(((const int*) w)[2]);
5012 const __m128i vbias0x3 = _mm_cvtsi32_si128(((const int*) w)[3]);
5013 __m256i vacc0x23 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x2), vbias0x3, 1);
5014 const __m128i vbias0x4 = _mm_cvtsi32_si128(((const int*) w)[4]);
5015 const __m128i vbias0x5 = _mm_cvtsi32_si128(((const int*) w)[5]);
5016 __m256i vacc0x45 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x4), vbias0x5, 1);
5017 const __m128i vbias0x6 = _mm_cvtsi32_si128(((const int*) w)[6]);
5018 const __m128i vbias0x7 = _mm_cvtsi32_si128(((const int*) w)[7]);
5019 __m256i vacc0x67 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x6), vbias0x7, 1);
5020 w = (const int32_t*) w + 8;
5021
5022 size_t k = 0;
5023 while (k < kc) {
5024 const __m128i va0 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a0));
5025 const __m256i vxa0 = _mm256_cvtepi8_epi16(va0);
5026 a0 += 8;
5027
5028 const __m128i vb01 = _mm_load_si128((const __m128i*) w);
5029 const __m256i vxb01 = _mm256_cvtepi8_epi16(vb01);
5030
5031 vacc0x01 = _mm256_add_epi32(vacc0x01, _mm256_madd_epi16(vxa0, vxb01));
5032 const __m128i vb23 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 16));
5033 const __m256i vxb23 = _mm256_cvtepi8_epi16(vb23);
5034
5035 vacc0x23 = _mm256_add_epi32(vacc0x23, _mm256_madd_epi16(vxa0, vxb23));
5036 const __m128i vb45 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 32));
5037 const __m256i vxb45 = _mm256_cvtepi8_epi16(vb45);
5038
5039 vacc0x45 = _mm256_add_epi32(vacc0x45, _mm256_madd_epi16(vxa0, vxb45));
5040 const __m128i vb67 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 48));
5041 const __m256i vxb67 = _mm256_cvtepi8_epi16(vb67);
5042
5043 vacc0x67 = _mm256_add_epi32(vacc0x67, _mm256_madd_epi16(vxa0, vxb67));
5044
5045 w = (const void*) ((const int8_t*) w + 64);
5046 k += 8 * sizeof(int8_t);
5047 }
5048
5049 const __m256i vacc0x0213 = _mm256_hadd_epi32(vacc0x01, vacc0x23);
5050 const __m256i vacc0x4657 = _mm256_hadd_epi32(vacc0x45, vacc0x67);
5051
5052 const __m256i vacc0x02461357 = _mm256_hadd_epi32(vacc0x0213, vacc0x4657);
5053
5054 const __m256i vpermute_mask = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
5055 __m256i vacc0x01234567 = _mm256_permutevar8x32_epi32(vacc0x02461357, vpermute_mask);
5056
5057 __m256 vscaled0x01234567 = _mm256_cvtepi32_ps(vacc0x01234567);
5058
5059 const __m256 vscale = _mm256_load_ps(params->fp32_avx2.scale);
5060 vscaled0x01234567 = _mm256_mul_ps(vscaled0x01234567, vscale);
5061
5062 const __m256 voutput_max_less_zero_point = _mm256_load_ps(params->fp32_avx2.output_max_less_zero_point);
5063 vscaled0x01234567 = _mm256_min_ps(vscaled0x01234567, voutput_max_less_zero_point);
5064
5065 vacc0x01234567 = _mm256_cvtps_epi32(vscaled0x01234567);
5066
5067 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx2.output_zero_point);
5068 __m256i vacc00x01234567 = _mm256_adds_epi16(_mm256_packs_epi32(vacc0x01234567, vacc0x01234567), voutput_zero_point);
5069
5070 vacc00x01234567 = _mm256_permute4x64_epi64(vacc00x01234567, _MM_SHUFFLE(3, 1, 2, 0));
5071
5072 __m256i vout = _mm256_packs_epi16(vacc00x01234567, vacc00x01234567);
5073
5074 vout = _mm256_max_epi8(vout, _mm256_load_si256((const __m256i*) params->fp32_avx2.output_min));
5075
5076 __m128i vout_lo = _mm256_castsi256_si128(vout);
5077 __m128i vout_hi = _mm256_extracti128_si256(vout, 1);
5078
5079 if (nc >= 8) {
5080 _mm_storel_epi64((__m128i*) c0, vout_lo);
5081
5082 c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
5083
5084 a0 = (const int8_t*) ((uintptr_t) a0 - kc);
5085
5086 nc -= 8;
5087 } else {
5088 if (nc & 4) {
5089 _mm_storeu_si32(c0, vout_lo);
5090
5091 c0 += 4;
5092
5093 vout_lo = _mm_srli_epi64(vout_lo, 32);
5094 vout_hi = _mm_srli_epi64(vout_hi, 32);
5095 }
5096 if (nc & 2) {
5097 unaligned_store_u16(c0, (uint16_t) _mm_extract_epi16(vout_lo, 0));
5098
5099 c0 += 2;
5100
5101 vout_lo = _mm_srli_epi32(vout_lo, 16);
5102 vout_hi = _mm_srli_epi32(vout_hi, 16);
5103 }
5104 if (nc & 1) {
5105 *c0 = (int8_t) _mm_extract_epi8(vout_lo, 0);
5106 }
5107
5108 nc = 0;
5109 }
5110 } while (nc != 0);
5111 }
5112
xnn_qs8_gemm_minmax_fp32_ukernel_3x8c8__avx2(size_t mr,size_t nc,size_t kc,const int8_t * restrict a,size_t a_stride,const void * restrict w,int8_t * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])5113 void xnn_qs8_gemm_minmax_fp32_ukernel_3x8c8__avx2(
5114 size_t mr,
5115 size_t nc,
5116 size_t kc,
5117 const int8_t* restrict a,
5118 size_t a_stride,
5119 const void* restrict w,
5120 int8_t* restrict c,
5121 size_t cm_stride,
5122 size_t cn_stride,
5123 const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
5124 {
5125 assert(mr != 0);
5126 assert(mr <= 3);
5127 assert(nc != 0);
5128 assert(kc != 0);
5129 assert(kc % sizeof(int8_t) == 0);
5130 assert(a != NULL);
5131 assert(w != NULL);
5132 assert(c != NULL);
5133
5134 kc = round_up_po2(kc, 8);
5135 const int8_t* a0 = a;
5136 int8_t* c0 = c;
5137 const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
5138 int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
5139 if XNN_UNPREDICTABLE(mr < 2) {
5140 a1 = a0;
5141 c1 = c0;
5142 }
5143 const int8_t* a2 = (const int8_t*) ((uintptr_t) a1 + a_stride);
5144 int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride);
5145 if XNN_UNPREDICTABLE(mr <= 2) {
5146 a2 = a1;
5147 c2 = c1;
5148 }
5149
5150 do {
5151 const __m128i vbias0x0 = _mm_cvtsi32_si128(((const int*) w)[0]);
5152 const __m128i vbias0x1 = _mm_cvtsi32_si128(((const int*) w)[1]);
5153 __m256i vacc0x01 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x0), vbias0x1, 1);
5154 const __m128i vbias0x2 = _mm_cvtsi32_si128(((const int*) w)[2]);
5155 const __m128i vbias0x3 = _mm_cvtsi32_si128(((const int*) w)[3]);
5156 __m256i vacc0x23 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x2), vbias0x3, 1);
5157 const __m128i vbias0x4 = _mm_cvtsi32_si128(((const int*) w)[4]);
5158 const __m128i vbias0x5 = _mm_cvtsi32_si128(((const int*) w)[5]);
5159 __m256i vacc0x45 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x4), vbias0x5, 1);
5160 const __m128i vbias0x6 = _mm_cvtsi32_si128(((const int*) w)[6]);
5161 const __m128i vbias0x7 = _mm_cvtsi32_si128(((const int*) w)[7]);
5162 __m256i vacc0x67 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x6), vbias0x7, 1);
5163 __m256i vacc1x01 = vacc0x01;
5164 __m256i vacc1x23 = vacc0x23;
5165 __m256i vacc1x45 = vacc0x45;
5166 __m256i vacc1x67 = vacc0x67;
5167 __m256i vacc2x01 = vacc0x01;
5168 __m256i vacc2x23 = vacc0x23;
5169 __m256i vacc2x45 = vacc0x45;
5170 __m256i vacc2x67 = vacc0x67;
5171 w = (const int32_t*) w + 8;
5172
5173 size_t k = 0;
5174 while (k < kc) {
5175 const __m128i va0 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a0));
5176 const __m256i vxa0 = _mm256_cvtepi8_epi16(va0);
5177 a0 += 8;
5178 const __m128i va1 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a1));
5179 const __m256i vxa1 = _mm256_cvtepi8_epi16(va1);
5180 a1 += 8;
5181 const __m128i va2 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a2));
5182 const __m256i vxa2 = _mm256_cvtepi8_epi16(va2);
5183 a2 += 8;
5184
5185 const __m128i vb01 = _mm_load_si128((const __m128i*) w);
5186 const __m256i vxb01 = _mm256_cvtepi8_epi16(vb01);
5187
5188 vacc0x01 = _mm256_add_epi32(vacc0x01, _mm256_madd_epi16(vxa0, vxb01));
5189 vacc1x01 = _mm256_add_epi32(vacc1x01, _mm256_madd_epi16(vxa1, vxb01));
5190 vacc2x01 = _mm256_add_epi32(vacc2x01, _mm256_madd_epi16(vxa2, vxb01));
5191 const __m128i vb23 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 16));
5192 const __m256i vxb23 = _mm256_cvtepi8_epi16(vb23);
5193
5194 vacc0x23 = _mm256_add_epi32(vacc0x23, _mm256_madd_epi16(vxa0, vxb23));
5195 vacc1x23 = _mm256_add_epi32(vacc1x23, _mm256_madd_epi16(vxa1, vxb23));
5196 vacc2x23 = _mm256_add_epi32(vacc2x23, _mm256_madd_epi16(vxa2, vxb23));
5197 const __m128i vb45 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 32));
5198 const __m256i vxb45 = _mm256_cvtepi8_epi16(vb45);
5199
5200 vacc0x45 = _mm256_add_epi32(vacc0x45, _mm256_madd_epi16(vxa0, vxb45));
5201 vacc1x45 = _mm256_add_epi32(vacc1x45, _mm256_madd_epi16(vxa1, vxb45));
5202 vacc2x45 = _mm256_add_epi32(vacc2x45, _mm256_madd_epi16(vxa2, vxb45));
5203 const __m128i vb67 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 48));
5204 const __m256i vxb67 = _mm256_cvtepi8_epi16(vb67);
5205
5206 vacc0x67 = _mm256_add_epi32(vacc0x67, _mm256_madd_epi16(vxa0, vxb67));
5207 vacc1x67 = _mm256_add_epi32(vacc1x67, _mm256_madd_epi16(vxa1, vxb67));
5208 vacc2x67 = _mm256_add_epi32(vacc2x67, _mm256_madd_epi16(vxa2, vxb67));
5209
5210 w = (const void*) ((const int8_t*) w + 64);
5211 k += 8 * sizeof(int8_t);
5212 }
5213
5214 const __m256i vacc0x0213 = _mm256_hadd_epi32(vacc0x01, vacc0x23);
5215 const __m256i vacc0x4657 = _mm256_hadd_epi32(vacc0x45, vacc0x67);
5216 const __m256i vacc1x0213 = _mm256_hadd_epi32(vacc1x01, vacc1x23);
5217 const __m256i vacc1x4657 = _mm256_hadd_epi32(vacc1x45, vacc1x67);
5218 const __m256i vacc2x0213 = _mm256_hadd_epi32(vacc2x01, vacc2x23);
5219 const __m256i vacc2x4657 = _mm256_hadd_epi32(vacc2x45, vacc2x67);
5220
5221 const __m256i vacc0x02461357 = _mm256_hadd_epi32(vacc0x0213, vacc0x4657);
5222 const __m256i vacc1x02461357 = _mm256_hadd_epi32(vacc1x0213, vacc1x4657);
5223 const __m256i vacc2x02461357 = _mm256_hadd_epi32(vacc2x0213, vacc2x4657);
5224
5225 const __m256i vpermute_mask = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
5226 __m256i vacc0x01234567 = _mm256_permutevar8x32_epi32(vacc0x02461357, vpermute_mask);
5227 __m256i vacc1x01234567 = _mm256_permutevar8x32_epi32(vacc1x02461357, vpermute_mask);
5228 __m256i vacc2x01234567 = _mm256_permutevar8x32_epi32(vacc2x02461357, vpermute_mask);
5229
5230 __m256 vscaled0x01234567 = _mm256_cvtepi32_ps(vacc0x01234567);
5231 __m256 vscaled1x01234567 = _mm256_cvtepi32_ps(vacc1x01234567);
5232 __m256 vscaled2x01234567 = _mm256_cvtepi32_ps(vacc2x01234567);
5233
5234 const __m256 vscale = _mm256_load_ps(params->fp32_avx2.scale);
5235 vscaled0x01234567 = _mm256_mul_ps(vscaled0x01234567, vscale);
5236 vscaled1x01234567 = _mm256_mul_ps(vscaled1x01234567, vscale);
5237 vscaled2x01234567 = _mm256_mul_ps(vscaled2x01234567, vscale);
5238
5239 const __m256 voutput_max_less_zero_point = _mm256_load_ps(params->fp32_avx2.output_max_less_zero_point);
5240 vscaled0x01234567 = _mm256_min_ps(vscaled0x01234567, voutput_max_less_zero_point);
5241 vscaled1x01234567 = _mm256_min_ps(vscaled1x01234567, voutput_max_less_zero_point);
5242 vscaled2x01234567 = _mm256_min_ps(vscaled2x01234567, voutput_max_less_zero_point);
5243
5244 vacc0x01234567 = _mm256_cvtps_epi32(vscaled0x01234567);
5245 vacc1x01234567 = _mm256_cvtps_epi32(vscaled1x01234567);
5246 vacc2x01234567 = _mm256_cvtps_epi32(vscaled2x01234567);
5247
5248 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx2.output_zero_point);
5249 __m256i vacc01x01234567 = _mm256_adds_epi16(_mm256_packs_epi32(vacc0x01234567, vacc1x01234567), voutput_zero_point);
5250 __m256i vacc22x01234567 = _mm256_adds_epi16(_mm256_packs_epi32(vacc2x01234567, vacc2x01234567), voutput_zero_point);
5251
5252 vacc01x01234567 = _mm256_permute4x64_epi64(vacc01x01234567, _MM_SHUFFLE(3, 1, 2, 0));
5253 vacc22x01234567 = _mm256_permute4x64_epi64(vacc22x01234567, _MM_SHUFFLE(3, 1, 2, 0));
5254
5255 __m256i vout = _mm256_packs_epi16(vacc01x01234567, vacc22x01234567);
5256
5257 vout = _mm256_max_epi8(vout, _mm256_load_si256((const __m256i*) params->fp32_avx2.output_min));
5258
5259 __m128i vout_lo = _mm256_castsi256_si128(vout);
5260 __m128i vout_hi = _mm256_extracti128_si256(vout, 1);
5261
5262 if (nc >= 8) {
5263 _mm_storel_epi64((__m128i*) c0, vout_lo);
5264 _mm_storel_epi64((__m128i*) c1, vout_hi);
5265 _mm_storeh_pi((__m64*) c2, _mm_castsi128_ps(vout_lo));
5266
5267 c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
5268 c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
5269 c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
5270
5271 a0 = (const int8_t*) ((uintptr_t) a0 - kc);
5272 a1 = (const int8_t*) ((uintptr_t) a1 - kc);
5273 a2 = (const int8_t*) ((uintptr_t) a2 - kc);
5274
5275 nc -= 8;
5276 } else {
5277 if (nc & 4) {
5278 _mm_storeu_si32(c0, vout_lo);
5279 _mm_storeu_si32(c1, vout_hi);
5280 unaligned_store_u32(c2, (uint32_t) _mm_extract_epi32(vout_lo, 2));
5281
5282 c0 += 4;
5283 c1 += 4;
5284 c2 += 4;
5285
5286 vout_lo = _mm_srli_epi64(vout_lo, 32);
5287 vout_hi = _mm_srli_epi64(vout_hi, 32);
5288 }
5289 if (nc & 2) {
5290 unaligned_store_u16(c0, (uint16_t) _mm_extract_epi16(vout_lo, 0));
5291 unaligned_store_u16(c1, (uint16_t) _mm_extract_epi16(vout_hi, 0));
5292 unaligned_store_u16(c2, (uint16_t) _mm_extract_epi16(vout_lo, 4));
5293
5294 c0 += 2;
5295 c1 += 2;
5296 c2 += 2;
5297
5298 vout_lo = _mm_srli_epi32(vout_lo, 16);
5299 vout_hi = _mm_srli_epi32(vout_hi, 16);
5300 }
5301 if (nc & 1) {
5302 *c0 = (int8_t) _mm_extract_epi8(vout_lo, 0);
5303 *c1 = (int8_t) _mm_extract_epi8(vout_hi, 0);
5304 *c2 = (int8_t) _mm_extract_epi8(vout_lo, 8);
5305 }
5306
5307 nc = 0;
5308 }
5309 } while (nc != 0);
5310 }
5311
xnn_qs8_igemm_minmax_fp32_ukernel_1x8c8__avx2(size_t mr,size_t nc,size_t kc,size_t ks,const int8_t ** restrict a,const void * restrict w,int8_t * restrict c,size_t cm_stride,size_t cn_stride,size_t a_offset,const int8_t * zero,const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])5312 void xnn_qs8_igemm_minmax_fp32_ukernel_1x8c8__avx2(
5313 size_t mr,
5314 size_t nc,
5315 size_t kc,
5316 size_t ks,
5317 const int8_t** restrict a,
5318 const void* restrict w,
5319 int8_t* restrict c,
5320 size_t cm_stride,
5321 size_t cn_stride,
5322 size_t a_offset,
5323 const int8_t* zero,
5324 const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
5325 {
5326 assert(mr != 0);
5327 assert(mr <= 1);
5328 assert(nc != 0);
5329 assert(kc != 0);
5330 assert(ks != 0);
5331 assert(ks % (1 * sizeof(void*)) == 0);
5332 assert(a_offset % sizeof(int8_t) == 0);
5333 assert(a != NULL);
5334 assert(w != NULL);
5335 assert(c != NULL);
5336
5337 kc = round_up_po2(kc, 8);
5338 int8_t* c0 = c;
5339
5340 do {
5341 const __m128i vbias0x0 = _mm_cvtsi32_si128(((const int*) w)[0]);
5342 const __m128i vbias0x1 = _mm_cvtsi32_si128(((const int*) w)[1]);
5343 __m256i vacc0x01 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x0), vbias0x1, 1);
5344 const __m128i vbias0x2 = _mm_cvtsi32_si128(((const int*) w)[2]);
5345 const __m128i vbias0x3 = _mm_cvtsi32_si128(((const int*) w)[3]);
5346 __m256i vacc0x23 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x2), vbias0x3, 1);
5347 const __m128i vbias0x4 = _mm_cvtsi32_si128(((const int*) w)[4]);
5348 const __m128i vbias0x5 = _mm_cvtsi32_si128(((const int*) w)[5]);
5349 __m256i vacc0x45 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x4), vbias0x5, 1);
5350 const __m128i vbias0x6 = _mm_cvtsi32_si128(((const int*) w)[6]);
5351 const __m128i vbias0x7 = _mm_cvtsi32_si128(((const int*) w)[7]);
5352 __m256i vacc0x67 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x6), vbias0x7, 1);
5353 w = (const int32_t*) w + 8;
5354
5355 size_t p = ks;
5356 do {
5357 const int8_t* restrict a0 = a[0];
5358 if XNN_UNPREDICTABLE(a0 != zero) {
5359 a0 = (const int8_t*) ((uintptr_t) a0 + a_offset);
5360 }
5361 a += 1;
5362
5363 size_t k = 0;
5364 while (k < kc) {
5365 const __m128i va0 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a0));
5366 const __m256i vxa0 = _mm256_cvtepi8_epi16(va0);
5367 a0 += 8;
5368
5369 const __m128i vb01 = _mm_load_si128((const __m128i*) w);
5370 const __m256i vxb01 = _mm256_cvtepi8_epi16(vb01);
5371
5372 vacc0x01 = _mm256_add_epi32(vacc0x01, _mm256_madd_epi16(vxa0, vxb01));
5373 const __m128i vb23 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 16));
5374 const __m256i vxb23 = _mm256_cvtepi8_epi16(vb23);
5375
5376 vacc0x23 = _mm256_add_epi32(vacc0x23, _mm256_madd_epi16(vxa0, vxb23));
5377 const __m128i vb45 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 32));
5378 const __m256i vxb45 = _mm256_cvtepi8_epi16(vb45);
5379
5380 vacc0x45 = _mm256_add_epi32(vacc0x45, _mm256_madd_epi16(vxa0, vxb45));
5381 const __m128i vb67 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 48));
5382 const __m256i vxb67 = _mm256_cvtepi8_epi16(vb67);
5383
5384 vacc0x67 = _mm256_add_epi32(vacc0x67, _mm256_madd_epi16(vxa0, vxb67));
5385
5386 w = (const void*) ((const int8_t*) w + 64);
5387 k += 8 * sizeof(int8_t);
5388 }
5389 p -= 1 * sizeof(void*);
5390 } while (p != 0);
5391
5392 const __m256i vacc0x0213 = _mm256_hadd_epi32(vacc0x01, vacc0x23);
5393 const __m256i vacc0x4657 = _mm256_hadd_epi32(vacc0x45, vacc0x67);
5394
5395 const __m256i vacc0x02461357 = _mm256_hadd_epi32(vacc0x0213, vacc0x4657);
5396
5397 const __m256i vpermute_mask = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
5398 __m256i vacc0x01234567 = _mm256_permutevar8x32_epi32(vacc0x02461357, vpermute_mask);
5399
5400 __m256 vscaled0x01234567 = _mm256_cvtepi32_ps(vacc0x01234567);
5401
5402 const __m256 vscale = _mm256_load_ps(params->fp32_avx2.scale);
5403 vscaled0x01234567 = _mm256_mul_ps(vscaled0x01234567, vscale);
5404
5405 const __m256 voutput_max_less_zero_point = _mm256_load_ps(params->fp32_avx2.output_max_less_zero_point);
5406 vscaled0x01234567 = _mm256_min_ps(vscaled0x01234567, voutput_max_less_zero_point);
5407
5408 vacc0x01234567 = _mm256_cvtps_epi32(vscaled0x01234567);
5409
5410 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx2.output_zero_point);
5411 __m256i vacc00x01234567 = _mm256_adds_epi16(_mm256_packs_epi32(vacc0x01234567, vacc0x01234567), voutput_zero_point);
5412
5413 vacc00x01234567 = _mm256_permute4x64_epi64(vacc00x01234567, _MM_SHUFFLE(3, 1, 2, 0));
5414
5415 __m256i vout = _mm256_packs_epi16(vacc00x01234567, vacc00x01234567);
5416
5417 vout = _mm256_max_epi8(vout, _mm256_load_si256((const __m256i*) params->fp32_avx2.output_min));
5418
5419 __m128i vout_lo = _mm256_castsi256_si128(vout);
5420 __m128i vout_hi = _mm256_extracti128_si256(vout, 1);
5421
5422 if (nc >= 8) {
5423 _mm_storel_epi64((__m128i*) c0, vout_lo);
5424
5425 c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
5426
5427 a = (const int8_t**restrict) ((uintptr_t) a - ks);
5428
5429 nc -= 8;
5430 } else {
5431 if (nc & 4) {
5432 _mm_storeu_si32(c0, vout_lo);
5433
5434 c0 += 4;
5435
5436 vout_lo = _mm_srli_epi64(vout_lo, 32);
5437 vout_hi = _mm_srli_epi64(vout_hi, 32);
5438 }
5439 if (nc & 2) {
5440 unaligned_store_u16(c0, (uint16_t) _mm_extract_epi16(vout_lo, 0));
5441
5442 c0 += 2;
5443
5444 vout_lo = _mm_srli_epi32(vout_lo, 16);
5445 vout_hi = _mm_srli_epi32(vout_hi, 16);
5446 }
5447 if (nc & 1) {
5448 *c0 = (int8_t) _mm_extract_epi8(vout_lo, 0);
5449 }
5450
5451 nc = 0;
5452 }
5453 } while (nc != 0);
5454 }
5455
xnn_qs8_igemm_minmax_fp32_ukernel_3x8c8__avx2(size_t mr,size_t nc,size_t kc,size_t ks,const int8_t ** restrict a,const void * restrict w,int8_t * restrict c,size_t cm_stride,size_t cn_stride,size_t a_offset,const int8_t * zero,const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])5456 void xnn_qs8_igemm_minmax_fp32_ukernel_3x8c8__avx2(
5457 size_t mr,
5458 size_t nc,
5459 size_t kc,
5460 size_t ks,
5461 const int8_t** restrict a,
5462 const void* restrict w,
5463 int8_t* restrict c,
5464 size_t cm_stride,
5465 size_t cn_stride,
5466 size_t a_offset,
5467 const int8_t* zero,
5468 const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
5469 {
5470 assert(mr != 0);
5471 assert(mr <= 3);
5472 assert(nc != 0);
5473 assert(kc != 0);
5474 assert(ks != 0);
5475 assert(ks % (3 * sizeof(void*)) == 0);
5476 assert(a_offset % sizeof(int8_t) == 0);
5477 assert(a != NULL);
5478 assert(w != NULL);
5479 assert(c != NULL);
5480
5481 kc = round_up_po2(kc, 8);
5482 int8_t* c0 = c;
5483 int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
5484 if XNN_UNPREDICTABLE(mr < 2) {
5485 c1 = c0;
5486 }
5487 int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride);
5488 if XNN_UNPREDICTABLE(mr <= 2) {
5489 c2 = c1;
5490 }
5491
5492 do {
5493 const __m128i vbias0x0 = _mm_cvtsi32_si128(((const int*) w)[0]);
5494 const __m128i vbias0x1 = _mm_cvtsi32_si128(((const int*) w)[1]);
5495 __m256i vacc0x01 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x0), vbias0x1, 1);
5496 const __m128i vbias0x2 = _mm_cvtsi32_si128(((const int*) w)[2]);
5497 const __m128i vbias0x3 = _mm_cvtsi32_si128(((const int*) w)[3]);
5498 __m256i vacc0x23 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x2), vbias0x3, 1);
5499 const __m128i vbias0x4 = _mm_cvtsi32_si128(((const int*) w)[4]);
5500 const __m128i vbias0x5 = _mm_cvtsi32_si128(((const int*) w)[5]);
5501 __m256i vacc0x45 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x4), vbias0x5, 1);
5502 const __m128i vbias0x6 = _mm_cvtsi32_si128(((const int*) w)[6]);
5503 const __m128i vbias0x7 = _mm_cvtsi32_si128(((const int*) w)[7]);
5504 __m256i vacc0x67 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x6), vbias0x7, 1);
5505 __m256i vacc1x01 = vacc0x01;
5506 __m256i vacc1x23 = vacc0x23;
5507 __m256i vacc1x45 = vacc0x45;
5508 __m256i vacc1x67 = vacc0x67;
5509 __m256i vacc2x01 = vacc0x01;
5510 __m256i vacc2x23 = vacc0x23;
5511 __m256i vacc2x45 = vacc0x45;
5512 __m256i vacc2x67 = vacc0x67;
5513 w = (const int32_t*) w + 8;
5514
5515 size_t p = ks;
5516 do {
5517 const int8_t* restrict a0 = a[0];
5518 if XNN_UNPREDICTABLE(a0 != zero) {
5519 a0 = (const int8_t*) ((uintptr_t) a0 + a_offset);
5520 }
5521 const int8_t* restrict a1 = a[1];
5522 if XNN_UNPREDICTABLE(a1 != zero) {
5523 a1 = (const int8_t*) ((uintptr_t) a1 + a_offset);
5524 }
5525 const int8_t* restrict a2 = a[2];
5526 if XNN_UNPREDICTABLE(a2 != zero) {
5527 a2 = (const int8_t*) ((uintptr_t) a2 + a_offset);
5528 }
5529 a += 3;
5530
5531 size_t k = 0;
5532 while (k < kc) {
5533 const __m128i va0 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a0));
5534 const __m256i vxa0 = _mm256_cvtepi8_epi16(va0);
5535 a0 += 8;
5536 const __m128i va1 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a1));
5537 const __m256i vxa1 = _mm256_cvtepi8_epi16(va1);
5538 a1 += 8;
5539 const __m128i va2 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a2));
5540 const __m256i vxa2 = _mm256_cvtepi8_epi16(va2);
5541 a2 += 8;
5542
5543 const __m128i vb01 = _mm_load_si128((const __m128i*) w);
5544 const __m256i vxb01 = _mm256_cvtepi8_epi16(vb01);
5545
5546 vacc0x01 = _mm256_add_epi32(vacc0x01, _mm256_madd_epi16(vxa0, vxb01));
5547 vacc1x01 = _mm256_add_epi32(vacc1x01, _mm256_madd_epi16(vxa1, vxb01));
5548 vacc2x01 = _mm256_add_epi32(vacc2x01, _mm256_madd_epi16(vxa2, vxb01));
5549 const __m128i vb23 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 16));
5550 const __m256i vxb23 = _mm256_cvtepi8_epi16(vb23);
5551
5552 vacc0x23 = _mm256_add_epi32(vacc0x23, _mm256_madd_epi16(vxa0, vxb23));
5553 vacc1x23 = _mm256_add_epi32(vacc1x23, _mm256_madd_epi16(vxa1, vxb23));
5554 vacc2x23 = _mm256_add_epi32(vacc2x23, _mm256_madd_epi16(vxa2, vxb23));
5555 const __m128i vb45 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 32));
5556 const __m256i vxb45 = _mm256_cvtepi8_epi16(vb45);
5557
5558 vacc0x45 = _mm256_add_epi32(vacc0x45, _mm256_madd_epi16(vxa0, vxb45));
5559 vacc1x45 = _mm256_add_epi32(vacc1x45, _mm256_madd_epi16(vxa1, vxb45));
5560 vacc2x45 = _mm256_add_epi32(vacc2x45, _mm256_madd_epi16(vxa2, vxb45));
5561 const __m128i vb67 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 48));
5562 const __m256i vxb67 = _mm256_cvtepi8_epi16(vb67);
5563
5564 vacc0x67 = _mm256_add_epi32(vacc0x67, _mm256_madd_epi16(vxa0, vxb67));
5565 vacc1x67 = _mm256_add_epi32(vacc1x67, _mm256_madd_epi16(vxa1, vxb67));
5566 vacc2x67 = _mm256_add_epi32(vacc2x67, _mm256_madd_epi16(vxa2, vxb67));
5567
5568 w = (const void*) ((const int8_t*) w + 64);
5569 k += 8 * sizeof(int8_t);
5570 }
5571 p -= 3 * sizeof(void*);
5572 } while (p != 0);
5573
5574 const __m256i vacc0x0213 = _mm256_hadd_epi32(vacc0x01, vacc0x23);
5575 const __m256i vacc0x4657 = _mm256_hadd_epi32(vacc0x45, vacc0x67);
5576 const __m256i vacc1x0213 = _mm256_hadd_epi32(vacc1x01, vacc1x23);
5577 const __m256i vacc1x4657 = _mm256_hadd_epi32(vacc1x45, vacc1x67);
5578 const __m256i vacc2x0213 = _mm256_hadd_epi32(vacc2x01, vacc2x23);
5579 const __m256i vacc2x4657 = _mm256_hadd_epi32(vacc2x45, vacc2x67);
5580
5581 const __m256i vacc0x02461357 = _mm256_hadd_epi32(vacc0x0213, vacc0x4657);
5582 const __m256i vacc1x02461357 = _mm256_hadd_epi32(vacc1x0213, vacc1x4657);
5583 const __m256i vacc2x02461357 = _mm256_hadd_epi32(vacc2x0213, vacc2x4657);
5584
5585 const __m256i vpermute_mask = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
5586 __m256i vacc0x01234567 = _mm256_permutevar8x32_epi32(vacc0x02461357, vpermute_mask);
5587 __m256i vacc1x01234567 = _mm256_permutevar8x32_epi32(vacc1x02461357, vpermute_mask);
5588 __m256i vacc2x01234567 = _mm256_permutevar8x32_epi32(vacc2x02461357, vpermute_mask);
5589
5590 __m256 vscaled0x01234567 = _mm256_cvtepi32_ps(vacc0x01234567);
5591 __m256 vscaled1x01234567 = _mm256_cvtepi32_ps(vacc1x01234567);
5592 __m256 vscaled2x01234567 = _mm256_cvtepi32_ps(vacc2x01234567);
5593
5594 const __m256 vscale = _mm256_load_ps(params->fp32_avx2.scale);
5595 vscaled0x01234567 = _mm256_mul_ps(vscaled0x01234567, vscale);
5596 vscaled1x01234567 = _mm256_mul_ps(vscaled1x01234567, vscale);
5597 vscaled2x01234567 = _mm256_mul_ps(vscaled2x01234567, vscale);
5598
5599 const __m256 voutput_max_less_zero_point = _mm256_load_ps(params->fp32_avx2.output_max_less_zero_point);
5600 vscaled0x01234567 = _mm256_min_ps(vscaled0x01234567, voutput_max_less_zero_point);
5601 vscaled1x01234567 = _mm256_min_ps(vscaled1x01234567, voutput_max_less_zero_point);
5602 vscaled2x01234567 = _mm256_min_ps(vscaled2x01234567, voutput_max_less_zero_point);
5603
5604 vacc0x01234567 = _mm256_cvtps_epi32(vscaled0x01234567);
5605 vacc1x01234567 = _mm256_cvtps_epi32(vscaled1x01234567);
5606 vacc2x01234567 = _mm256_cvtps_epi32(vscaled2x01234567);
5607
5608 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx2.output_zero_point);
5609 __m256i vacc01x01234567 = _mm256_adds_epi16(_mm256_packs_epi32(vacc0x01234567, vacc1x01234567), voutput_zero_point);
5610 __m256i vacc22x01234567 = _mm256_adds_epi16(_mm256_packs_epi32(vacc2x01234567, vacc2x01234567), voutput_zero_point);
5611
5612 vacc01x01234567 = _mm256_permute4x64_epi64(vacc01x01234567, _MM_SHUFFLE(3, 1, 2, 0));
5613 vacc22x01234567 = _mm256_permute4x64_epi64(vacc22x01234567, _MM_SHUFFLE(3, 1, 2, 0));
5614
5615 __m256i vout = _mm256_packs_epi16(vacc01x01234567, vacc22x01234567);
5616
5617 vout = _mm256_max_epi8(vout, _mm256_load_si256((const __m256i*) params->fp32_avx2.output_min));
5618
5619 __m128i vout_lo = _mm256_castsi256_si128(vout);
5620 __m128i vout_hi = _mm256_extracti128_si256(vout, 1);
5621
5622 if (nc >= 8) {
5623 _mm_storeh_pi((__m64*) c2, _mm_castsi128_ps(vout_lo));
5624 _mm_storel_epi64((__m128i*) c1, vout_hi);
5625 _mm_storel_epi64((__m128i*) c0, vout_lo);
5626
5627 c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
5628 c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
5629 c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
5630
5631 a = (const int8_t**restrict) ((uintptr_t) a - ks);
5632
5633 nc -= 8;
5634 } else {
5635 if (nc & 4) {
5636 unaligned_store_u32(c2, (uint32_t) _mm_extract_epi32(vout_lo, 2));
5637 _mm_storeu_si32(c1, vout_hi);
5638 _mm_storeu_si32(c0, vout_lo);
5639
5640 c2 += 4;
5641 c1 += 4;
5642 c0 += 4;
5643
5644 vout_lo = _mm_srli_epi64(vout_lo, 32);
5645 vout_hi = _mm_srli_epi64(vout_hi, 32);
5646 }
5647 if (nc & 2) {
5648 unaligned_store_u16(c2, (uint16_t) _mm_extract_epi16(vout_lo, 4));
5649 unaligned_store_u16(c1, (uint16_t) _mm_extract_epi16(vout_hi, 0));
5650 unaligned_store_u16(c0, (uint16_t) _mm_extract_epi16(vout_lo, 0));
5651
5652 c2 += 2;
5653 c1 += 2;
5654 c0 += 2;
5655
5656 vout_lo = _mm_srli_epi32(vout_lo, 16);
5657 vout_hi = _mm_srli_epi32(vout_hi, 16);
5658 }
5659 if (nc & 1) {
5660 *c2 = (int8_t) _mm_extract_epi8(vout_lo, 8);
5661 *c1 = (int8_t) _mm_extract_epi8(vout_hi, 0);
5662 *c0 = (int8_t) _mm_extract_epi8(vout_lo, 0);
5663 }
5664
5665 nc = 0;
5666 }
5667 } while (nc != 0);
5668 }
5669
xnn_qs8_vadd_minmax_ukernel__avx2_mul32_ld64_x16(size_t n,const int8_t * input_a,const int8_t * input_b,int8_t * output,const union xnn_qs8_add_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])5670 void xnn_qs8_vadd_minmax_ukernel__avx2_mul32_ld64_x16(
5671 size_t n,
5672 const int8_t* input_a,
5673 const int8_t* input_b,
5674 int8_t* output,
5675 const union xnn_qs8_add_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
5676 {
5677 const __m256i vbias = _mm256_load_si256((const __m256i*) params->avx2.bias);
5678 const __m256i va_multiplier = _mm256_load_si256((const __m256i*) params->avx2.a_multiplier);
5679 const __m256i vb_multiplier = _mm256_load_si256((const __m256i*) params->avx2.b_multiplier);
5680 const __m128i vshift = _mm_load_si128((const __m128i*) params->avx2.shift);
5681 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->avx2.output_zero_point);
5682 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->avx2.output_min);
5683 const __m128i voutput_max = _mm_load_si128((const __m128i*) params->avx2.output_max);
5684
5685 for (; n >= 16 * sizeof(int8_t); n -= 16 * sizeof(int8_t)) {
5686 const __m256i va01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) input_a));
5687 const __m256i vb01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) input_b));
5688 const __m256i va89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (input_a + 8)));
5689 const __m256i vb89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (input_b + 8)));
5690 input_a += 16;
5691 input_b += 16;
5692
5693 __m256i vacc01234567 = _mm256_add_epi32(vbias, _mm256_mullo_epi32(va01234567, va_multiplier));
5694 __m256i vacc89ABCDEF = _mm256_add_epi32(vbias, _mm256_mullo_epi32(va89ABCDEF, va_multiplier));
5695
5696 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vb01234567, vb_multiplier));
5697 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vb89ABCDEF, vb_multiplier));
5698
5699 vacc01234567 = _mm256_sra_epi32(vacc01234567, vshift);
5700 vacc89ABCDEF = _mm256_sra_epi32(vacc89ABCDEF, vshift);
5701
5702 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(vacc01234567, vacc89ABCDEF), voutput_zero_point);
5703
5704 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packs_epi16(_mm256_castsi256_si128(vout012389AB4567CDEF), _mm256_extracti128_si256(vout012389AB4567CDEF, 1)), _MM_SHUFFLE(3, 1, 2, 0));
5705
5706 vout0123456789ABCDEF = _mm_max_epi8(vout0123456789ABCDEF, voutput_min);
5707
5708 vout0123456789ABCDEF = _mm_min_epi8(vout0123456789ABCDEF, voutput_max);
5709
5710 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
5711 output += 16;
5712 }
5713 if XNN_UNLIKELY(n != 0) {
5714 do {
5715 const __m256i va01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) input_a));
5716 const __m256i vb01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) input_b));
5717 input_a += 8;
5718 input_b += 8;
5719
5720 __m256i vacc01234567 = _mm256_add_epi32(vbias, _mm256_mullo_epi32(va01234567, va_multiplier));
5721
5722 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vb01234567, vb_multiplier));
5723
5724 vacc01234567 = _mm256_sra_epi32(vacc01234567, vshift);
5725
5726 __m128i vout01234567 = _mm_adds_epi16(_mm_packs_epi32(_mm256_castsi256_si128(vacc01234567), _mm256_extracti128_si256(vacc01234567, 1)), _mm256_castsi256_si128(voutput_zero_point));
5727 __m128i vout0123456701234567 = _mm_packs_epi16(vout01234567, vout01234567);
5728 vout0123456701234567 = _mm_max_epi8(vout0123456701234567, voutput_min);
5729 vout0123456701234567 = _mm_min_epi8(vout0123456701234567, voutput_max);
5730
5731 if XNN_LIKELY(n >= (8 * sizeof(int8_t))) {
5732 _mm_storel_epi64((__m128i*) output, vout0123456701234567);
5733 output += 8;
5734 n -= 8 * sizeof(int8_t);
5735 } else {
5736 if (n & (4 * sizeof(int8_t))) {
5737 _mm_storeu_si32(output, vout0123456701234567);
5738 vout0123456701234567 = _mm_srli_epi64(vout0123456701234567, 32);
5739 output += 4;
5740 }
5741 if (n & (2 * sizeof(int8_t))) {
5742 _mm_storeu_si16(output, vout0123456701234567);
5743 vout0123456701234567 = _mm_srli_epi32(vout0123456701234567, 16);
5744 output += 2;
5745 }
5746 if (n & (1 * sizeof(int8_t))) {
5747 *output = (int8_t) _mm_extract_epi8(vout0123456701234567, 0);
5748 }
5749 n = 0;
5750 }
5751 } while (n != 0);
5752 }
5753 }
5754
xnn_qs8_vaddc_minmax_ukernel__avx2_mul32_ld64_x16(size_t n,const int8_t * input_a,const int8_t * input_b,int8_t * output,const union xnn_qs8_add_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])5755 void xnn_qs8_vaddc_minmax_ukernel__avx2_mul32_ld64_x16(
5756 size_t n,
5757 const int8_t* input_a,
5758 const int8_t* input_b,
5759 int8_t* output,
5760 const union xnn_qs8_add_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
5761 {
5762 const __m256i va_multiplier = _mm256_load_si256((const __m256i*) params->avx2.a_multiplier);
5763 const __m128i vshift = _mm_load_si128((const __m128i*) params->avx2.shift);
5764 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->avx2.output_zero_point);
5765 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->avx2.output_min);
5766 const __m128i voutput_max = _mm_load_si128((const __m128i*) params->avx2.output_max);
5767
5768 const __m256i vbias = _mm256_add_epi32(
5769 _mm256_broadcastd_epi32(_mm_cvtsi32_si128(params->avx2.b_multiplier[0] * (int32_t) *input_b)),
5770 _mm256_load_si256((const __m256i*) params->avx2.bias));
5771 for (; n >= 16 * sizeof(int8_t); n -= 16 * sizeof(int8_t)) {
5772 const __m256i va01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) input_a));
5773 const __m256i va89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) (input_a + 8)));
5774 input_a += 16;
5775
5776 __m256i vacc01234567 = _mm256_add_epi32(vbias, _mm256_mullo_epi32(va01234567, va_multiplier));
5777 __m256i vacc89ABCDEF = _mm256_add_epi32(vbias, _mm256_mullo_epi32(va89ABCDEF, va_multiplier));
5778
5779 vacc01234567 = _mm256_sra_epi32(vacc01234567, vshift);
5780 vacc89ABCDEF = _mm256_sra_epi32(vacc89ABCDEF, vshift);
5781
5782 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(vacc01234567, vacc89ABCDEF), voutput_zero_point);
5783
5784 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packs_epi16(_mm256_castsi256_si128(vout012389AB4567CDEF), _mm256_extracti128_si256(vout012389AB4567CDEF, 1)), _MM_SHUFFLE(3, 1, 2, 0));
5785
5786 vout0123456789ABCDEF = _mm_max_epi8(vout0123456789ABCDEF, voutput_min);
5787
5788 vout0123456789ABCDEF = _mm_min_epi8(vout0123456789ABCDEF, voutput_max);
5789
5790 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
5791 output += 16;
5792 }
5793 if XNN_UNLIKELY(n != 0) {
5794 do {
5795 const __m256i va01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) input_a));
5796 input_a += 8;
5797
5798 __m256i vacc01234567 = _mm256_add_epi32(vbias, _mm256_mullo_epi32(va01234567, va_multiplier));
5799
5800 vacc01234567 = _mm256_sra_epi32(vacc01234567, vshift);
5801
5802 __m128i vout01234567 = _mm_adds_epi16(_mm_packs_epi32(_mm256_castsi256_si128(vacc01234567), _mm256_extracti128_si256(vacc01234567, 1)), _mm256_castsi256_si128(voutput_zero_point));
5803 __m128i vout0123456701234567 = _mm_packs_epi16(vout01234567, vout01234567);
5804 vout0123456701234567 = _mm_max_epi8(vout0123456701234567, voutput_min);
5805 vout0123456701234567 = _mm_min_epi8(vout0123456701234567, voutput_max);
5806
5807 if XNN_LIKELY(n >= (8 * sizeof(int8_t))) {
5808 _mm_storel_epi64((__m128i*) output, vout0123456701234567);
5809 output += 8;
5810 n -= 8 * sizeof(int8_t);
5811 } else {
5812 if (n & (4 * sizeof(int8_t))) {
5813 _mm_storeu_si32(output, vout0123456701234567);
5814 vout0123456701234567 = _mm_srli_epi64(vout0123456701234567, 32);
5815 output += 4;
5816 }
5817 if (n & (2 * sizeof(int8_t))) {
5818 _mm_storeu_si16(output, vout0123456701234567);
5819 vout0123456701234567 = _mm_srli_epi32(vout0123456701234567, 16);
5820 output += 2;
5821 }
5822 if (n & (1 * sizeof(int8_t))) {
5823 *output = (int8_t) _mm_extract_epi8(vout0123456701234567, 0);
5824 }
5825 n = 0;
5826 }
5827 } while (n != 0);
5828 }
5829 }
5830
xnn_qs8_vcvt_ukernel__avx2_x32(size_t n,const int8_t * x,int8_t * y,const union xnn_qs8_cvt_params params[restrict XNN_MIN_ELEMENTS (1)])5831 void xnn_qs8_vcvt_ukernel__avx2_x32(
5832 size_t n,
5833 const int8_t* x,
5834 int8_t* y,
5835 const union xnn_qs8_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
5836 {
5837 assert(n != 0);
5838 assert(n % sizeof(int8_t) == 0);
5839 assert(x != NULL);
5840 assert(y != NULL);
5841
5842 const __m256i vinput_zero_point = _mm256_load_si256((const __m256i*) params->avx2.input_zero_point);
5843 const __m256i vmultiplier = _mm256_load_si256((const __m256i*) params->avx2.multiplier);
5844 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->avx2.output_zero_point);
5845 for (; n >= 32 * sizeof(int8_t); n -= 32 * sizeof(int8_t)) {
5846 __m256i vacc0 = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) x));
5847 __m256i vacc1 = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) (x + 16)));
5848 x += 32;
5849
5850 vacc0 = _mm256_sub_epi16(vinput_zero_point, vacc0);
5851 vacc1 = _mm256_sub_epi16(vinput_zero_point, vacc1);
5852
5853 vacc0 = _mm256_slli_epi16(vacc0, 7);
5854 vacc1 = _mm256_slli_epi16(vacc1, 7);
5855
5856 vacc0 = _mm256_mulhrs_epi16(vacc0, vmultiplier);
5857 vacc1 = _mm256_mulhrs_epi16(vacc1, vmultiplier);
5858
5859 vacc0 = _mm256_adds_epi16(vacc0, voutput_zero_point);
5860 vacc1 = _mm256_adds_epi16(vacc1, voutput_zero_point);
5861
5862 __m256i vy0 = _mm256_packs_epi16(vacc0, vacc1);
5863
5864 vy0 = _mm256_permute4x64_epi64(vy0, _MM_SHUFFLE(3, 1, 2, 0));
5865
5866 _mm256_storeu_si256((__m256i*) y, vy0);
5867 y += 32;
5868 }
5869 for (; n >= 16 * sizeof(int8_t); n -= 16 * sizeof(int8_t)) {
5870 __m256i vacc = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) x));
5871 vacc = _mm256_sub_epi16(vinput_zero_point, vacc);
5872 vacc = _mm256_slli_epi16(vacc, 7);
5873 vacc = _mm256_mulhrs_epi16(vacc, vmultiplier);
5874 vacc = _mm256_adds_epi16(vacc, voutput_zero_point);
5875 x += 16;
5876
5877 const __m128i vacc_hi = _mm256_extracti128_si256(vacc, 1);
5878 const __m128i vy = _mm_packs_epi16(_mm256_castsi256_si128(vacc), vacc_hi);
5879 _mm_storeu_si128((__m128i*) y, vy);
5880 y += 16;
5881 }
5882 if XNN_UNLIKELY(n != 0) {
5883 assert(n >= 1 * sizeof(int8_t));
5884 assert(n <= 15 * sizeof(int8_t));
5885
5886 __m256i vacc = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) x));
5887 vacc = _mm256_sub_epi16(vinput_zero_point, vacc);
5888 vacc = _mm256_slli_epi16(vacc, 7);
5889 vacc = _mm256_mulhrs_epi16(vacc, vmultiplier);
5890 vacc = _mm256_adds_epi16(vacc, voutput_zero_point);
5891
5892 const __m128i vacc_hi = _mm256_extracti128_si256(vacc, 1);
5893 __m128i vy = _mm_packs_epi16(_mm256_castsi256_si128(vacc), vacc_hi);
5894 if (n & (8 * sizeof(int8_t))) {
5895 _mm_storel_epi64((__m128i*) y, vy);
5896 vy = _mm_unpackhi_epi64(vy, vy);
5897 y += 8;
5898 }
5899 if (n & (4 * sizeof(int8_t))) {
5900 _mm_storeu_si32(y, vy);
5901 vy = _mm_srli_epi64(vy, 32);
5902 y += 4;
5903 }
5904 if (n & (2 * sizeof(int8_t))) {
5905 _mm_storeu_si16(y, vy);
5906 vy = _mm_srli_epi32(vy, 16);
5907 y += 2;
5908 }
5909 if (n & (1 * sizeof(int8_t))) {
5910 *y = (int8_t) _mm_extract_epi8(vy, 0);
5911 }
5912 }
5913 }
5914
xnn_qs8_vlrelu_ukernel__avx2_x32(size_t n,const int8_t * x,int8_t * y,const union xnn_qs8_lrelu_params params[restrict XNN_MIN_ELEMENTS (1)])5915 void xnn_qs8_vlrelu_ukernel__avx2_x32(
5916 size_t n,
5917 const int8_t* x,
5918 int8_t* y,
5919 const union xnn_qs8_lrelu_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
5920 {
5921 assert(n != 0);
5922 assert(n % sizeof(int8_t) == 0);
5923 assert(x != NULL);
5924 assert(y != NULL);
5925
5926 const __m256i vinput_zero_point = _mm256_load_si256((const __m256i*) params->avx2.input_zero_point);
5927 const __m256i vpositive_multiplier = _mm256_load_si256((const __m256i*) params->avx2.positive_multiplier);
5928 const __m256i vnegative_multiplier = _mm256_load_si256((const __m256i*) params->avx2.negative_multiplier);
5929 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->avx2.output_zero_point);
5930 for (; n >= 32 * sizeof(int8_t); n -= 32 * sizeof(int8_t)) {
5931 __m256i vacc0 = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) x));
5932 __m256i vacc1 = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) (x + 16)));
5933 x += 32;
5934
5935 __m256i vmultiplier0 = _mm256_cmpgt_epi16(vacc0, vinput_zero_point);
5936 vacc0 = _mm256_sub_epi16(vinput_zero_point, vacc0);
5937 __m256i vmultiplier1 = _mm256_cmpgt_epi16(vacc1, vinput_zero_point);
5938 vacc1 = _mm256_sub_epi16(vinput_zero_point, vacc1);
5939
5940 vmultiplier0 = _mm256_blendv_epi8(vnegative_multiplier, vpositive_multiplier, vmultiplier0);
5941 vacc0 = _mm256_slli_epi16(vacc0, 7);
5942 vmultiplier1 = _mm256_blendv_epi8(vnegative_multiplier, vpositive_multiplier, vmultiplier1);
5943 vacc1 = _mm256_slli_epi16(vacc1, 7);
5944
5945 vacc0 = _mm256_mulhrs_epi16(vacc0, vmultiplier0);
5946 vacc1 = _mm256_mulhrs_epi16(vacc1, vmultiplier1);
5947
5948 vacc0 = _mm256_adds_epi16(vacc0, voutput_zero_point);
5949 vacc1 = _mm256_adds_epi16(vacc1, voutput_zero_point);
5950
5951 __m256i vy0 = _mm256_packs_epi16(vacc0, vacc1);
5952
5953 vy0 = _mm256_permute4x64_epi64(vy0, _MM_SHUFFLE(3, 1, 2, 0));
5954
5955 _mm256_storeu_si256((__m256i*) y, vy0);
5956 y += 32;
5957 }
5958 for (; n >= 16 * sizeof(int8_t); n -= 16 * sizeof(int8_t)) {
5959 __m256i vacc = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) x));
5960 __m256i vmultiplier = _mm256_cmpgt_epi16(vacc, vinput_zero_point);
5961 vacc = _mm256_sub_epi16(vinput_zero_point, vacc);
5962 vmultiplier = _mm256_blendv_epi8(vnegative_multiplier, vpositive_multiplier, vmultiplier);
5963 vacc = _mm256_slli_epi16(vacc, 7);
5964 vacc = _mm256_mulhrs_epi16(vacc, vmultiplier);
5965 vacc = _mm256_adds_epi16(vacc, voutput_zero_point);
5966 x += 16;
5967
5968 const __m128i vacc_hi = _mm256_extracti128_si256(vacc, 1);
5969 const __m128i vy = _mm_packs_epi16(_mm256_castsi256_si128(vacc), vacc_hi);
5970 _mm_storeu_si128((__m128i*) y, vy);
5971 y += 16;
5972 }
5973 if XNN_UNLIKELY(n != 0) {
5974 assert(n >= 1 * sizeof(int8_t));
5975 assert(n <= 15 * sizeof(int8_t));
5976
5977 __m256i vacc = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) x));
5978 __m256i vmultiplier = _mm256_cmpgt_epi16(vacc, vinput_zero_point);
5979 vacc = _mm256_sub_epi16(vinput_zero_point, vacc);
5980 vmultiplier = _mm256_blendv_epi8(vnegative_multiplier, vpositive_multiplier, vmultiplier);
5981 vacc = _mm256_slli_epi16(vacc, 7);
5982 vacc = _mm256_mulhrs_epi16(vacc, vmultiplier);
5983 vacc = _mm256_adds_epi16(vacc, voutput_zero_point);
5984
5985 const __m128i vacc_hi = _mm256_extracti128_si256(vacc, 1);
5986 __m128i vy = _mm_packs_epi16(_mm256_castsi256_si128(vacc), vacc_hi);
5987 if (n & (8 * sizeof(int8_t))) {
5988 _mm_storel_epi64((__m128i*) y, vy);
5989 vy = _mm_unpackhi_epi64(vy, vy);
5990 y += 8;
5991 }
5992 if (n & (4 * sizeof(int8_t))) {
5993 _mm_storeu_si32(y, vy);
5994 vy = _mm_srli_epi64(vy, 32);
5995 y += 4;
5996 }
5997 if (n & (2 * sizeof(int8_t))) {
5998 _mm_storeu_si16(y, vy);
5999 vy = _mm_srli_epi32(vy, 16);
6000 y += 2;
6001 }
6002 if (n & (1 * sizeof(int8_t))) {
6003 *y = (int8_t) _mm_extract_epi8(vy, 0);
6004 }
6005 }
6006 }
6007
xnn_qu8_dwconv_minmax_fp32_ukernel_up16x25__avx2_mul32(size_t channels,size_t output_width,const uint8_t ** input,const void * weights,uint8_t * output,size_t input_stride,size_t output_increment,size_t input_offset,const uint8_t * zero,const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])6008 void xnn_qu8_dwconv_minmax_fp32_ukernel_up16x25__avx2_mul32(
6009 size_t channels,
6010 size_t output_width,
6011 const uint8_t** input,
6012 const void* weights,
6013 uint8_t* output,
6014 size_t input_stride,
6015 size_t output_increment,
6016 size_t input_offset,
6017 const uint8_t* zero,
6018 const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
6019 {
6020 assert(channels != 0);
6021 assert(output_width != 0);
6022
6023 const __m256i vk_zero_point = _mm256_cvtepu16_epi32(_mm_load_si128((const __m128i*) params->fp32_avx2.kernel_zero_point));
6024 do {
6025 const uint8_t* i0 = input[0];
6026 assert(i0 != NULL);
6027 if XNN_UNPREDICTABLE(i0 != zero) {
6028 i0 = (const uint8_t*) ((uintptr_t) i0 + input_offset);
6029 }
6030 const uint8_t* i1 = input[1];
6031 assert(i1 != NULL);
6032 if XNN_UNPREDICTABLE(i1 != zero) {
6033 i1 = (const uint8_t*) ((uintptr_t) i1 + input_offset);
6034 }
6035 const uint8_t* i2 = input[2];
6036 assert(i2 != NULL);
6037 if XNN_UNPREDICTABLE(i2 != zero) {
6038 i2 = (const uint8_t*) ((uintptr_t) i2 + input_offset);
6039 }
6040 const uint8_t* i3 = input[3];
6041 assert(i3 != NULL);
6042 if XNN_UNPREDICTABLE(i3 != zero) {
6043 i3 = (const uint8_t*) ((uintptr_t) i3 + input_offset);
6044 }
6045 const uint8_t* i4 = input[4];
6046 assert(i4 != NULL);
6047 if XNN_UNPREDICTABLE(i4 != zero) {
6048 i4 = (const uint8_t*) ((uintptr_t) i4 + input_offset);
6049 }
6050 const uint8_t* i5 = input[5];
6051 assert(i5 != NULL);
6052 if XNN_UNPREDICTABLE(i5 != zero) {
6053 i5 = (const uint8_t*) ((uintptr_t) i5 + input_offset);
6054 }
6055 const uint8_t* i6 = input[6];
6056 assert(i6 != NULL);
6057 if XNN_UNPREDICTABLE(i6 != zero) {
6058 i6 = (const uint8_t*) ((uintptr_t) i6 + input_offset);
6059 }
6060 const uint8_t* i7 = input[7];
6061 assert(i7 != NULL);
6062 if XNN_UNPREDICTABLE(i7 != zero) {
6063 i7 = (const uint8_t*) ((uintptr_t) i7 + input_offset);
6064 }
6065 const uint8_t* i8 = input[8];
6066 assert(i8 != NULL);
6067 if XNN_UNPREDICTABLE(i8 != zero) {
6068 i8 = (const uint8_t*) ((uintptr_t) i8 + input_offset);
6069 }
6070 const uint8_t* i9 = input[9];
6071 assert(i9 != NULL);
6072 if XNN_UNPREDICTABLE(i9 != zero) {
6073 i9 = (const uint8_t*) ((uintptr_t) i9 + input_offset);
6074 }
6075 const uint8_t* i10 = input[10];
6076 assert(i10 != NULL);
6077 if XNN_UNPREDICTABLE(i10 != zero) {
6078 i10 = (const uint8_t*) ((uintptr_t) i10 + input_offset);
6079 }
6080 const uint8_t* i11 = input[11];
6081 assert(i11 != NULL);
6082 if XNN_UNPREDICTABLE(i11 != zero) {
6083 i11 = (const uint8_t*) ((uintptr_t) i11 + input_offset);
6084 }
6085 const uint8_t* i12 = input[12];
6086 assert(i12 != NULL);
6087 if XNN_UNPREDICTABLE(i12 != zero) {
6088 i12 = (const uint8_t*) ((uintptr_t) i12 + input_offset);
6089 }
6090 const uint8_t* i13 = input[13];
6091 assert(i13 != NULL);
6092 if XNN_UNPREDICTABLE(i13 != zero) {
6093 i13 = (const uint8_t*) ((uintptr_t) i13 + input_offset);
6094 }
6095 const uint8_t* i14 = input[14];
6096 assert(i14 != NULL);
6097 if XNN_UNPREDICTABLE(i14 != zero) {
6098 i14 = (const uint8_t*) ((uintptr_t) i14 + input_offset);
6099 }
6100 const uint8_t* i15 = input[15];
6101 assert(i15 != NULL);
6102 if XNN_UNPREDICTABLE(i15 != zero) {
6103 i15 = (const uint8_t*) ((uintptr_t) i15 + input_offset);
6104 }
6105 const uint8_t* i16 = input[16];
6106 assert(i16 != NULL);
6107 if XNN_UNPREDICTABLE(i16 != zero) {
6108 i16 = (const uint8_t*) ((uintptr_t) i16 + input_offset);
6109 }
6110 const uint8_t* i17 = input[17];
6111 assert(i17 != NULL);
6112 if XNN_UNPREDICTABLE(i17 != zero) {
6113 i17 = (const uint8_t*) ((uintptr_t) i17 + input_offset);
6114 }
6115 const uint8_t* i18 = input[18];
6116 assert(i18 != NULL);
6117 if XNN_UNPREDICTABLE(i18 != zero) {
6118 i18 = (const uint8_t*) ((uintptr_t) i18 + input_offset);
6119 }
6120 const uint8_t* i19 = input[19];
6121 assert(i19 != NULL);
6122 if XNN_UNPREDICTABLE(i19 != zero) {
6123 i19 = (const uint8_t*) ((uintptr_t) i19 + input_offset);
6124 }
6125 const uint8_t* i20 = input[20];
6126 assert(i20 != NULL);
6127 if XNN_UNPREDICTABLE(i20 != zero) {
6128 i20 = (const uint8_t*) ((uintptr_t) i20 + input_offset);
6129 }
6130 const uint8_t* i21 = input[21];
6131 assert(i21 != NULL);
6132 if XNN_UNPREDICTABLE(i21 != zero) {
6133 i21 = (const uint8_t*) ((uintptr_t) i21 + input_offset);
6134 }
6135 const uint8_t* i22 = input[22];
6136 assert(i22 != NULL);
6137 if XNN_UNPREDICTABLE(i22 != zero) {
6138 i22 = (const uint8_t*) ((uintptr_t) i22 + input_offset);
6139 }
6140 const uint8_t* i23 = input[23];
6141 assert(i23 != NULL);
6142 if XNN_UNPREDICTABLE(i23 != zero) {
6143 i23 = (const uint8_t*) ((uintptr_t) i23 + input_offset);
6144 }
6145 const uint8_t* i24 = input[24];
6146 assert(i24 != NULL);
6147 if XNN_UNPREDICTABLE(i24 != zero) {
6148 i24 = (const uint8_t*) ((uintptr_t) i24 + input_offset);
6149 }
6150 input = (const uint8_t**) ((uintptr_t) input + input_stride);
6151
6152 size_t c = channels;
6153 const void* w = weights;
6154 for (; c >= 16; c -= 16) {
6155 __m256i vacc01234567 = _mm256_loadu_si256((const __m256i*) w);
6156 __m256i vacc89ABCDEF = _mm256_loadu_si256((const __m256i*) ((const int32_t*) w + 8));
6157
6158
6159 const __m256i vi0x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i0));
6160 const __m256i vk0x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 0 * sizeof(uint8_t)))), vk_zero_point);
6161 const __m256i vi0x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i0 + 8)));
6162 const __m256i vk0x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 8 * sizeof(uint8_t)))), vk_zero_point);
6163 i0 += 16;
6164
6165 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi0x01234567, vk0x01234567));
6166 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi0x89ABCDEF, vk0x89ABCDEF));
6167
6168 const __m256i vi1x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i1));
6169 const __m256i vk1x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 16 * sizeof(uint8_t)))), vk_zero_point);
6170 const __m256i vi1x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i1 + 8)));
6171 const __m256i vk1x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 24 * sizeof(uint8_t)))), vk_zero_point);
6172 i1 += 16;
6173
6174 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi1x01234567, vk1x01234567));
6175 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi1x89ABCDEF, vk1x89ABCDEF));
6176
6177 const __m256i vi2x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i2));
6178 const __m256i vk2x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 32 * sizeof(uint8_t)))), vk_zero_point);
6179 const __m256i vi2x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i2 + 8)));
6180 const __m256i vk2x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 40 * sizeof(uint8_t)))), vk_zero_point);
6181 i2 += 16;
6182
6183 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi2x01234567, vk2x01234567));
6184 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi2x89ABCDEF, vk2x89ABCDEF));
6185
6186 const __m256i vi3x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i3));
6187 const __m256i vk3x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 48 * sizeof(uint8_t)))), vk_zero_point);
6188 const __m256i vi3x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i3 + 8)));
6189 const __m256i vk3x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 56 * sizeof(uint8_t)))), vk_zero_point);
6190 i3 += 16;
6191
6192 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi3x01234567, vk3x01234567));
6193 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi3x89ABCDEF, vk3x89ABCDEF));
6194
6195 const __m256i vi4x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i4));
6196 const __m256i vk4x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 64 * sizeof(uint8_t)))), vk_zero_point);
6197 const __m256i vi4x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i4 + 8)));
6198 const __m256i vk4x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 72 * sizeof(uint8_t)))), vk_zero_point);
6199 i4 += 16;
6200
6201 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi4x01234567, vk4x01234567));
6202 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi4x89ABCDEF, vk4x89ABCDEF));
6203
6204 const __m256i vi5x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i5));
6205 const __m256i vk5x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 80 * sizeof(uint8_t)))), vk_zero_point);
6206 const __m256i vi5x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i5 + 8)));
6207 const __m256i vk5x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 88 * sizeof(uint8_t)))), vk_zero_point);
6208 i5 += 16;
6209
6210 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi5x01234567, vk5x01234567));
6211 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi5x89ABCDEF, vk5x89ABCDEF));
6212
6213 const __m256i vi6x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i6));
6214 const __m256i vk6x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 96 * sizeof(uint8_t)))), vk_zero_point);
6215 const __m256i vi6x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i6 + 8)));
6216 const __m256i vk6x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 104 * sizeof(uint8_t)))), vk_zero_point);
6217 i6 += 16;
6218
6219 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi6x01234567, vk6x01234567));
6220 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi6x89ABCDEF, vk6x89ABCDEF));
6221
6222 const __m256i vi7x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i7));
6223 const __m256i vk7x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 112 * sizeof(uint8_t)))), vk_zero_point);
6224 const __m256i vi7x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i7 + 8)));
6225 const __m256i vk7x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 120 * sizeof(uint8_t)))), vk_zero_point);
6226 i7 += 16;
6227
6228 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi7x01234567, vk7x01234567));
6229 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi7x89ABCDEF, vk7x89ABCDEF));
6230
6231 const __m256i vi8x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i8));
6232 const __m256i vk8x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 128 * sizeof(uint8_t)))), vk_zero_point);
6233 const __m256i vi8x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i8 + 8)));
6234 const __m256i vk8x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 136 * sizeof(uint8_t)))), vk_zero_point);
6235 i8 += 16;
6236
6237 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi8x01234567, vk8x01234567));
6238 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi8x89ABCDEF, vk8x89ABCDEF));
6239
6240 const __m256i vi9x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i9));
6241 const __m256i vk9x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 144 * sizeof(uint8_t)))), vk_zero_point);
6242 const __m256i vi9x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i9 + 8)));
6243 const __m256i vk9x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 152 * sizeof(uint8_t)))), vk_zero_point);
6244 i9 += 16;
6245
6246 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi9x01234567, vk9x01234567));
6247 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi9x89ABCDEF, vk9x89ABCDEF));
6248
6249 const __m256i vi10x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i10));
6250 const __m256i vk10x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 160 * sizeof(uint8_t)))), vk_zero_point);
6251 const __m256i vi10x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i10 + 8)));
6252 const __m256i vk10x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 168 * sizeof(uint8_t)))), vk_zero_point);
6253 i10 += 16;
6254
6255 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi10x01234567, vk10x01234567));
6256 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi10x89ABCDEF, vk10x89ABCDEF));
6257
6258 const __m256i vi11x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i11));
6259 const __m256i vk11x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 176 * sizeof(uint8_t)))), vk_zero_point);
6260 const __m256i vi11x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i11 + 8)));
6261 const __m256i vk11x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 184 * sizeof(uint8_t)))), vk_zero_point);
6262 i11 += 16;
6263
6264 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi11x01234567, vk11x01234567));
6265 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi11x89ABCDEF, vk11x89ABCDEF));
6266
6267 const __m256i vi12x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i12));
6268 const __m256i vk12x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 192 * sizeof(uint8_t)))), vk_zero_point);
6269 const __m256i vi12x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i12 + 8)));
6270 const __m256i vk12x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 200 * sizeof(uint8_t)))), vk_zero_point);
6271 i12 += 16;
6272
6273 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi12x01234567, vk12x01234567));
6274 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi12x89ABCDEF, vk12x89ABCDEF));
6275
6276 const __m256i vi13x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i13));
6277 const __m256i vk13x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 208 * sizeof(uint8_t)))), vk_zero_point);
6278 const __m256i vi13x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i13 + 8)));
6279 const __m256i vk13x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 216 * sizeof(uint8_t)))), vk_zero_point);
6280 i13 += 16;
6281
6282 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi13x01234567, vk13x01234567));
6283 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi13x89ABCDEF, vk13x89ABCDEF));
6284
6285 const __m256i vi14x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i14));
6286 const __m256i vk14x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 224 * sizeof(uint8_t)))), vk_zero_point);
6287 const __m256i vi14x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i14 + 8)));
6288 const __m256i vk14x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 232 * sizeof(uint8_t)))), vk_zero_point);
6289 i14 += 16;
6290
6291 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi14x01234567, vk14x01234567));
6292 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi14x89ABCDEF, vk14x89ABCDEF));
6293
6294 const __m256i vi15x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i15));
6295 const __m256i vk15x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 240 * sizeof(uint8_t)))), vk_zero_point);
6296 const __m256i vi15x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i15 + 8)));
6297 const __m256i vk15x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 248 * sizeof(uint8_t)))), vk_zero_point);
6298 i15 += 16;
6299
6300 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi15x01234567, vk15x01234567));
6301 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi15x89ABCDEF, vk15x89ABCDEF));
6302
6303 const __m256i vi16x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i16));
6304 const __m256i vk16x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 256 * sizeof(uint8_t)))), vk_zero_point);
6305 const __m256i vi16x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i16 + 8)));
6306 const __m256i vk16x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 264 * sizeof(uint8_t)))), vk_zero_point);
6307 i16 += 16;
6308
6309 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi16x01234567, vk16x01234567));
6310 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi16x89ABCDEF, vk16x89ABCDEF));
6311
6312 const __m256i vi17x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i17));
6313 const __m256i vk17x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 272 * sizeof(uint8_t)))), vk_zero_point);
6314 const __m256i vi17x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i17 + 8)));
6315 const __m256i vk17x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 280 * sizeof(uint8_t)))), vk_zero_point);
6316 i17 += 16;
6317
6318 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi17x01234567, vk17x01234567));
6319 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi17x89ABCDEF, vk17x89ABCDEF));
6320
6321 const __m256i vi18x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i18));
6322 const __m256i vk18x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 288 * sizeof(uint8_t)))), vk_zero_point);
6323 const __m256i vi18x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i18 + 8)));
6324 const __m256i vk18x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 296 * sizeof(uint8_t)))), vk_zero_point);
6325 i18 += 16;
6326
6327 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi18x01234567, vk18x01234567));
6328 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi18x89ABCDEF, vk18x89ABCDEF));
6329
6330 const __m256i vi19x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i19));
6331 const __m256i vk19x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 304 * sizeof(uint8_t)))), vk_zero_point);
6332 const __m256i vi19x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i19 + 8)));
6333 const __m256i vk19x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 312 * sizeof(uint8_t)))), vk_zero_point);
6334 i19 += 16;
6335
6336 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi19x01234567, vk19x01234567));
6337 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi19x89ABCDEF, vk19x89ABCDEF));
6338
6339 const __m256i vi20x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i20));
6340 const __m256i vk20x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 320 * sizeof(uint8_t)))), vk_zero_point);
6341 const __m256i vi20x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i20 + 8)));
6342 const __m256i vk20x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 328 * sizeof(uint8_t)))), vk_zero_point);
6343 i20 += 16;
6344
6345 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi20x01234567, vk20x01234567));
6346 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi20x89ABCDEF, vk20x89ABCDEF));
6347
6348 const __m256i vi21x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i21));
6349 const __m256i vk21x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 336 * sizeof(uint8_t)))), vk_zero_point);
6350 const __m256i vi21x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i21 + 8)));
6351 const __m256i vk21x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 344 * sizeof(uint8_t)))), vk_zero_point);
6352 i21 += 16;
6353
6354 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi21x01234567, vk21x01234567));
6355 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi21x89ABCDEF, vk21x89ABCDEF));
6356
6357 const __m256i vi22x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i22));
6358 const __m256i vk22x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 352 * sizeof(uint8_t)))), vk_zero_point);
6359 const __m256i vi22x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i22 + 8)));
6360 const __m256i vk22x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 360 * sizeof(uint8_t)))), vk_zero_point);
6361 i22 += 16;
6362
6363 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi22x01234567, vk22x01234567));
6364 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi22x89ABCDEF, vk22x89ABCDEF));
6365
6366 const __m256i vi23x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i23));
6367 const __m256i vk23x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 368 * sizeof(uint8_t)))), vk_zero_point);
6368 const __m256i vi23x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i23 + 8)));
6369 const __m256i vk23x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 376 * sizeof(uint8_t)))), vk_zero_point);
6370 i23 += 16;
6371
6372 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi23x01234567, vk23x01234567));
6373 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi23x89ABCDEF, vk23x89ABCDEF));
6374
6375 const __m256i vi24x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i24));
6376 const __m256i vk24x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 384 * sizeof(uint8_t)))), vk_zero_point);
6377 const __m256i vi24x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i24 + 8)));
6378 const __m256i vk24x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 392 * sizeof(uint8_t)))), vk_zero_point);
6379 i24 += 16;
6380
6381 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi24x01234567, vk24x01234567));
6382 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi24x89ABCDEF, vk24x89ABCDEF));
6383
6384 w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t) + 400 * sizeof(uint8_t));
6385
6386 __m256 vscaled01234567 = _mm256_cvtepi32_ps(vacc01234567);
6387 __m256 vscaled89ABCDEF = _mm256_cvtepi32_ps(vacc89ABCDEF);
6388
6389 const __m256 vscale = _mm256_load_ps(params->fp32_avx2.scale);
6390 vscaled01234567 = _mm256_mul_ps(vscaled01234567, vscale);
6391 vscaled89ABCDEF = _mm256_mul_ps(vscaled89ABCDEF, vscale);
6392
6393 const __m256 voutput_max_less_zero_point = _mm256_load_ps(params->fp32_avx2.output_max_less_zero_point);
6394 vscaled01234567 = _mm256_min_ps(vscaled01234567, voutput_max_less_zero_point);
6395 vscaled89ABCDEF = _mm256_min_ps(vscaled89ABCDEF, voutput_max_less_zero_point);
6396
6397 vacc01234567 = _mm256_cvtps_epi32(vscaled01234567);
6398 vacc89ABCDEF = _mm256_cvtps_epi32(vscaled89ABCDEF);
6399
6400 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx2.output_zero_point);
6401 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(vacc01234567, vacc89ABCDEF), voutput_zero_point);
6402
6403 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packus_epi16(_mm256_castsi256_si128(vout012389AB4567CDEF), _mm256_extracti128_si256(vout012389AB4567CDEF, 1)), _MM_SHUFFLE(3, 1, 2, 0));
6404
6405 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx2.output_min);
6406 vout0123456789ABCDEF = _mm_max_epu8(vout0123456789ABCDEF, voutput_min);
6407
6408 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
6409 output += 16;
6410 }
6411 if XNN_UNLIKELY(c != 0) {
6412 const uint8_t* k = (const uint8_t*) ((const int32_t*) w + 16);
6413 do {
6414 __m256i vacc01234567 = _mm256_loadu_si256((const __m256i*) w);
6415
6416
6417 const __m256i vi0x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i0));
6418 const __m256i vk0x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) k)), vk_zero_point);
6419 i0 += 8;
6420
6421 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi0x01234567, vk0x01234567));
6422
6423 const __m256i vi1x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i1));
6424 const __m256i vk1x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 16))), vk_zero_point);
6425 i1 += 8;
6426
6427 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi1x01234567, vk1x01234567));
6428
6429 const __m256i vi2x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i2));
6430 const __m256i vk2x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 32))), vk_zero_point);
6431 i2 += 8;
6432
6433 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi2x01234567, vk2x01234567));
6434
6435 const __m256i vi3x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i3));
6436 const __m256i vk3x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 48))), vk_zero_point);
6437 i3 += 8;
6438
6439 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi3x01234567, vk3x01234567));
6440
6441 const __m256i vi4x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i4));
6442 const __m256i vk4x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 64))), vk_zero_point);
6443 i4 += 8;
6444
6445 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi4x01234567, vk4x01234567));
6446
6447 const __m256i vi5x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i5));
6448 const __m256i vk5x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 80))), vk_zero_point);
6449 i5 += 8;
6450
6451 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi5x01234567, vk5x01234567));
6452
6453 const __m256i vi6x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i6));
6454 const __m256i vk6x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 96))), vk_zero_point);
6455 i6 += 8;
6456
6457 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi6x01234567, vk6x01234567));
6458
6459 const __m256i vi7x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i7));
6460 const __m256i vk7x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 112))), vk_zero_point);
6461 i7 += 8;
6462
6463 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi7x01234567, vk7x01234567));
6464
6465 const __m256i vi8x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i8));
6466 const __m256i vk8x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 128))), vk_zero_point);
6467 i8 += 8;
6468
6469 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi8x01234567, vk8x01234567));
6470
6471 const __m256i vi9x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i9));
6472 const __m256i vk9x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 144))), vk_zero_point);
6473 i9 += 8;
6474
6475 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi9x01234567, vk9x01234567));
6476
6477 const __m256i vi10x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i10));
6478 const __m256i vk10x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 160))), vk_zero_point);
6479 i10 += 8;
6480
6481 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi10x01234567, vk10x01234567));
6482
6483 const __m256i vi11x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i11));
6484 const __m256i vk11x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 176))), vk_zero_point);
6485 i11 += 8;
6486
6487 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi11x01234567, vk11x01234567));
6488
6489 const __m256i vi12x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i12));
6490 const __m256i vk12x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 192))), vk_zero_point);
6491 i12 += 8;
6492
6493 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi12x01234567, vk12x01234567));
6494
6495 const __m256i vi13x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i13));
6496 const __m256i vk13x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 208))), vk_zero_point);
6497 i13 += 8;
6498
6499 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi13x01234567, vk13x01234567));
6500
6501 const __m256i vi14x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i14));
6502 const __m256i vk14x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 224))), vk_zero_point);
6503 i14 += 8;
6504
6505 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi14x01234567, vk14x01234567));
6506
6507 const __m256i vi15x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i15));
6508 const __m256i vk15x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 240))), vk_zero_point);
6509 i15 += 8;
6510
6511 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi15x01234567, vk15x01234567));
6512
6513 const __m256i vi16x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i16));
6514 const __m256i vk16x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 256))), vk_zero_point);
6515 i16 += 8;
6516
6517 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi16x01234567, vk16x01234567));
6518
6519 const __m256i vi17x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i17));
6520 const __m256i vk17x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 272))), vk_zero_point);
6521 i17 += 8;
6522
6523 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi17x01234567, vk17x01234567));
6524
6525 const __m256i vi18x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i18));
6526 const __m256i vk18x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 288))), vk_zero_point);
6527 i18 += 8;
6528
6529 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi18x01234567, vk18x01234567));
6530
6531 const __m256i vi19x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i19));
6532 const __m256i vk19x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 304))), vk_zero_point);
6533 i19 += 8;
6534
6535 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi19x01234567, vk19x01234567));
6536
6537 const __m256i vi20x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i20));
6538 const __m256i vk20x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 320))), vk_zero_point);
6539 i20 += 8;
6540
6541 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi20x01234567, vk20x01234567));
6542
6543 const __m256i vi21x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i21));
6544 const __m256i vk21x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 336))), vk_zero_point);
6545 i21 += 8;
6546
6547 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi21x01234567, vk21x01234567));
6548
6549 const __m256i vi22x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i22));
6550 const __m256i vk22x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 352))), vk_zero_point);
6551 i22 += 8;
6552
6553 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi22x01234567, vk22x01234567));
6554
6555 const __m256i vi23x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i23));
6556 const __m256i vk23x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 368))), vk_zero_point);
6557 i23 += 8;
6558
6559 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi23x01234567, vk23x01234567));
6560
6561 const __m256i vi24x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i24));
6562 const __m256i vk24x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 384))), vk_zero_point);
6563 i24 += 8;
6564
6565 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi24x01234567, vk24x01234567));
6566
6567 k += 8;
6568
6569 __m256 vscaled01234567 = _mm256_cvtepi32_ps(vacc01234567);
6570 vscaled01234567 = _mm256_mul_ps(vscaled01234567, _mm256_load_ps(params->fp32_avx2.scale));
6571 vscaled01234567 = _mm256_min_ps(vscaled01234567, _mm256_load_ps(params->fp32_avx2.output_max_less_zero_point));
6572 vacc01234567 = _mm256_cvtps_epi32(vscaled01234567);
6573
6574 w = (const void*) ((const int32_t*) w + 8);
6575
6576 const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->fp32_avx2.output_zero_point);
6577 __m128i vout01234567 = _mm_adds_epi16(_mm_packs_epi32(_mm256_castsi256_si128(vacc01234567), _mm256_extracti128_si256(vacc01234567, 1)), voutput_zero_point);
6578
6579 __m128i vout0123456701234567 = _mm_packus_epi16(vout01234567, vout01234567);
6580
6581 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx2.output_min);
6582 vout0123456701234567 = _mm_max_epu8(vout0123456701234567, voutput_min);
6583
6584 if XNN_LIKELY(c >= 8) {
6585 _mm_storel_epi64((__m128i*) output, vout0123456701234567);
6586 output += 8;
6587 c -= 8;
6588 } else {
6589 if (c & 4) {
6590 unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vout0123456701234567));
6591 vout0123456701234567 = _mm_srli_epi64(vout0123456701234567, 32);
6592 output += 4;
6593 }
6594 if (c & 2) {
6595 unaligned_store_u16(output, (uint16_t) _mm_extract_epi16(vout0123456701234567, 0));
6596 vout0123456701234567 = _mm_srli_epi32(vout0123456701234567, 16);
6597 output += 2;
6598 }
6599 if (c & 1) {
6600 *output = (uint8_t) _mm_extract_epi8(vout0123456701234567, 0);
6601 output += 1;
6602 }
6603 c = 0;
6604 }
6605 } while (c != 0);
6606 }
6607
6608 output = (uint8_t*) ((uintptr_t) output + output_increment);
6609 } while (--output_width != 0);
6610 }
6611
xnn_qu8_dwconv_minmax_fp32_ukernel_up16x9__avx2_mul32(size_t channels,size_t output_width,const uint8_t ** input,const void * weights,uint8_t * output,size_t input_stride,size_t output_increment,size_t input_offset,const uint8_t * zero,const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])6612 void xnn_qu8_dwconv_minmax_fp32_ukernel_up16x9__avx2_mul32(
6613 size_t channels,
6614 size_t output_width,
6615 const uint8_t** input,
6616 const void* weights,
6617 uint8_t* output,
6618 size_t input_stride,
6619 size_t output_increment,
6620 size_t input_offset,
6621 const uint8_t* zero,
6622 const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
6623 {
6624 assert(channels != 0);
6625 assert(output_width != 0);
6626
6627 const __m256i vk_zero_point = _mm256_cvtepu16_epi32(_mm_load_si128((const __m128i*) params->fp32_avx2.kernel_zero_point));
6628 do {
6629 const uint8_t* i0 = input[0];
6630 assert(i0 != NULL);
6631 if XNN_UNPREDICTABLE(i0 != zero) {
6632 i0 = (const uint8_t*) ((uintptr_t) i0 + input_offset);
6633 }
6634 const uint8_t* i1 = input[1];
6635 assert(i1 != NULL);
6636 if XNN_UNPREDICTABLE(i1 != zero) {
6637 i1 = (const uint8_t*) ((uintptr_t) i1 + input_offset);
6638 }
6639 const uint8_t* i2 = input[2];
6640 assert(i2 != NULL);
6641 if XNN_UNPREDICTABLE(i2 != zero) {
6642 i2 = (const uint8_t*) ((uintptr_t) i2 + input_offset);
6643 }
6644 const uint8_t* i3 = input[3];
6645 assert(i3 != NULL);
6646 if XNN_UNPREDICTABLE(i3 != zero) {
6647 i3 = (const uint8_t*) ((uintptr_t) i3 + input_offset);
6648 }
6649 const uint8_t* i4 = input[4];
6650 assert(i4 != NULL);
6651 if XNN_UNPREDICTABLE(i4 != zero) {
6652 i4 = (const uint8_t*) ((uintptr_t) i4 + input_offset);
6653 }
6654 const uint8_t* i5 = input[5];
6655 assert(i5 != NULL);
6656 if XNN_UNPREDICTABLE(i5 != zero) {
6657 i5 = (const uint8_t*) ((uintptr_t) i5 + input_offset);
6658 }
6659 const uint8_t* i6 = input[6];
6660 assert(i6 != NULL);
6661 if XNN_UNPREDICTABLE(i6 != zero) {
6662 i6 = (const uint8_t*) ((uintptr_t) i6 + input_offset);
6663 }
6664 const uint8_t* i7 = input[7];
6665 assert(i7 != NULL);
6666 if XNN_UNPREDICTABLE(i7 != zero) {
6667 i7 = (const uint8_t*) ((uintptr_t) i7 + input_offset);
6668 }
6669 const uint8_t* i8 = input[8];
6670 assert(i8 != NULL);
6671 if XNN_UNPREDICTABLE(i8 != zero) {
6672 i8 = (const uint8_t*) ((uintptr_t) i8 + input_offset);
6673 }
6674 input = (const uint8_t**) ((uintptr_t) input + input_stride);
6675
6676 size_t c = channels;
6677 const void* w = weights;
6678 for (; c >= 16; c -= 16) {
6679 __m256i vacc01234567 = _mm256_loadu_si256((const __m256i*) w);
6680 __m256i vacc89ABCDEF = _mm256_loadu_si256((const __m256i*) ((const int32_t*) w + 8));
6681
6682
6683 const __m256i vi0x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i0));
6684 const __m256i vk0x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 0 * sizeof(uint8_t)))), vk_zero_point);
6685 const __m256i vi0x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i0 + 8)));
6686 const __m256i vk0x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 8 * sizeof(uint8_t)))), vk_zero_point);
6687 i0 += 16;
6688
6689 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi0x01234567, vk0x01234567));
6690 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi0x89ABCDEF, vk0x89ABCDEF));
6691
6692 const __m256i vi1x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i1));
6693 const __m256i vk1x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 16 * sizeof(uint8_t)))), vk_zero_point);
6694 const __m256i vi1x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i1 + 8)));
6695 const __m256i vk1x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 24 * sizeof(uint8_t)))), vk_zero_point);
6696 i1 += 16;
6697
6698 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi1x01234567, vk1x01234567));
6699 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi1x89ABCDEF, vk1x89ABCDEF));
6700
6701 const __m256i vi2x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i2));
6702 const __m256i vk2x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 32 * sizeof(uint8_t)))), vk_zero_point);
6703 const __m256i vi2x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i2 + 8)));
6704 const __m256i vk2x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 40 * sizeof(uint8_t)))), vk_zero_point);
6705 i2 += 16;
6706
6707 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi2x01234567, vk2x01234567));
6708 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi2x89ABCDEF, vk2x89ABCDEF));
6709
6710 const __m256i vi3x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i3));
6711 const __m256i vk3x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 48 * sizeof(uint8_t)))), vk_zero_point);
6712 const __m256i vi3x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i3 + 8)));
6713 const __m256i vk3x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 56 * sizeof(uint8_t)))), vk_zero_point);
6714 i3 += 16;
6715
6716 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi3x01234567, vk3x01234567));
6717 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi3x89ABCDEF, vk3x89ABCDEF));
6718
6719 const __m256i vi4x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i4));
6720 const __m256i vk4x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 64 * sizeof(uint8_t)))), vk_zero_point);
6721 const __m256i vi4x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i4 + 8)));
6722 const __m256i vk4x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 72 * sizeof(uint8_t)))), vk_zero_point);
6723 i4 += 16;
6724
6725 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi4x01234567, vk4x01234567));
6726 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi4x89ABCDEF, vk4x89ABCDEF));
6727
6728 const __m256i vi5x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i5));
6729 const __m256i vk5x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 80 * sizeof(uint8_t)))), vk_zero_point);
6730 const __m256i vi5x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i5 + 8)));
6731 const __m256i vk5x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 88 * sizeof(uint8_t)))), vk_zero_point);
6732 i5 += 16;
6733
6734 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi5x01234567, vk5x01234567));
6735 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi5x89ABCDEF, vk5x89ABCDEF));
6736
6737 const __m256i vi6x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i6));
6738 const __m256i vk6x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 96 * sizeof(uint8_t)))), vk_zero_point);
6739 const __m256i vi6x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i6 + 8)));
6740 const __m256i vk6x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 104 * sizeof(uint8_t)))), vk_zero_point);
6741 i6 += 16;
6742
6743 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi6x01234567, vk6x01234567));
6744 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi6x89ABCDEF, vk6x89ABCDEF));
6745
6746 const __m256i vi7x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i7));
6747 const __m256i vk7x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 112 * sizeof(uint8_t)))), vk_zero_point);
6748 const __m256i vi7x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i7 + 8)));
6749 const __m256i vk7x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 120 * sizeof(uint8_t)))), vk_zero_point);
6750 i7 += 16;
6751
6752 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi7x01234567, vk7x01234567));
6753 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi7x89ABCDEF, vk7x89ABCDEF));
6754
6755 const __m256i vi8x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i8));
6756 const __m256i vk8x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 128 * sizeof(uint8_t)))), vk_zero_point);
6757 const __m256i vi8x89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (i8 + 8)));
6758 const __m256i vk8x89ABCDEF = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) ((uintptr_t) w + 16 * sizeof(int32_t) + 136 * sizeof(uint8_t)))), vk_zero_point);
6759 i8 += 16;
6760
6761 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi8x01234567, vk8x01234567));
6762 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vi8x89ABCDEF, vk8x89ABCDEF));
6763
6764 w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t) + 144 * sizeof(uint8_t));
6765
6766 __m256 vscaled01234567 = _mm256_cvtepi32_ps(vacc01234567);
6767 __m256 vscaled89ABCDEF = _mm256_cvtepi32_ps(vacc89ABCDEF);
6768
6769 const __m256 vscale = _mm256_load_ps(params->fp32_avx2.scale);
6770 vscaled01234567 = _mm256_mul_ps(vscaled01234567, vscale);
6771 vscaled89ABCDEF = _mm256_mul_ps(vscaled89ABCDEF, vscale);
6772
6773 const __m256 voutput_max_less_zero_point = _mm256_load_ps(params->fp32_avx2.output_max_less_zero_point);
6774 vscaled01234567 = _mm256_min_ps(vscaled01234567, voutput_max_less_zero_point);
6775 vscaled89ABCDEF = _mm256_min_ps(vscaled89ABCDEF, voutput_max_less_zero_point);
6776
6777 vacc01234567 = _mm256_cvtps_epi32(vscaled01234567);
6778 vacc89ABCDEF = _mm256_cvtps_epi32(vscaled89ABCDEF);
6779
6780 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx2.output_zero_point);
6781 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(vacc01234567, vacc89ABCDEF), voutput_zero_point);
6782
6783 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packus_epi16(_mm256_castsi256_si128(vout012389AB4567CDEF), _mm256_extracti128_si256(vout012389AB4567CDEF, 1)), _MM_SHUFFLE(3, 1, 2, 0));
6784
6785 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx2.output_min);
6786 vout0123456789ABCDEF = _mm_max_epu8(vout0123456789ABCDEF, voutput_min);
6787
6788 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
6789 output += 16;
6790 }
6791 if XNN_UNLIKELY(c != 0) {
6792 const uint8_t* k = (const uint8_t*) ((const int32_t*) w + 16);
6793 do {
6794 __m256i vacc01234567 = _mm256_loadu_si256((const __m256i*) w);
6795
6796
6797 const __m256i vi0x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i0));
6798 const __m256i vk0x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) k)), vk_zero_point);
6799 i0 += 8;
6800
6801 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi0x01234567, vk0x01234567));
6802
6803 const __m256i vi1x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i1));
6804 const __m256i vk1x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 16))), vk_zero_point);
6805 i1 += 8;
6806
6807 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi1x01234567, vk1x01234567));
6808
6809 const __m256i vi2x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i2));
6810 const __m256i vk2x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 32))), vk_zero_point);
6811 i2 += 8;
6812
6813 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi2x01234567, vk2x01234567));
6814
6815 const __m256i vi3x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i3));
6816 const __m256i vk3x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 48))), vk_zero_point);
6817 i3 += 8;
6818
6819 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi3x01234567, vk3x01234567));
6820
6821 const __m256i vi4x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i4));
6822 const __m256i vk4x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 64))), vk_zero_point);
6823 i4 += 8;
6824
6825 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi4x01234567, vk4x01234567));
6826
6827 const __m256i vi5x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i5));
6828 const __m256i vk5x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 80))), vk_zero_point);
6829 i5 += 8;
6830
6831 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi5x01234567, vk5x01234567));
6832
6833 const __m256i vi6x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i6));
6834 const __m256i vk6x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 96))), vk_zero_point);
6835 i6 += 8;
6836
6837 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi6x01234567, vk6x01234567));
6838
6839 const __m256i vi7x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i7));
6840 const __m256i vk7x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 112))), vk_zero_point);
6841 i7 += 8;
6842
6843 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi7x01234567, vk7x01234567));
6844
6845 const __m256i vi8x01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) i8));
6846 const __m256i vk8x01234567 = _mm256_sub_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (k + 128))), vk_zero_point);
6847 i8 += 8;
6848
6849 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vi8x01234567, vk8x01234567));
6850
6851 k += 8;
6852
6853 __m256 vscaled01234567 = _mm256_cvtepi32_ps(vacc01234567);
6854 vscaled01234567 = _mm256_mul_ps(vscaled01234567, _mm256_load_ps(params->fp32_avx2.scale));
6855 vscaled01234567 = _mm256_min_ps(vscaled01234567, _mm256_load_ps(params->fp32_avx2.output_max_less_zero_point));
6856 vacc01234567 = _mm256_cvtps_epi32(vscaled01234567);
6857
6858 w = (const void*) ((const int32_t*) w + 8);
6859
6860 const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->fp32_avx2.output_zero_point);
6861 __m128i vout01234567 = _mm_adds_epi16(_mm_packs_epi32(_mm256_castsi256_si128(vacc01234567), _mm256_extracti128_si256(vacc01234567, 1)), voutput_zero_point);
6862
6863 __m128i vout0123456701234567 = _mm_packus_epi16(vout01234567, vout01234567);
6864
6865 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx2.output_min);
6866 vout0123456701234567 = _mm_max_epu8(vout0123456701234567, voutput_min);
6867
6868 if XNN_LIKELY(c >= 8) {
6869 _mm_storel_epi64((__m128i*) output, vout0123456701234567);
6870 output += 8;
6871 c -= 8;
6872 } else {
6873 if (c & 4) {
6874 unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vout0123456701234567));
6875 vout0123456701234567 = _mm_srli_epi64(vout0123456701234567, 32);
6876 output += 4;
6877 }
6878 if (c & 2) {
6879 unaligned_store_u16(output, (uint16_t) _mm_extract_epi16(vout0123456701234567, 0));
6880 vout0123456701234567 = _mm_srli_epi32(vout0123456701234567, 16);
6881 output += 2;
6882 }
6883 if (c & 1) {
6884 *output = (uint8_t) _mm_extract_epi8(vout0123456701234567, 0);
6885 output += 1;
6886 }
6887 c = 0;
6888 }
6889 } while (c != 0);
6890 }
6891
6892 output = (uint8_t*) ((uintptr_t) output + output_increment);
6893 } while (--output_width != 0);
6894 }
6895
xnn_qu8_f32_vcvt_ukernel__avx2_x16(size_t n,const uint8_t * x,float * y,const union xnn_qu8_f32_cvt_params params[restrict XNN_MIN_ELEMENTS (1)])6896 void xnn_qu8_f32_vcvt_ukernel__avx2_x16(
6897 size_t n,
6898 const uint8_t* x,
6899 float* y,
6900 const union xnn_qu8_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
6901 {
6902 assert(n != 0);
6903 assert(n % sizeof(uint8_t) == 0);
6904 assert(x != NULL);
6905 assert(y != NULL);
6906
6907 const __m256i vminus_zero_point = _mm256_load_si256((const __m256i*) params->avx.minus_zero_point);
6908 const __m256 vscale = _mm256_load_ps(params->avx.scale);
6909 for (; n >= 16 * sizeof(uint8_t); n -= 16 * sizeof(uint8_t)) {
6910 __m256i vx01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) x));
6911 __m256i vx89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (x + 8)));
6912 x += 16;
6913
6914 vx01234567 = _mm256_add_epi32(vx01234567, vminus_zero_point);
6915 vx89ABCDEF = _mm256_add_epi32(vx89ABCDEF, vminus_zero_point);
6916
6917 __m256 vy01234567 = _mm256_cvtepi32_ps(vx01234567);
6918 __m256 vy89ABCDEF = _mm256_cvtepi32_ps(vx89ABCDEF);
6919
6920 vy01234567 = _mm256_mul_ps(vy01234567, vscale);
6921 vy89ABCDEF = _mm256_mul_ps(vy89ABCDEF, vscale);
6922
6923 _mm256_storeu_ps(y, vy01234567);
6924 _mm256_storeu_ps(y + 8, vy89ABCDEF);
6925 y += 16;
6926 }
6927 for (; n >= 8 * sizeof(uint8_t); n -= 8 * sizeof(uint8_t)) {
6928 __m256i vx = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) x));
6929 vx = _mm256_add_epi32(vx, vminus_zero_point);
6930 x += 8;
6931
6932 __m256 vy = _mm256_cvtepi32_ps(vx);
6933 vy = _mm256_mul_ps(vy, vscale);
6934
6935 _mm256_storeu_ps(y, vy);
6936 y += 8;
6937 }
6938 if XNN_UNLIKELY(n != 0) {
6939 assert(n >= 1 * sizeof(uint8_t));
6940 assert(n <= 7 * sizeof(uint8_t));
6941
6942 __m256i vx = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) x));
6943 vx = _mm256_add_epi32(vx, vminus_zero_point);
6944
6945 __m256 vy = _mm256_cvtepi32_ps(vx);
6946 vy = _mm256_mul_ps(vy, vscale);
6947
6948 __m128 vy_lo = _mm256_castps256_ps128(vy);
6949 if (n & (4 * sizeof(uint8_t))) {
6950 _mm_storeu_ps(y, vy_lo);
6951 vy_lo = _mm256_extractf128_ps(vy, 1);
6952 y += 4;
6953 }
6954 if (n & (2 * sizeof(uint8_t))) {
6955 _mm_storel_pi((__m64*) y, vy_lo);
6956 vy_lo = _mm_movehl_ps(vy_lo, vy_lo);
6957 y += 2;
6958 }
6959 if (n & (1 * sizeof(uint8_t))) {
6960 _mm_store_ss(y, vy_lo);
6961 }
6962 }
6963 }
6964
xnn_qu8_gemm_minmax_fp32_ukernel_1x8c8__avx2(size_t mr,size_t nc,size_t kc,const uint8_t * restrict a,size_t a_stride,const void * restrict w,uint8_t * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])6965 void xnn_qu8_gemm_minmax_fp32_ukernel_1x8c8__avx2(
6966 size_t mr,
6967 size_t nc,
6968 size_t kc,
6969 const uint8_t* restrict a,
6970 size_t a_stride,
6971 const void* restrict w,
6972 uint8_t* restrict c,
6973 size_t cm_stride,
6974 size_t cn_stride,
6975 const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
6976 {
6977 assert(mr != 0);
6978 assert(mr <= 1);
6979 assert(nc != 0);
6980 assert(kc != 0);
6981 assert(kc % sizeof(uint8_t) == 0);
6982 assert(a != NULL);
6983 assert(w != NULL);
6984 assert(c != NULL);
6985
6986 kc = round_up_po2(kc, 8);
6987 const uint8_t* a0 = a;
6988 uint8_t* c0 = c;
6989
6990 do {
6991 const __m128i vbias0x0 = _mm_cvtsi32_si128(((const int*) w)[0]);
6992 const __m128i vbias0x1 = _mm_cvtsi32_si128(((const int*) w)[1]);
6993 __m256i vacc0x01 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x0), vbias0x1, 1);
6994 const __m128i vbias0x2 = _mm_cvtsi32_si128(((const int*) w)[2]);
6995 const __m128i vbias0x3 = _mm_cvtsi32_si128(((const int*) w)[3]);
6996 __m256i vacc0x23 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x2), vbias0x3, 1);
6997 const __m128i vbias0x4 = _mm_cvtsi32_si128(((const int*) w)[4]);
6998 const __m128i vbias0x5 = _mm_cvtsi32_si128(((const int*) w)[5]);
6999 __m256i vacc0x45 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x4), vbias0x5, 1);
7000 const __m128i vbias0x6 = _mm_cvtsi32_si128(((const int*) w)[6]);
7001 const __m128i vbias0x7 = _mm_cvtsi32_si128(((const int*) w)[7]);
7002 __m256i vacc0x67 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x6), vbias0x7, 1);
7003 w = (const int32_t*) w + 8;
7004
7005 size_t k = 0;
7006 const __m256i vb_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx2.kernel_zero_point);
7007 while (k < kc) {
7008 const __m128i va0 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a0));
7009 const __m256i vxa0 = _mm256_cvtepu8_epi16(va0);
7010 a0 += 8;
7011
7012 const __m128i vb01 = _mm_load_si128((const __m128i*) w);
7013 const __m256i vxb01 = _mm256_sub_epi16(_mm256_cvtepu8_epi16(vb01), vb_zero_point);
7014
7015 vacc0x01 = _mm256_add_epi32(vacc0x01, _mm256_madd_epi16(vxa0, vxb01));
7016 const __m128i vb23 = _mm_load_si128((const __m128i*) ((const uint8_t*) w + 16));
7017 const __m256i vxb23 = _mm256_sub_epi16(_mm256_cvtepu8_epi16(vb23), vb_zero_point);
7018
7019 vacc0x23 = _mm256_add_epi32(vacc0x23, _mm256_madd_epi16(vxa0, vxb23));
7020 const __m128i vb45 = _mm_load_si128((const __m128i*) ((const uint8_t*) w + 32));
7021 const __m256i vxb45 = _mm256_sub_epi16(_mm256_cvtepu8_epi16(vb45), vb_zero_point);
7022
7023 vacc0x45 = _mm256_add_epi32(vacc0x45, _mm256_madd_epi16(vxa0, vxb45));
7024 const __m128i vb67 = _mm_load_si128((const __m128i*) ((const uint8_t*) w + 48));
7025 const __m256i vxb67 = _mm256_sub_epi16(_mm256_cvtepu8_epi16(vb67), vb_zero_point);
7026
7027 vacc0x67 = _mm256_add_epi32(vacc0x67, _mm256_madd_epi16(vxa0, vxb67));
7028
7029 w = (const void*) ((const uint8_t*) w + 64);
7030 k += 8 * sizeof(uint8_t);
7031 }
7032
7033 const __m256i vacc0x0213 = _mm256_hadd_epi32(vacc0x01, vacc0x23);
7034 const __m256i vacc0x4657 = _mm256_hadd_epi32(vacc0x45, vacc0x67);
7035
7036 const __m256i vacc0x02461357 = _mm256_hadd_epi32(vacc0x0213, vacc0x4657);
7037
7038 const __m256i vpermute_mask = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
7039 __m256i vacc0x01234567 = _mm256_permutevar8x32_epi32(vacc0x02461357, vpermute_mask);
7040
7041 __m256 vscaled0x01234567 = _mm256_cvtepi32_ps(vacc0x01234567);
7042
7043 const __m256 vscale = _mm256_load_ps(params->fp32_avx2.scale);
7044 vscaled0x01234567 = _mm256_mul_ps(vscaled0x01234567, vscale);
7045
7046 const __m256 voutput_max_less_zero_point = _mm256_load_ps(params->fp32_avx2.output_max_less_zero_point);
7047 vscaled0x01234567 = _mm256_min_ps(vscaled0x01234567, voutput_max_less_zero_point);
7048
7049 vacc0x01234567 = _mm256_cvtps_epi32(vscaled0x01234567);
7050
7051 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx2.output_zero_point);
7052 __m256i vacc00x01234567 = _mm256_adds_epi16(_mm256_packs_epi32(vacc0x01234567, vacc0x01234567), voutput_zero_point);
7053
7054 vacc00x01234567 = _mm256_permute4x64_epi64(vacc00x01234567, _MM_SHUFFLE(3, 1, 2, 0));
7055
7056 __m256i vout = _mm256_packus_epi16(vacc00x01234567, vacc00x01234567);
7057
7058 vout = _mm256_max_epu8(vout, _mm256_load_si256((const __m256i*) params->fp32_avx2.output_min));
7059
7060 __m128i vout_lo = _mm256_castsi256_si128(vout);
7061 __m128i vout_hi = _mm256_extracti128_si256(vout, 1);
7062
7063 if (nc >= 8) {
7064 _mm_storel_epi64((__m128i*) c0, vout_lo);
7065
7066 c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride);
7067
7068 a0 = (const uint8_t*) ((uintptr_t) a0 - kc);
7069
7070 nc -= 8;
7071 } else {
7072 if (nc & 4) {
7073 _mm_storeu_si32(c0, vout_lo);
7074
7075 c0 += 4;
7076
7077 vout_lo = _mm_srli_epi64(vout_lo, 32);
7078 vout_hi = _mm_srli_epi64(vout_hi, 32);
7079 }
7080 if (nc & 2) {
7081 unaligned_store_u16(c0, (uint16_t) _mm_extract_epi16(vout_lo, 0));
7082
7083 c0 += 2;
7084
7085 vout_lo = _mm_srli_epi32(vout_lo, 16);
7086 vout_hi = _mm_srli_epi32(vout_hi, 16);
7087 }
7088 if (nc & 1) {
7089 *c0 = (uint8_t) _mm_extract_epi8(vout_lo, 0);
7090 }
7091
7092 nc = 0;
7093 }
7094 } while (nc != 0);
7095 }
7096
xnn_qu8_gemm_minmax_fp32_ukernel_3x8c8__avx2(size_t mr,size_t nc,size_t kc,const uint8_t * restrict a,size_t a_stride,const void * restrict w,uint8_t * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])7097 void xnn_qu8_gemm_minmax_fp32_ukernel_3x8c8__avx2(
7098 size_t mr,
7099 size_t nc,
7100 size_t kc,
7101 const uint8_t* restrict a,
7102 size_t a_stride,
7103 const void* restrict w,
7104 uint8_t* restrict c,
7105 size_t cm_stride,
7106 size_t cn_stride,
7107 const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
7108 {
7109 assert(mr != 0);
7110 assert(mr <= 3);
7111 assert(nc != 0);
7112 assert(kc != 0);
7113 assert(kc % sizeof(uint8_t) == 0);
7114 assert(a != NULL);
7115 assert(w != NULL);
7116 assert(c != NULL);
7117
7118 kc = round_up_po2(kc, 8);
7119 const uint8_t* a0 = a;
7120 uint8_t* c0 = c;
7121 const uint8_t* a1 = (const uint8_t*) ((uintptr_t) a0 + a_stride);
7122 uint8_t* c1 = (uint8_t*) ((uintptr_t) c0 + cm_stride);
7123 if XNN_UNPREDICTABLE(mr < 2) {
7124 a1 = a0;
7125 c1 = c0;
7126 }
7127 const uint8_t* a2 = (const uint8_t*) ((uintptr_t) a1 + a_stride);
7128 uint8_t* c2 = (uint8_t*) ((uintptr_t) c1 + cm_stride);
7129 if XNN_UNPREDICTABLE(mr <= 2) {
7130 a2 = a1;
7131 c2 = c1;
7132 }
7133
7134 do {
7135 const __m128i vbias0x0 = _mm_cvtsi32_si128(((const int*) w)[0]);
7136 const __m128i vbias0x1 = _mm_cvtsi32_si128(((const int*) w)[1]);
7137 __m256i vacc0x01 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x0), vbias0x1, 1);
7138 const __m128i vbias0x2 = _mm_cvtsi32_si128(((const int*) w)[2]);
7139 const __m128i vbias0x3 = _mm_cvtsi32_si128(((const int*) w)[3]);
7140 __m256i vacc0x23 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x2), vbias0x3, 1);
7141 const __m128i vbias0x4 = _mm_cvtsi32_si128(((const int*) w)[4]);
7142 const __m128i vbias0x5 = _mm_cvtsi32_si128(((const int*) w)[5]);
7143 __m256i vacc0x45 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x4), vbias0x5, 1);
7144 const __m128i vbias0x6 = _mm_cvtsi32_si128(((const int*) w)[6]);
7145 const __m128i vbias0x7 = _mm_cvtsi32_si128(((const int*) w)[7]);
7146 __m256i vacc0x67 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x6), vbias0x7, 1);
7147 __m256i vacc1x01 = vacc0x01;
7148 __m256i vacc1x23 = vacc0x23;
7149 __m256i vacc1x45 = vacc0x45;
7150 __m256i vacc1x67 = vacc0x67;
7151 __m256i vacc2x01 = vacc0x01;
7152 __m256i vacc2x23 = vacc0x23;
7153 __m256i vacc2x45 = vacc0x45;
7154 __m256i vacc2x67 = vacc0x67;
7155 w = (const int32_t*) w + 8;
7156
7157 size_t k = 0;
7158 const __m256i vb_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx2.kernel_zero_point);
7159 while (k < kc) {
7160 const __m128i va0 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a0));
7161 const __m256i vxa0 = _mm256_cvtepu8_epi16(va0);
7162 a0 += 8;
7163 const __m128i va1 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a1));
7164 const __m256i vxa1 = _mm256_cvtepu8_epi16(va1);
7165 a1 += 8;
7166 const __m128i va2 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a2));
7167 const __m256i vxa2 = _mm256_cvtepu8_epi16(va2);
7168 a2 += 8;
7169
7170 const __m128i vb01 = _mm_load_si128((const __m128i*) w);
7171 const __m256i vxb01 = _mm256_sub_epi16(_mm256_cvtepu8_epi16(vb01), vb_zero_point);
7172
7173 vacc0x01 = _mm256_add_epi32(vacc0x01, _mm256_madd_epi16(vxa0, vxb01));
7174 vacc1x01 = _mm256_add_epi32(vacc1x01, _mm256_madd_epi16(vxa1, vxb01));
7175 vacc2x01 = _mm256_add_epi32(vacc2x01, _mm256_madd_epi16(vxa2, vxb01));
7176 const __m128i vb23 = _mm_load_si128((const __m128i*) ((const uint8_t*) w + 16));
7177 const __m256i vxb23 = _mm256_sub_epi16(_mm256_cvtepu8_epi16(vb23), vb_zero_point);
7178
7179 vacc0x23 = _mm256_add_epi32(vacc0x23, _mm256_madd_epi16(vxa0, vxb23));
7180 vacc1x23 = _mm256_add_epi32(vacc1x23, _mm256_madd_epi16(vxa1, vxb23));
7181 vacc2x23 = _mm256_add_epi32(vacc2x23, _mm256_madd_epi16(vxa2, vxb23));
7182 const __m128i vb45 = _mm_load_si128((const __m128i*) ((const uint8_t*) w + 32));
7183 const __m256i vxb45 = _mm256_sub_epi16(_mm256_cvtepu8_epi16(vb45), vb_zero_point);
7184
7185 vacc0x45 = _mm256_add_epi32(vacc0x45, _mm256_madd_epi16(vxa0, vxb45));
7186 vacc1x45 = _mm256_add_epi32(vacc1x45, _mm256_madd_epi16(vxa1, vxb45));
7187 vacc2x45 = _mm256_add_epi32(vacc2x45, _mm256_madd_epi16(vxa2, vxb45));
7188 const __m128i vb67 = _mm_load_si128((const __m128i*) ((const uint8_t*) w + 48));
7189 const __m256i vxb67 = _mm256_sub_epi16(_mm256_cvtepu8_epi16(vb67), vb_zero_point);
7190
7191 vacc0x67 = _mm256_add_epi32(vacc0x67, _mm256_madd_epi16(vxa0, vxb67));
7192 vacc1x67 = _mm256_add_epi32(vacc1x67, _mm256_madd_epi16(vxa1, vxb67));
7193 vacc2x67 = _mm256_add_epi32(vacc2x67, _mm256_madd_epi16(vxa2, vxb67));
7194
7195 w = (const void*) ((const uint8_t*) w + 64);
7196 k += 8 * sizeof(uint8_t);
7197 }
7198
7199 const __m256i vacc0x0213 = _mm256_hadd_epi32(vacc0x01, vacc0x23);
7200 const __m256i vacc0x4657 = _mm256_hadd_epi32(vacc0x45, vacc0x67);
7201 const __m256i vacc1x0213 = _mm256_hadd_epi32(vacc1x01, vacc1x23);
7202 const __m256i vacc1x4657 = _mm256_hadd_epi32(vacc1x45, vacc1x67);
7203 const __m256i vacc2x0213 = _mm256_hadd_epi32(vacc2x01, vacc2x23);
7204 const __m256i vacc2x4657 = _mm256_hadd_epi32(vacc2x45, vacc2x67);
7205
7206 const __m256i vacc0x02461357 = _mm256_hadd_epi32(vacc0x0213, vacc0x4657);
7207 const __m256i vacc1x02461357 = _mm256_hadd_epi32(vacc1x0213, vacc1x4657);
7208 const __m256i vacc2x02461357 = _mm256_hadd_epi32(vacc2x0213, vacc2x4657);
7209
7210 const __m256i vpermute_mask = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
7211 __m256i vacc0x01234567 = _mm256_permutevar8x32_epi32(vacc0x02461357, vpermute_mask);
7212 __m256i vacc1x01234567 = _mm256_permutevar8x32_epi32(vacc1x02461357, vpermute_mask);
7213 __m256i vacc2x01234567 = _mm256_permutevar8x32_epi32(vacc2x02461357, vpermute_mask);
7214
7215 __m256 vscaled0x01234567 = _mm256_cvtepi32_ps(vacc0x01234567);
7216 __m256 vscaled1x01234567 = _mm256_cvtepi32_ps(vacc1x01234567);
7217 __m256 vscaled2x01234567 = _mm256_cvtepi32_ps(vacc2x01234567);
7218
7219 const __m256 vscale = _mm256_load_ps(params->fp32_avx2.scale);
7220 vscaled0x01234567 = _mm256_mul_ps(vscaled0x01234567, vscale);
7221 vscaled1x01234567 = _mm256_mul_ps(vscaled1x01234567, vscale);
7222 vscaled2x01234567 = _mm256_mul_ps(vscaled2x01234567, vscale);
7223
7224 const __m256 voutput_max_less_zero_point = _mm256_load_ps(params->fp32_avx2.output_max_less_zero_point);
7225 vscaled0x01234567 = _mm256_min_ps(vscaled0x01234567, voutput_max_less_zero_point);
7226 vscaled1x01234567 = _mm256_min_ps(vscaled1x01234567, voutput_max_less_zero_point);
7227 vscaled2x01234567 = _mm256_min_ps(vscaled2x01234567, voutput_max_less_zero_point);
7228
7229 vacc0x01234567 = _mm256_cvtps_epi32(vscaled0x01234567);
7230 vacc1x01234567 = _mm256_cvtps_epi32(vscaled1x01234567);
7231 vacc2x01234567 = _mm256_cvtps_epi32(vscaled2x01234567);
7232
7233 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx2.output_zero_point);
7234 __m256i vacc01x01234567 = _mm256_adds_epi16(_mm256_packs_epi32(vacc0x01234567, vacc1x01234567), voutput_zero_point);
7235 __m256i vacc22x01234567 = _mm256_adds_epi16(_mm256_packs_epi32(vacc2x01234567, vacc2x01234567), voutput_zero_point);
7236
7237 vacc01x01234567 = _mm256_permute4x64_epi64(vacc01x01234567, _MM_SHUFFLE(3, 1, 2, 0));
7238 vacc22x01234567 = _mm256_permute4x64_epi64(vacc22x01234567, _MM_SHUFFLE(3, 1, 2, 0));
7239
7240 __m256i vout = _mm256_packus_epi16(vacc01x01234567, vacc22x01234567);
7241
7242 vout = _mm256_max_epu8(vout, _mm256_load_si256((const __m256i*) params->fp32_avx2.output_min));
7243
7244 __m128i vout_lo = _mm256_castsi256_si128(vout);
7245 __m128i vout_hi = _mm256_extracti128_si256(vout, 1);
7246
7247 if (nc >= 8) {
7248 _mm_storel_epi64((__m128i*) c0, vout_lo);
7249 _mm_storel_epi64((__m128i*) c1, vout_hi);
7250 _mm_storeh_pi((__m64*) c2, _mm_castsi128_ps(vout_lo));
7251
7252 c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride);
7253 c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride);
7254 c2 = (uint8_t*) ((uintptr_t) c2 + cn_stride);
7255
7256 a0 = (const uint8_t*) ((uintptr_t) a0 - kc);
7257 a1 = (const uint8_t*) ((uintptr_t) a1 - kc);
7258 a2 = (const uint8_t*) ((uintptr_t) a2 - kc);
7259
7260 nc -= 8;
7261 } else {
7262 if (nc & 4) {
7263 _mm_storeu_si32(c0, vout_lo);
7264 _mm_storeu_si32(c1, vout_hi);
7265 unaligned_store_u32(c2, (uint32_t) _mm_extract_epi32(vout_lo, 2));
7266
7267 c0 += 4;
7268 c1 += 4;
7269 c2 += 4;
7270
7271 vout_lo = _mm_srli_epi64(vout_lo, 32);
7272 vout_hi = _mm_srli_epi64(vout_hi, 32);
7273 }
7274 if (nc & 2) {
7275 unaligned_store_u16(c0, (uint16_t) _mm_extract_epi16(vout_lo, 0));
7276 unaligned_store_u16(c1, (uint16_t) _mm_extract_epi16(vout_hi, 0));
7277 unaligned_store_u16(c2, (uint16_t) _mm_extract_epi16(vout_lo, 4));
7278
7279 c0 += 2;
7280 c1 += 2;
7281 c2 += 2;
7282
7283 vout_lo = _mm_srli_epi32(vout_lo, 16);
7284 vout_hi = _mm_srli_epi32(vout_hi, 16);
7285 }
7286 if (nc & 1) {
7287 *c0 = (uint8_t) _mm_extract_epi8(vout_lo, 0);
7288 *c1 = (uint8_t) _mm_extract_epi8(vout_hi, 0);
7289 *c2 = (uint8_t) _mm_extract_epi8(vout_lo, 8);
7290 }
7291
7292 nc = 0;
7293 }
7294 } while (nc != 0);
7295 }
7296
xnn_qu8_igemm_minmax_fp32_ukernel_1x8c8__avx2(size_t mr,size_t nc,size_t kc,size_t ks,const uint8_t ** restrict a,const void * restrict w,uint8_t * restrict c,size_t cm_stride,size_t cn_stride,size_t a_offset,const uint8_t * zero,const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])7297 void xnn_qu8_igemm_minmax_fp32_ukernel_1x8c8__avx2(
7298 size_t mr,
7299 size_t nc,
7300 size_t kc,
7301 size_t ks,
7302 const uint8_t** restrict a,
7303 const void* restrict w,
7304 uint8_t* restrict c,
7305 size_t cm_stride,
7306 size_t cn_stride,
7307 size_t a_offset,
7308 const uint8_t* zero,
7309 const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
7310 {
7311 assert(mr != 0);
7312 assert(mr <= 1);
7313 assert(nc != 0);
7314 assert(kc != 0);
7315 assert(ks != 0);
7316 assert(ks % (1 * sizeof(void*)) == 0);
7317 assert(a_offset % sizeof(uint8_t) == 0);
7318 assert(a != NULL);
7319 assert(w != NULL);
7320 assert(c != NULL);
7321
7322 kc = round_up_po2(kc, 8);
7323 uint8_t* c0 = c;
7324
7325 do {
7326 const __m128i vbias0x0 = _mm_cvtsi32_si128(((const int*) w)[0]);
7327 const __m128i vbias0x1 = _mm_cvtsi32_si128(((const int*) w)[1]);
7328 __m256i vacc0x01 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x0), vbias0x1, 1);
7329 const __m128i vbias0x2 = _mm_cvtsi32_si128(((const int*) w)[2]);
7330 const __m128i vbias0x3 = _mm_cvtsi32_si128(((const int*) w)[3]);
7331 __m256i vacc0x23 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x2), vbias0x3, 1);
7332 const __m128i vbias0x4 = _mm_cvtsi32_si128(((const int*) w)[4]);
7333 const __m128i vbias0x5 = _mm_cvtsi32_si128(((const int*) w)[5]);
7334 __m256i vacc0x45 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x4), vbias0x5, 1);
7335 const __m128i vbias0x6 = _mm_cvtsi32_si128(((const int*) w)[6]);
7336 const __m128i vbias0x7 = _mm_cvtsi32_si128(((const int*) w)[7]);
7337 __m256i vacc0x67 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x6), vbias0x7, 1);
7338 w = (const int32_t*) w + 8;
7339
7340 size_t p = ks;
7341 const __m256i vb_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx2.kernel_zero_point);
7342 do {
7343 const uint8_t* restrict a0 = a[0];
7344 if XNN_UNPREDICTABLE(a0 != zero) {
7345 a0 = (const uint8_t*) ((uintptr_t) a0 + a_offset);
7346 }
7347 a += 1;
7348
7349 size_t k = 0;
7350 while (k < kc) {
7351 const __m128i va0 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a0));
7352 const __m256i vxa0 = _mm256_cvtepu8_epi16(va0);
7353 a0 += 8;
7354
7355 const __m128i vb01 = _mm_load_si128((const __m128i*) w);
7356 const __m256i vxb01 = _mm256_sub_epi16(_mm256_cvtepu8_epi16(vb01), vb_zero_point);
7357
7358 vacc0x01 = _mm256_add_epi32(vacc0x01, _mm256_madd_epi16(vxa0, vxb01));
7359 const __m128i vb23 = _mm_load_si128((const __m128i*) ((const uint8_t*) w + 16));
7360 const __m256i vxb23 = _mm256_sub_epi16(_mm256_cvtepu8_epi16(vb23), vb_zero_point);
7361
7362 vacc0x23 = _mm256_add_epi32(vacc0x23, _mm256_madd_epi16(vxa0, vxb23));
7363 const __m128i vb45 = _mm_load_si128((const __m128i*) ((const uint8_t*) w + 32));
7364 const __m256i vxb45 = _mm256_sub_epi16(_mm256_cvtepu8_epi16(vb45), vb_zero_point);
7365
7366 vacc0x45 = _mm256_add_epi32(vacc0x45, _mm256_madd_epi16(vxa0, vxb45));
7367 const __m128i vb67 = _mm_load_si128((const __m128i*) ((const uint8_t*) w + 48));
7368 const __m256i vxb67 = _mm256_sub_epi16(_mm256_cvtepu8_epi16(vb67), vb_zero_point);
7369
7370 vacc0x67 = _mm256_add_epi32(vacc0x67, _mm256_madd_epi16(vxa0, vxb67));
7371
7372 w = (const void*) ((const uint8_t*) w + 64);
7373 k += 8 * sizeof(uint8_t);
7374 }
7375 p -= 1 * sizeof(void*);
7376 } while (p != 0);
7377
7378 const __m256i vacc0x0213 = _mm256_hadd_epi32(vacc0x01, vacc0x23);
7379 const __m256i vacc0x4657 = _mm256_hadd_epi32(vacc0x45, vacc0x67);
7380
7381 const __m256i vacc0x02461357 = _mm256_hadd_epi32(vacc0x0213, vacc0x4657);
7382
7383 const __m256i vpermute_mask = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
7384 __m256i vacc0x01234567 = _mm256_permutevar8x32_epi32(vacc0x02461357, vpermute_mask);
7385
7386 __m256 vscaled0x01234567 = _mm256_cvtepi32_ps(vacc0x01234567);
7387
7388 const __m256 vscale = _mm256_load_ps(params->fp32_avx2.scale);
7389 vscaled0x01234567 = _mm256_mul_ps(vscaled0x01234567, vscale);
7390
7391 const __m256 voutput_max_less_zero_point = _mm256_load_ps(params->fp32_avx2.output_max_less_zero_point);
7392 vscaled0x01234567 = _mm256_min_ps(vscaled0x01234567, voutput_max_less_zero_point);
7393
7394 vacc0x01234567 = _mm256_cvtps_epi32(vscaled0x01234567);
7395
7396 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx2.output_zero_point);
7397 __m256i vacc00x01234567 = _mm256_adds_epi16(_mm256_packs_epi32(vacc0x01234567, vacc0x01234567), voutput_zero_point);
7398
7399 vacc00x01234567 = _mm256_permute4x64_epi64(vacc00x01234567, _MM_SHUFFLE(3, 1, 2, 0));
7400
7401 __m256i vout = _mm256_packus_epi16(vacc00x01234567, vacc00x01234567);
7402
7403 vout = _mm256_max_epu8(vout, _mm256_load_si256((const __m256i*) params->fp32_avx2.output_min));
7404
7405 __m128i vout_lo = _mm256_castsi256_si128(vout);
7406 __m128i vout_hi = _mm256_extracti128_si256(vout, 1);
7407
7408 if (nc >= 8) {
7409 _mm_storel_epi64((__m128i*) c0, vout_lo);
7410
7411 c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride);
7412
7413 a = (const uint8_t**restrict) ((uintptr_t) a - ks);
7414
7415 nc -= 8;
7416 } else {
7417 if (nc & 4) {
7418 _mm_storeu_si32(c0, vout_lo);
7419
7420 c0 += 4;
7421
7422 vout_lo = _mm_srli_epi64(vout_lo, 32);
7423 vout_hi = _mm_srli_epi64(vout_hi, 32);
7424 }
7425 if (nc & 2) {
7426 unaligned_store_u16(c0, (uint16_t) _mm_extract_epi16(vout_lo, 0));
7427
7428 c0 += 2;
7429
7430 vout_lo = _mm_srli_epi32(vout_lo, 16);
7431 vout_hi = _mm_srli_epi32(vout_hi, 16);
7432 }
7433 if (nc & 1) {
7434 *c0 = (uint8_t) _mm_extract_epi8(vout_lo, 0);
7435 }
7436
7437 nc = 0;
7438 }
7439 } while (nc != 0);
7440 }
7441
xnn_qu8_igemm_minmax_fp32_ukernel_3x8c8__avx2(size_t mr,size_t nc,size_t kc,size_t ks,const uint8_t ** restrict a,const void * restrict w,uint8_t * restrict c,size_t cm_stride,size_t cn_stride,size_t a_offset,const uint8_t * zero,const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])7442 void xnn_qu8_igemm_minmax_fp32_ukernel_3x8c8__avx2(
7443 size_t mr,
7444 size_t nc,
7445 size_t kc,
7446 size_t ks,
7447 const uint8_t** restrict a,
7448 const void* restrict w,
7449 uint8_t* restrict c,
7450 size_t cm_stride,
7451 size_t cn_stride,
7452 size_t a_offset,
7453 const uint8_t* zero,
7454 const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
7455 {
7456 assert(mr != 0);
7457 assert(mr <= 3);
7458 assert(nc != 0);
7459 assert(kc != 0);
7460 assert(ks != 0);
7461 assert(ks % (3 * sizeof(void*)) == 0);
7462 assert(a_offset % sizeof(uint8_t) == 0);
7463 assert(a != NULL);
7464 assert(w != NULL);
7465 assert(c != NULL);
7466
7467 kc = round_up_po2(kc, 8);
7468 uint8_t* c0 = c;
7469 uint8_t* c1 = (uint8_t*) ((uintptr_t) c0 + cm_stride);
7470 if XNN_UNPREDICTABLE(mr < 2) {
7471 c1 = c0;
7472 }
7473 uint8_t* c2 = (uint8_t*) ((uintptr_t) c1 + cm_stride);
7474 if XNN_UNPREDICTABLE(mr <= 2) {
7475 c2 = c1;
7476 }
7477
7478 do {
7479 const __m128i vbias0x0 = _mm_cvtsi32_si128(((const int*) w)[0]);
7480 const __m128i vbias0x1 = _mm_cvtsi32_si128(((const int*) w)[1]);
7481 __m256i vacc0x01 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x0), vbias0x1, 1);
7482 const __m128i vbias0x2 = _mm_cvtsi32_si128(((const int*) w)[2]);
7483 const __m128i vbias0x3 = _mm_cvtsi32_si128(((const int*) w)[3]);
7484 __m256i vacc0x23 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x2), vbias0x3, 1);
7485 const __m128i vbias0x4 = _mm_cvtsi32_si128(((const int*) w)[4]);
7486 const __m128i vbias0x5 = _mm_cvtsi32_si128(((const int*) w)[5]);
7487 __m256i vacc0x45 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x4), vbias0x5, 1);
7488 const __m128i vbias0x6 = _mm_cvtsi32_si128(((const int*) w)[6]);
7489 const __m128i vbias0x7 = _mm_cvtsi32_si128(((const int*) w)[7]);
7490 __m256i vacc0x67 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x6), vbias0x7, 1);
7491 __m256i vacc1x01 = vacc0x01;
7492 __m256i vacc1x23 = vacc0x23;
7493 __m256i vacc1x45 = vacc0x45;
7494 __m256i vacc1x67 = vacc0x67;
7495 __m256i vacc2x01 = vacc0x01;
7496 __m256i vacc2x23 = vacc0x23;
7497 __m256i vacc2x45 = vacc0x45;
7498 __m256i vacc2x67 = vacc0x67;
7499 w = (const int32_t*) w + 8;
7500
7501 size_t p = ks;
7502 const __m256i vb_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx2.kernel_zero_point);
7503 do {
7504 const uint8_t* restrict a0 = a[0];
7505 if XNN_UNPREDICTABLE(a0 != zero) {
7506 a0 = (const uint8_t*) ((uintptr_t) a0 + a_offset);
7507 }
7508 const uint8_t* restrict a1 = a[1];
7509 if XNN_UNPREDICTABLE(a1 != zero) {
7510 a1 = (const uint8_t*) ((uintptr_t) a1 + a_offset);
7511 }
7512 const uint8_t* restrict a2 = a[2];
7513 if XNN_UNPREDICTABLE(a2 != zero) {
7514 a2 = (const uint8_t*) ((uintptr_t) a2 + a_offset);
7515 }
7516 a += 3;
7517
7518 size_t k = 0;
7519 while (k < kc) {
7520 const __m128i va0 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a0));
7521 const __m256i vxa0 = _mm256_cvtepu8_epi16(va0);
7522 a0 += 8;
7523 const __m128i va1 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a1));
7524 const __m256i vxa1 = _mm256_cvtepu8_epi16(va1);
7525 a1 += 8;
7526 const __m128i va2 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a2));
7527 const __m256i vxa2 = _mm256_cvtepu8_epi16(va2);
7528 a2 += 8;
7529
7530 const __m128i vb01 = _mm_load_si128((const __m128i*) w);
7531 const __m256i vxb01 = _mm256_sub_epi16(_mm256_cvtepu8_epi16(vb01), vb_zero_point);
7532
7533 vacc0x01 = _mm256_add_epi32(vacc0x01, _mm256_madd_epi16(vxa0, vxb01));
7534 vacc1x01 = _mm256_add_epi32(vacc1x01, _mm256_madd_epi16(vxa1, vxb01));
7535 vacc2x01 = _mm256_add_epi32(vacc2x01, _mm256_madd_epi16(vxa2, vxb01));
7536 const __m128i vb23 = _mm_load_si128((const __m128i*) ((const uint8_t*) w + 16));
7537 const __m256i vxb23 = _mm256_sub_epi16(_mm256_cvtepu8_epi16(vb23), vb_zero_point);
7538
7539 vacc0x23 = _mm256_add_epi32(vacc0x23, _mm256_madd_epi16(vxa0, vxb23));
7540 vacc1x23 = _mm256_add_epi32(vacc1x23, _mm256_madd_epi16(vxa1, vxb23));
7541 vacc2x23 = _mm256_add_epi32(vacc2x23, _mm256_madd_epi16(vxa2, vxb23));
7542 const __m128i vb45 = _mm_load_si128((const __m128i*) ((const uint8_t*) w + 32));
7543 const __m256i vxb45 = _mm256_sub_epi16(_mm256_cvtepu8_epi16(vb45), vb_zero_point);
7544
7545 vacc0x45 = _mm256_add_epi32(vacc0x45, _mm256_madd_epi16(vxa0, vxb45));
7546 vacc1x45 = _mm256_add_epi32(vacc1x45, _mm256_madd_epi16(vxa1, vxb45));
7547 vacc2x45 = _mm256_add_epi32(vacc2x45, _mm256_madd_epi16(vxa2, vxb45));
7548 const __m128i vb67 = _mm_load_si128((const __m128i*) ((const uint8_t*) w + 48));
7549 const __m256i vxb67 = _mm256_sub_epi16(_mm256_cvtepu8_epi16(vb67), vb_zero_point);
7550
7551 vacc0x67 = _mm256_add_epi32(vacc0x67, _mm256_madd_epi16(vxa0, vxb67));
7552 vacc1x67 = _mm256_add_epi32(vacc1x67, _mm256_madd_epi16(vxa1, vxb67));
7553 vacc2x67 = _mm256_add_epi32(vacc2x67, _mm256_madd_epi16(vxa2, vxb67));
7554
7555 w = (const void*) ((const uint8_t*) w + 64);
7556 k += 8 * sizeof(uint8_t);
7557 }
7558 p -= 3 * sizeof(void*);
7559 } while (p != 0);
7560
7561 const __m256i vacc0x0213 = _mm256_hadd_epi32(vacc0x01, vacc0x23);
7562 const __m256i vacc0x4657 = _mm256_hadd_epi32(vacc0x45, vacc0x67);
7563 const __m256i vacc1x0213 = _mm256_hadd_epi32(vacc1x01, vacc1x23);
7564 const __m256i vacc1x4657 = _mm256_hadd_epi32(vacc1x45, vacc1x67);
7565 const __m256i vacc2x0213 = _mm256_hadd_epi32(vacc2x01, vacc2x23);
7566 const __m256i vacc2x4657 = _mm256_hadd_epi32(vacc2x45, vacc2x67);
7567
7568 const __m256i vacc0x02461357 = _mm256_hadd_epi32(vacc0x0213, vacc0x4657);
7569 const __m256i vacc1x02461357 = _mm256_hadd_epi32(vacc1x0213, vacc1x4657);
7570 const __m256i vacc2x02461357 = _mm256_hadd_epi32(vacc2x0213, vacc2x4657);
7571
7572 const __m256i vpermute_mask = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
7573 __m256i vacc0x01234567 = _mm256_permutevar8x32_epi32(vacc0x02461357, vpermute_mask);
7574 __m256i vacc1x01234567 = _mm256_permutevar8x32_epi32(vacc1x02461357, vpermute_mask);
7575 __m256i vacc2x01234567 = _mm256_permutevar8x32_epi32(vacc2x02461357, vpermute_mask);
7576
7577 __m256 vscaled0x01234567 = _mm256_cvtepi32_ps(vacc0x01234567);
7578 __m256 vscaled1x01234567 = _mm256_cvtepi32_ps(vacc1x01234567);
7579 __m256 vscaled2x01234567 = _mm256_cvtepi32_ps(vacc2x01234567);
7580
7581 const __m256 vscale = _mm256_load_ps(params->fp32_avx2.scale);
7582 vscaled0x01234567 = _mm256_mul_ps(vscaled0x01234567, vscale);
7583 vscaled1x01234567 = _mm256_mul_ps(vscaled1x01234567, vscale);
7584 vscaled2x01234567 = _mm256_mul_ps(vscaled2x01234567, vscale);
7585
7586 const __m256 voutput_max_less_zero_point = _mm256_load_ps(params->fp32_avx2.output_max_less_zero_point);
7587 vscaled0x01234567 = _mm256_min_ps(vscaled0x01234567, voutput_max_less_zero_point);
7588 vscaled1x01234567 = _mm256_min_ps(vscaled1x01234567, voutput_max_less_zero_point);
7589 vscaled2x01234567 = _mm256_min_ps(vscaled2x01234567, voutput_max_less_zero_point);
7590
7591 vacc0x01234567 = _mm256_cvtps_epi32(vscaled0x01234567);
7592 vacc1x01234567 = _mm256_cvtps_epi32(vscaled1x01234567);
7593 vacc2x01234567 = _mm256_cvtps_epi32(vscaled2x01234567);
7594
7595 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx2.output_zero_point);
7596 __m256i vacc01x01234567 = _mm256_adds_epi16(_mm256_packs_epi32(vacc0x01234567, vacc1x01234567), voutput_zero_point);
7597 __m256i vacc22x01234567 = _mm256_adds_epi16(_mm256_packs_epi32(vacc2x01234567, vacc2x01234567), voutput_zero_point);
7598
7599 vacc01x01234567 = _mm256_permute4x64_epi64(vacc01x01234567, _MM_SHUFFLE(3, 1, 2, 0));
7600 vacc22x01234567 = _mm256_permute4x64_epi64(vacc22x01234567, _MM_SHUFFLE(3, 1, 2, 0));
7601
7602 __m256i vout = _mm256_packus_epi16(vacc01x01234567, vacc22x01234567);
7603
7604 vout = _mm256_max_epu8(vout, _mm256_load_si256((const __m256i*) params->fp32_avx2.output_min));
7605
7606 __m128i vout_lo = _mm256_castsi256_si128(vout);
7607 __m128i vout_hi = _mm256_extracti128_si256(vout, 1);
7608
7609 if (nc >= 8) {
7610 _mm_storeh_pi((__m64*) c2, _mm_castsi128_ps(vout_lo));
7611 _mm_storel_epi64((__m128i*) c1, vout_hi);
7612 _mm_storel_epi64((__m128i*) c0, vout_lo);
7613
7614 c2 = (uint8_t*) ((uintptr_t) c2 + cn_stride);
7615 c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride);
7616 c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride);
7617
7618 a = (const uint8_t**restrict) ((uintptr_t) a - ks);
7619
7620 nc -= 8;
7621 } else {
7622 if (nc & 4) {
7623 unaligned_store_u32(c2, (uint32_t) _mm_extract_epi32(vout_lo, 2));
7624 _mm_storeu_si32(c1, vout_hi);
7625 _mm_storeu_si32(c0, vout_lo);
7626
7627 c2 += 4;
7628 c1 += 4;
7629 c0 += 4;
7630
7631 vout_lo = _mm_srli_epi64(vout_lo, 32);
7632 vout_hi = _mm_srli_epi64(vout_hi, 32);
7633 }
7634 if (nc & 2) {
7635 unaligned_store_u16(c2, (uint16_t) _mm_extract_epi16(vout_lo, 4));
7636 unaligned_store_u16(c1, (uint16_t) _mm_extract_epi16(vout_hi, 0));
7637 unaligned_store_u16(c0, (uint16_t) _mm_extract_epi16(vout_lo, 0));
7638
7639 c2 += 2;
7640 c1 += 2;
7641 c0 += 2;
7642
7643 vout_lo = _mm_srli_epi32(vout_lo, 16);
7644 vout_hi = _mm_srli_epi32(vout_hi, 16);
7645 }
7646 if (nc & 1) {
7647 *c2 = (uint8_t) _mm_extract_epi8(vout_lo, 8);
7648 *c1 = (uint8_t) _mm_extract_epi8(vout_hi, 0);
7649 *c0 = (uint8_t) _mm_extract_epi8(vout_lo, 0);
7650 }
7651
7652 nc = 0;
7653 }
7654 } while (nc != 0);
7655 }
7656
xnn_qu8_vadd_minmax_ukernel__avx2_mul32_ld64_x16(size_t n,const uint8_t * input_a,const uint8_t * input_b,uint8_t * output,const union xnn_qu8_add_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])7657 void xnn_qu8_vadd_minmax_ukernel__avx2_mul32_ld64_x16(
7658 size_t n,
7659 const uint8_t* input_a,
7660 const uint8_t* input_b,
7661 uint8_t* output,
7662 const union xnn_qu8_add_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
7663 {
7664 const __m256i vbias = _mm256_load_si256((const __m256i*) params->avx2.bias);
7665 const __m256i va_multiplier = _mm256_load_si256((const __m256i*) params->avx2.a_multiplier);
7666 const __m256i vb_multiplier = _mm256_load_si256((const __m256i*) params->avx2.b_multiplier);
7667 const __m128i vshift = _mm_load_si128((const __m128i*) params->avx2.shift);
7668 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->avx2.output_zero_point);
7669 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->avx2.output_min);
7670 const __m128i voutput_max = _mm_load_si128((const __m128i*) params->avx2.output_max);
7671
7672 for (; n >= 16 * sizeof(uint8_t); n -= 16 * sizeof(uint8_t)) {
7673 const __m256i va01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) input_a));
7674 const __m256i vb01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) input_b));
7675 const __m256i va89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (input_a + 8)));
7676 const __m256i vb89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (input_b + 8)));
7677 input_a += 16;
7678 input_b += 16;
7679
7680 __m256i vacc01234567 = _mm256_add_epi32(vbias, _mm256_mullo_epi32(va01234567, va_multiplier));
7681 __m256i vacc89ABCDEF = _mm256_add_epi32(vbias, _mm256_mullo_epi32(va89ABCDEF, va_multiplier));
7682
7683 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vb01234567, vb_multiplier));
7684 vacc89ABCDEF = _mm256_add_epi32(vacc89ABCDEF, _mm256_mullo_epi32(vb89ABCDEF, vb_multiplier));
7685
7686 vacc01234567 = _mm256_sra_epi32(vacc01234567, vshift);
7687 vacc89ABCDEF = _mm256_sra_epi32(vacc89ABCDEF, vshift);
7688
7689 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(vacc01234567, vacc89ABCDEF), voutput_zero_point);
7690
7691 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packus_epi16(_mm256_castsi256_si128(vout012389AB4567CDEF), _mm256_extracti128_si256(vout012389AB4567CDEF, 1)), _MM_SHUFFLE(3, 1, 2, 0));
7692
7693 vout0123456789ABCDEF = _mm_max_epu8(vout0123456789ABCDEF, voutput_min);
7694
7695 vout0123456789ABCDEF = _mm_min_epu8(vout0123456789ABCDEF, voutput_max);
7696
7697 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
7698 output += 16;
7699 }
7700 if XNN_UNLIKELY(n != 0) {
7701 do {
7702 const __m256i va01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) input_a));
7703 const __m256i vb01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) input_b));
7704 input_a += 8;
7705 input_b += 8;
7706
7707 __m256i vacc01234567 = _mm256_add_epi32(vbias, _mm256_mullo_epi32(va01234567, va_multiplier));
7708
7709 vacc01234567 = _mm256_add_epi32(vacc01234567, _mm256_mullo_epi32(vb01234567, vb_multiplier));
7710
7711 vacc01234567 = _mm256_sra_epi32(vacc01234567, vshift);
7712
7713 __m128i vout01234567 = _mm_adds_epi16(_mm_packs_epi32(_mm256_castsi256_si128(vacc01234567), _mm256_extracti128_si256(vacc01234567, 1)), _mm256_castsi256_si128(voutput_zero_point));
7714 __m128i vout0123456701234567 = _mm_packus_epi16(vout01234567, vout01234567);
7715 vout0123456701234567 = _mm_max_epu8(vout0123456701234567, voutput_min);
7716 vout0123456701234567 = _mm_min_epu8(vout0123456701234567, voutput_max);
7717
7718 if XNN_LIKELY(n >= (8 * sizeof(uint8_t))) {
7719 _mm_storel_epi64((__m128i*) output, vout0123456701234567);
7720 output += 8;
7721 n -= 8 * sizeof(uint8_t);
7722 } else {
7723 if (n & (4 * sizeof(uint8_t))) {
7724 _mm_storeu_si32(output, vout0123456701234567);
7725 vout0123456701234567 = _mm_srli_epi64(vout0123456701234567, 32);
7726 output += 4;
7727 }
7728 if (n & (2 * sizeof(uint8_t))) {
7729 _mm_storeu_si16(output, vout0123456701234567);
7730 vout0123456701234567 = _mm_srli_epi32(vout0123456701234567, 16);
7731 output += 2;
7732 }
7733 if (n & (1 * sizeof(uint8_t))) {
7734 *output = (uint8_t) _mm_extract_epi8(vout0123456701234567, 0);
7735 }
7736 n = 0;
7737 }
7738 } while (n != 0);
7739 }
7740 }
7741
xnn_qu8_vaddc_minmax_ukernel__avx2_mul32_ld64_x16(size_t n,const uint8_t * input_a,const uint8_t * input_b,uint8_t * output,const union xnn_qu8_add_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])7742 void xnn_qu8_vaddc_minmax_ukernel__avx2_mul32_ld64_x16(
7743 size_t n,
7744 const uint8_t* input_a,
7745 const uint8_t* input_b,
7746 uint8_t* output,
7747 const union xnn_qu8_add_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
7748 {
7749 const __m256i va_multiplier = _mm256_load_si256((const __m256i*) params->avx2.a_multiplier);
7750 const __m128i vshift = _mm_load_si128((const __m128i*) params->avx2.shift);
7751 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->avx2.output_zero_point);
7752 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->avx2.output_min);
7753 const __m128i voutput_max = _mm_load_si128((const __m128i*) params->avx2.output_max);
7754
7755 const __m256i vbias = _mm256_add_epi32(
7756 _mm256_broadcastd_epi32(_mm_cvtsi32_si128(params->avx2.b_multiplier[0] * (int32_t) *input_b)),
7757 _mm256_load_si256((const __m256i*) params->avx2.bias));
7758 for (; n >= 16 * sizeof(uint8_t); n -= 16 * sizeof(uint8_t)) {
7759 const __m256i va01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) input_a));
7760 const __m256i va89ABCDEF = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) (input_a + 8)));
7761 input_a += 16;
7762
7763 __m256i vacc01234567 = _mm256_add_epi32(vbias, _mm256_mullo_epi32(va01234567, va_multiplier));
7764 __m256i vacc89ABCDEF = _mm256_add_epi32(vbias, _mm256_mullo_epi32(va89ABCDEF, va_multiplier));
7765
7766 vacc01234567 = _mm256_sra_epi32(vacc01234567, vshift);
7767 vacc89ABCDEF = _mm256_sra_epi32(vacc89ABCDEF, vshift);
7768
7769 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(vacc01234567, vacc89ABCDEF), voutput_zero_point);
7770
7771 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packus_epi16(_mm256_castsi256_si128(vout012389AB4567CDEF), _mm256_extracti128_si256(vout012389AB4567CDEF, 1)), _MM_SHUFFLE(3, 1, 2, 0));
7772
7773 vout0123456789ABCDEF = _mm_max_epu8(vout0123456789ABCDEF, voutput_min);
7774
7775 vout0123456789ABCDEF = _mm_min_epu8(vout0123456789ABCDEF, voutput_max);
7776
7777 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
7778 output += 16;
7779 }
7780 if XNN_UNLIKELY(n != 0) {
7781 do {
7782 const __m256i va01234567 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*) input_a));
7783 input_a += 8;
7784
7785 __m256i vacc01234567 = _mm256_add_epi32(vbias, _mm256_mullo_epi32(va01234567, va_multiplier));
7786
7787 vacc01234567 = _mm256_sra_epi32(vacc01234567, vshift);
7788
7789 __m128i vout01234567 = _mm_adds_epi16(_mm_packs_epi32(_mm256_castsi256_si128(vacc01234567), _mm256_extracti128_si256(vacc01234567, 1)), _mm256_castsi256_si128(voutput_zero_point));
7790 __m128i vout0123456701234567 = _mm_packus_epi16(vout01234567, vout01234567);
7791 vout0123456701234567 = _mm_max_epu8(vout0123456701234567, voutput_min);
7792 vout0123456701234567 = _mm_min_epu8(vout0123456701234567, voutput_max);
7793
7794 if XNN_LIKELY(n >= (8 * sizeof(uint8_t))) {
7795 _mm_storel_epi64((__m128i*) output, vout0123456701234567);
7796 output += 8;
7797 n -= 8 * sizeof(uint8_t);
7798 } else {
7799 if (n & (4 * sizeof(uint8_t))) {
7800 _mm_storeu_si32(output, vout0123456701234567);
7801 vout0123456701234567 = _mm_srli_epi64(vout0123456701234567, 32);
7802 output += 4;
7803 }
7804 if (n & (2 * sizeof(uint8_t))) {
7805 _mm_storeu_si16(output, vout0123456701234567);
7806 vout0123456701234567 = _mm_srli_epi32(vout0123456701234567, 16);
7807 output += 2;
7808 }
7809 if (n & (1 * sizeof(uint8_t))) {
7810 *output = (uint8_t) _mm_extract_epi8(vout0123456701234567, 0);
7811 }
7812 n = 0;
7813 }
7814 } while (n != 0);
7815 }
7816 }
7817
xnn_qu8_vcvt_ukernel__avx2_x32(size_t n,const uint8_t * x,uint8_t * y,const union xnn_qu8_cvt_params params[restrict XNN_MIN_ELEMENTS (1)])7818 void xnn_qu8_vcvt_ukernel__avx2_x32(
7819 size_t n,
7820 const uint8_t* x,
7821 uint8_t* y,
7822 const union xnn_qu8_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
7823 {
7824 assert(n != 0);
7825 assert(n % sizeof(uint8_t) == 0);
7826 assert(x != NULL);
7827 assert(y != NULL);
7828
7829 const __m256i vinput_zero_point = _mm256_load_si256((const __m256i*) params->avx2.input_zero_point);
7830 const __m256i vmultiplier = _mm256_load_si256((const __m256i*) params->avx2.multiplier);
7831 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->avx2.output_zero_point);
7832 for (; n >= 32 * sizeof(uint8_t); n -= 32 * sizeof(uint8_t)) {
7833 __m256i vacc0 = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i*) x));
7834 __m256i vacc1 = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i*) (x + 16)));
7835 x += 32;
7836
7837 vacc0 = _mm256_sub_epi16(vinput_zero_point, vacc0);
7838 vacc1 = _mm256_sub_epi16(vinput_zero_point, vacc1);
7839
7840 vacc0 = _mm256_slli_epi16(vacc0, 7);
7841 vacc1 = _mm256_slli_epi16(vacc1, 7);
7842
7843 vacc0 = _mm256_mulhrs_epi16(vacc0, vmultiplier);
7844 vacc1 = _mm256_mulhrs_epi16(vacc1, vmultiplier);
7845
7846 vacc0 = _mm256_adds_epi16(vacc0, voutput_zero_point);
7847 vacc1 = _mm256_adds_epi16(vacc1, voutput_zero_point);
7848
7849 __m256i vy0 = _mm256_packus_epi16(vacc0, vacc1);
7850
7851 vy0 = _mm256_permute4x64_epi64(vy0, _MM_SHUFFLE(3, 1, 2, 0));
7852
7853 _mm256_storeu_si256((__m256i*) y, vy0);
7854 y += 32;
7855 }
7856 for (; n >= 16 * sizeof(uint8_t); n -= 16 * sizeof(uint8_t)) {
7857 __m256i vacc = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i*) x));
7858 vacc = _mm256_sub_epi16(vinput_zero_point, vacc);
7859 vacc = _mm256_slli_epi16(vacc, 7);
7860 vacc = _mm256_mulhrs_epi16(vacc, vmultiplier);
7861 vacc = _mm256_adds_epi16(vacc, voutput_zero_point);
7862 x += 16;
7863
7864 const __m128i vacc_hi = _mm256_extracti128_si256(vacc, 1);
7865 const __m128i vy = _mm_packus_epi16(_mm256_castsi256_si128(vacc), vacc_hi);
7866 _mm_storeu_si128((__m128i*) y, vy);
7867 y += 16;
7868 }
7869 if XNN_UNLIKELY(n != 0) {
7870 assert(n >= 1 * sizeof(uint8_t));
7871 assert(n <= 15 * sizeof(uint8_t));
7872
7873 __m256i vacc = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i*) x));
7874 vacc = _mm256_sub_epi16(vinput_zero_point, vacc);
7875 vacc = _mm256_slli_epi16(vacc, 7);
7876 vacc = _mm256_mulhrs_epi16(vacc, vmultiplier);
7877 vacc = _mm256_adds_epi16(vacc, voutput_zero_point);
7878
7879 const __m128i vacc_hi = _mm256_extracti128_si256(vacc, 1);
7880 __m128i vy = _mm_packus_epi16(_mm256_castsi256_si128(vacc), vacc_hi);
7881 if (n & (8 * sizeof(uint8_t))) {
7882 _mm_storel_epi64((__m128i*) y, vy);
7883 vy = _mm_unpackhi_epi64(vy, vy);
7884 y += 8;
7885 }
7886 if (n & (4 * sizeof(uint8_t))) {
7887 _mm_storeu_si32(y, vy);
7888 vy = _mm_srli_epi64(vy, 32);
7889 y += 4;
7890 }
7891 if (n & (2 * sizeof(uint8_t))) {
7892 _mm_storeu_si16(y, vy);
7893 vy = _mm_srli_epi32(vy, 16);
7894 y += 2;
7895 }
7896 if (n & (1 * sizeof(uint8_t))) {
7897 *y = (uint8_t) _mm_extract_epi8(vy, 0);
7898 }
7899 }
7900 }
7901
xnn_qu8_vlrelu_ukernel__avx2_x32(size_t n,const uint8_t * x,uint8_t * y,const union xnn_qu8_lrelu_params params[restrict XNN_MIN_ELEMENTS (1)])7902 void xnn_qu8_vlrelu_ukernel__avx2_x32(
7903 size_t n,
7904 const uint8_t* x,
7905 uint8_t* y,
7906 const union xnn_qu8_lrelu_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
7907 {
7908 assert(n != 0);
7909 assert(n % sizeof(uint8_t) == 0);
7910 assert(x != NULL);
7911 assert(y != NULL);
7912
7913 const __m256i vinput_zero_point = _mm256_load_si256((const __m256i*) params->avx2.input_zero_point);
7914 const __m256i vpositive_multiplier = _mm256_load_si256((const __m256i*) params->avx2.positive_multiplier);
7915 const __m256i vnegative_multiplier = _mm256_load_si256((const __m256i*) params->avx2.negative_multiplier);
7916 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->avx2.output_zero_point);
7917 for (; n >= 32 * sizeof(uint8_t); n -= 32 * sizeof(uint8_t)) {
7918 __m256i vacc0 = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i*) x));
7919 __m256i vacc1 = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i*) (x + 16)));
7920 x += 32;
7921
7922 __m256i vmultiplier0 = _mm256_cmpgt_epi16(vacc0, vinput_zero_point);
7923 vacc0 = _mm256_sub_epi16(vinput_zero_point, vacc0);
7924 __m256i vmultiplier1 = _mm256_cmpgt_epi16(vacc1, vinput_zero_point);
7925 vacc1 = _mm256_sub_epi16(vinput_zero_point, vacc1);
7926
7927 vmultiplier0 = _mm256_blendv_epi8(vnegative_multiplier, vpositive_multiplier, vmultiplier0);
7928 vacc0 = _mm256_slli_epi16(vacc0, 7);
7929 vmultiplier1 = _mm256_blendv_epi8(vnegative_multiplier, vpositive_multiplier, vmultiplier1);
7930 vacc1 = _mm256_slli_epi16(vacc1, 7);
7931
7932 vacc0 = _mm256_mulhrs_epi16(vacc0, vmultiplier0);
7933 vacc1 = _mm256_mulhrs_epi16(vacc1, vmultiplier1);
7934
7935 vacc0 = _mm256_adds_epi16(vacc0, voutput_zero_point);
7936 vacc1 = _mm256_adds_epi16(vacc1, voutput_zero_point);
7937
7938 __m256i vy0 = _mm256_packus_epi16(vacc0, vacc1);
7939
7940 vy0 = _mm256_permute4x64_epi64(vy0, _MM_SHUFFLE(3, 1, 2, 0));
7941
7942 _mm256_storeu_si256((__m256i*) y, vy0);
7943 y += 32;
7944 }
7945 for (; n >= 16 * sizeof(uint8_t); n -= 16 * sizeof(uint8_t)) {
7946 __m256i vacc = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i*) x));
7947 __m256i vmultiplier = _mm256_cmpgt_epi16(vacc, vinput_zero_point);
7948 vacc = _mm256_sub_epi16(vinput_zero_point, vacc);
7949 vmultiplier = _mm256_blendv_epi8(vnegative_multiplier, vpositive_multiplier, vmultiplier);
7950 vacc = _mm256_slli_epi16(vacc, 7);
7951 vacc = _mm256_mulhrs_epi16(vacc, vmultiplier);
7952 vacc = _mm256_adds_epi16(vacc, voutput_zero_point);
7953 x += 16;
7954
7955 const __m128i vacc_hi = _mm256_extracti128_si256(vacc, 1);
7956 const __m128i vy = _mm_packus_epi16(_mm256_castsi256_si128(vacc), vacc_hi);
7957 _mm_storeu_si128((__m128i*) y, vy);
7958 y += 16;
7959 }
7960 if XNN_UNLIKELY(n != 0) {
7961 assert(n >= 1 * sizeof(uint8_t));
7962 assert(n <= 15 * sizeof(uint8_t));
7963
7964 __m256i vacc = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i*) x));
7965 __m256i vmultiplier = _mm256_cmpgt_epi16(vacc, vinput_zero_point);
7966 vacc = _mm256_sub_epi16(vinput_zero_point, vacc);
7967 vmultiplier = _mm256_blendv_epi8(vnegative_multiplier, vpositive_multiplier, vmultiplier);
7968 vacc = _mm256_slli_epi16(vacc, 7);
7969 vacc = _mm256_mulhrs_epi16(vacc, vmultiplier);
7970 vacc = _mm256_adds_epi16(vacc, voutput_zero_point);
7971
7972 const __m128i vacc_hi = _mm256_extracti128_si256(vacc, 1);
7973 __m128i vy = _mm_packus_epi16(_mm256_castsi256_si128(vacc), vacc_hi);
7974 if (n & (8 * sizeof(uint8_t))) {
7975 _mm_storel_epi64((__m128i*) y, vy);
7976 vy = _mm_unpackhi_epi64(vy, vy);
7977 y += 8;
7978 }
7979 if (n & (4 * sizeof(uint8_t))) {
7980 _mm_storeu_si32(y, vy);
7981 vy = _mm_srli_epi64(vy, 32);
7982 y += 4;
7983 }
7984 if (n & (2 * sizeof(uint8_t))) {
7985 _mm_storeu_si16(y, vy);
7986 vy = _mm_srli_epi32(vy, 16);
7987 y += 2;
7988 }
7989 if (n & (1 * sizeof(uint8_t))) {
7990 *y = (uint8_t) _mm_extract_epi8(vy, 0);
7991 }
7992 }
7993 }
7994
xnn_x8_lut_ukernel__avx2_x128(size_t n,const uint8_t * x,uint8_t * y,const uint8_t t[restrict XNN_MIN_ELEMENTS (256)])7995 void xnn_x8_lut_ukernel__avx2_x128(
7996 size_t n,
7997 const uint8_t* x,
7998 uint8_t* y,
7999 const uint8_t t[restrict XNN_MIN_ELEMENTS(256)])
8000 {
8001 assert(n != 0);
8002 assert(x != NULL);
8003 assert(y != NULL);
8004
8005 const __m256i vt0 = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) t));
8006 const __m256i vt1 = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) (t + 16)));
8007 const __m256i vt2 = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) (t + 32)));
8008 const __m256i vt3 = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) (t + 48)));
8009 const __m256i vt4 = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) (t + 64)));
8010 const __m256i vt5 = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) (t + 80)));
8011 const __m256i vt6 = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) (t + 96)));
8012 const __m256i vt7 = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) (t + 112)));
8013 const __m256i vt8 = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) (t + 128)));
8014 const __m256i vt9 = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) (t + 144)));
8015 const __m256i vtA = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) (t + 160)));
8016 const __m256i vtB = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) (t + 176)));
8017 const __m256i vtC = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) (t + 192)));
8018 const __m256i vtD = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) (t + 208)));
8019 const __m256i vtE = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) (t + 224)));
8020 const __m256i vtF = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i*) (t + 240)));
8021
8022 const __m256i vtable0 = vt0;
8023 const __m256i vtable1 = _mm256_xor_si256(vt0, vt1);
8024 const __m256i vtable2 = _mm256_xor_si256(vt1, vt2);
8025 const __m256i vtable3 = _mm256_xor_si256(vt2, vt3);
8026 const __m256i vtable4 = _mm256_xor_si256(vt3, vt4);
8027 const __m256i vtable5 = _mm256_xor_si256(vt4, vt5);
8028 const __m256i vtable6 = _mm256_xor_si256(vt5, vt6);
8029 const __m256i vtable7 = _mm256_xor_si256(vt6, vt7);
8030 const __m256i vtable8 = _mm256_xor_si256(_mm256_xor_si256(vt7, vt8), vtable0);
8031 const __m256i vtable9 = _mm256_xor_si256(_mm256_xor_si256(vt8, vt9), vtable1);
8032 const __m256i vtableA = _mm256_xor_si256(_mm256_xor_si256(vt9, vtA), vtable2);
8033 const __m256i vtableB = _mm256_xor_si256(_mm256_xor_si256(vtA, vtB), vtable3);
8034 const __m256i vtableC = _mm256_xor_si256(_mm256_xor_si256(vtB, vtC), vtable4);
8035 const __m256i vtableD = _mm256_xor_si256(_mm256_xor_si256(vtC, vtD), vtable5);
8036 const __m256i vtableE = _mm256_xor_si256(_mm256_xor_si256(vtD, vtE), vtable6);
8037 const __m256i vtableF = _mm256_xor_si256(_mm256_xor_si256(vtE, vtF), vtable7);
8038
8039 const __m256i voffset = _mm256_set1_epi8(16);
8040 for (; n >= 128 * sizeof(uint8_t); n -= 128 * sizeof(uint8_t)) {
8041 __m256i vx0 = _mm256_loadu_si256((const __m256i*) x);
8042 __m256i vx1 = _mm256_loadu_si256((const __m256i*) (x + 32));
8043 __m256i vx2 = _mm256_loadu_si256((const __m256i*) (x + 64));
8044 __m256i vx3 = _mm256_loadu_si256((const __m256i*) (x + 96));
8045 x += 128;
8046
8047 __m256i vy0 = _mm256_shuffle_epi8(vtable0, vx0);
8048 __m256i vy1 = _mm256_shuffle_epi8(vtable0, vx1);
8049 __m256i vy2 = _mm256_shuffle_epi8(vtable0, vx2);
8050 __m256i vy3 = _mm256_shuffle_epi8(vtable0, vx3);
8051
8052 vx0 = _mm256_sub_epi8(vx0, voffset);
8053 vx1 = _mm256_sub_epi8(vx1, voffset);
8054 vx2 = _mm256_sub_epi8(vx2, voffset);
8055 vx3 = _mm256_sub_epi8(vx3, voffset);
8056 vy0 = _mm256_xor_si256(vy0, _mm256_shuffle_epi8(vtable1, vx0));
8057 vy1 = _mm256_xor_si256(vy1, _mm256_shuffle_epi8(vtable1, vx1));
8058 vy2 = _mm256_xor_si256(vy2, _mm256_shuffle_epi8(vtable1, vx2));
8059 vy3 = _mm256_xor_si256(vy3, _mm256_shuffle_epi8(vtable1, vx3));
8060 vx0 = _mm256_sub_epi8(vx0, voffset);
8061 vx1 = _mm256_sub_epi8(vx1, voffset);
8062 vx2 = _mm256_sub_epi8(vx2, voffset);
8063 vx3 = _mm256_sub_epi8(vx3, voffset);
8064 vy0 = _mm256_xor_si256(vy0, _mm256_shuffle_epi8(vtable2, vx0));
8065 vy1 = _mm256_xor_si256(vy1, _mm256_shuffle_epi8(vtable2, vx1));
8066 vy2 = _mm256_xor_si256(vy2, _mm256_shuffle_epi8(vtable2, vx2));
8067 vy3 = _mm256_xor_si256(vy3, _mm256_shuffle_epi8(vtable2, vx3));
8068 vx0 = _mm256_sub_epi8(vx0, voffset);
8069 vx1 = _mm256_sub_epi8(vx1, voffset);
8070 vx2 = _mm256_sub_epi8(vx2, voffset);
8071 vx3 = _mm256_sub_epi8(vx3, voffset);
8072 vy0 = _mm256_xor_si256(vy0, _mm256_shuffle_epi8(vtable3, vx0));
8073 vy1 = _mm256_xor_si256(vy1, _mm256_shuffle_epi8(vtable3, vx1));
8074 vy2 = _mm256_xor_si256(vy2, _mm256_shuffle_epi8(vtable3, vx2));
8075 vy3 = _mm256_xor_si256(vy3, _mm256_shuffle_epi8(vtable3, vx3));
8076 vx0 = _mm256_sub_epi8(vx0, voffset);
8077 vx1 = _mm256_sub_epi8(vx1, voffset);
8078 vx2 = _mm256_sub_epi8(vx2, voffset);
8079 vx3 = _mm256_sub_epi8(vx3, voffset);
8080 vy0 = _mm256_xor_si256(vy0, _mm256_shuffle_epi8(vtable4, vx0));
8081 vy1 = _mm256_xor_si256(vy1, _mm256_shuffle_epi8(vtable4, vx1));
8082 vy2 = _mm256_xor_si256(vy2, _mm256_shuffle_epi8(vtable4, vx2));
8083 vy3 = _mm256_xor_si256(vy3, _mm256_shuffle_epi8(vtable4, vx3));
8084 vx0 = _mm256_sub_epi8(vx0, voffset);
8085 vx1 = _mm256_sub_epi8(vx1, voffset);
8086 vx2 = _mm256_sub_epi8(vx2, voffset);
8087 vx3 = _mm256_sub_epi8(vx3, voffset);
8088 vy0 = _mm256_xor_si256(vy0, _mm256_shuffle_epi8(vtable5, vx0));
8089 vy1 = _mm256_xor_si256(vy1, _mm256_shuffle_epi8(vtable5, vx1));
8090 vy2 = _mm256_xor_si256(vy2, _mm256_shuffle_epi8(vtable5, vx2));
8091 vy3 = _mm256_xor_si256(vy3, _mm256_shuffle_epi8(vtable5, vx3));
8092 vx0 = _mm256_sub_epi8(vx0, voffset);
8093 vx1 = _mm256_sub_epi8(vx1, voffset);
8094 vx2 = _mm256_sub_epi8(vx2, voffset);
8095 vx3 = _mm256_sub_epi8(vx3, voffset);
8096 vy0 = _mm256_xor_si256(vy0, _mm256_shuffle_epi8(vtable6, vx0));
8097 vy1 = _mm256_xor_si256(vy1, _mm256_shuffle_epi8(vtable6, vx1));
8098 vy2 = _mm256_xor_si256(vy2, _mm256_shuffle_epi8(vtable6, vx2));
8099 vy3 = _mm256_xor_si256(vy3, _mm256_shuffle_epi8(vtable6, vx3));
8100 vx0 = _mm256_sub_epi8(vx0, voffset);
8101 vx1 = _mm256_sub_epi8(vx1, voffset);
8102 vx2 = _mm256_sub_epi8(vx2, voffset);
8103 vx3 = _mm256_sub_epi8(vx3, voffset);
8104 vy0 = _mm256_xor_si256(vy0, _mm256_shuffle_epi8(vtable7, vx0));
8105 vy1 = _mm256_xor_si256(vy1, _mm256_shuffle_epi8(vtable7, vx1));
8106 vy2 = _mm256_xor_si256(vy2, _mm256_shuffle_epi8(vtable7, vx2));
8107 vy3 = _mm256_xor_si256(vy3, _mm256_shuffle_epi8(vtable7, vx3));
8108 vx0 = _mm256_sub_epi8(vx0, voffset);
8109 vx1 = _mm256_sub_epi8(vx1, voffset);
8110 vx2 = _mm256_sub_epi8(vx2, voffset);
8111 vx3 = _mm256_sub_epi8(vx3, voffset);
8112 vy0 = _mm256_xor_si256(vy0, _mm256_shuffle_epi8(vtable8, vx0));
8113 vy1 = _mm256_xor_si256(vy1, _mm256_shuffle_epi8(vtable8, vx1));
8114 vy2 = _mm256_xor_si256(vy2, _mm256_shuffle_epi8(vtable8, vx2));
8115 vy3 = _mm256_xor_si256(vy3, _mm256_shuffle_epi8(vtable8, vx3));
8116
8117 vx0 = _mm256_subs_epi8(vx0, voffset);
8118 vx1 = _mm256_subs_epi8(vx1, voffset);
8119 vx2 = _mm256_subs_epi8(vx2, voffset);
8120 vx3 = _mm256_subs_epi8(vx3, voffset);
8121 vy0 = _mm256_xor_si256(vy0, _mm256_shuffle_epi8(vtable9, vx0));
8122 vy1 = _mm256_xor_si256(vy1, _mm256_shuffle_epi8(vtable9, vx1));
8123 vy2 = _mm256_xor_si256(vy2, _mm256_shuffle_epi8(vtable9, vx2));
8124 vy3 = _mm256_xor_si256(vy3, _mm256_shuffle_epi8(vtable9, vx3));
8125 vx0 = _mm256_subs_epi8(vx0, voffset);
8126 vx1 = _mm256_subs_epi8(vx1, voffset);
8127 vx2 = _mm256_subs_epi8(vx2, voffset);
8128 vx3 = _mm256_subs_epi8(vx3, voffset);
8129 vy0 = _mm256_xor_si256(vy0, _mm256_shuffle_epi8(vtableA, vx0));
8130 vy1 = _mm256_xor_si256(vy1, _mm256_shuffle_epi8(vtableA, vx1));
8131 vy2 = _mm256_xor_si256(vy2, _mm256_shuffle_epi8(vtableA, vx2));
8132 vy3 = _mm256_xor_si256(vy3, _mm256_shuffle_epi8(vtableA, vx3));
8133 vx0 = _mm256_subs_epi8(vx0, voffset);
8134 vx1 = _mm256_subs_epi8(vx1, voffset);
8135 vx2 = _mm256_subs_epi8(vx2, voffset);
8136 vx3 = _mm256_subs_epi8(vx3, voffset);
8137 vy0 = _mm256_xor_si256(vy0, _mm256_shuffle_epi8(vtableB, vx0));
8138 vy1 = _mm256_xor_si256(vy1, _mm256_shuffle_epi8(vtableB, vx1));
8139 vy2 = _mm256_xor_si256(vy2, _mm256_shuffle_epi8(vtableB, vx2));
8140 vy3 = _mm256_xor_si256(vy3, _mm256_shuffle_epi8(vtableB, vx3));
8141 vx0 = _mm256_subs_epi8(vx0, voffset);
8142 vx1 = _mm256_subs_epi8(vx1, voffset);
8143 vx2 = _mm256_subs_epi8(vx2, voffset);
8144 vx3 = _mm256_subs_epi8(vx3, voffset);
8145 vy0 = _mm256_xor_si256(vy0, _mm256_shuffle_epi8(vtableC, vx0));
8146 vy1 = _mm256_xor_si256(vy1, _mm256_shuffle_epi8(vtableC, vx1));
8147 vy2 = _mm256_xor_si256(vy2, _mm256_shuffle_epi8(vtableC, vx2));
8148 vy3 = _mm256_xor_si256(vy3, _mm256_shuffle_epi8(vtableC, vx3));
8149 vx0 = _mm256_subs_epi8(vx0, voffset);
8150 vx1 = _mm256_subs_epi8(vx1, voffset);
8151 vx2 = _mm256_subs_epi8(vx2, voffset);
8152 vx3 = _mm256_subs_epi8(vx3, voffset);
8153 vy0 = _mm256_xor_si256(vy0, _mm256_shuffle_epi8(vtableD, vx0));
8154 vy1 = _mm256_xor_si256(vy1, _mm256_shuffle_epi8(vtableD, vx1));
8155 vy2 = _mm256_xor_si256(vy2, _mm256_shuffle_epi8(vtableD, vx2));
8156 vy3 = _mm256_xor_si256(vy3, _mm256_shuffle_epi8(vtableD, vx3));
8157 vx0 = _mm256_subs_epi8(vx0, voffset);
8158 vx1 = _mm256_subs_epi8(vx1, voffset);
8159 vx2 = _mm256_subs_epi8(vx2, voffset);
8160 vx3 = _mm256_subs_epi8(vx3, voffset);
8161 vy0 = _mm256_xor_si256(vy0, _mm256_shuffle_epi8(vtableE, vx0));
8162 vy1 = _mm256_xor_si256(vy1, _mm256_shuffle_epi8(vtableE, vx1));
8163 vy2 = _mm256_xor_si256(vy2, _mm256_shuffle_epi8(vtableE, vx2));
8164 vy3 = _mm256_xor_si256(vy3, _mm256_shuffle_epi8(vtableE, vx3));
8165 vx0 = _mm256_subs_epi8(vx0, voffset);
8166 vx1 = _mm256_subs_epi8(vx1, voffset);
8167 vx2 = _mm256_subs_epi8(vx2, voffset);
8168 vx3 = _mm256_subs_epi8(vx3, voffset);
8169 vy0 = _mm256_xor_si256(vy0, _mm256_shuffle_epi8(vtableF, vx0));
8170 vy1 = _mm256_xor_si256(vy1, _mm256_shuffle_epi8(vtableF, vx1));
8171 vy2 = _mm256_xor_si256(vy2, _mm256_shuffle_epi8(vtableF, vx2));
8172 vy3 = _mm256_xor_si256(vy3, _mm256_shuffle_epi8(vtableF, vx3));
8173
8174 _mm256_storeu_si256((__m256i*) y, vy0);
8175 _mm256_storeu_si256((__m256i*) (y + 32), vy1);
8176 _mm256_storeu_si256((__m256i*) (y + 64), vy2);
8177 _mm256_storeu_si256((__m256i*) (y + 96), vy3);
8178 y += 128;
8179 }
8180 for (; n >= 16 * sizeof(uint8_t); n -= 16 * sizeof(uint8_t)) {
8181 __m128i vx = _mm_loadu_si128((const __m128i*) x);
8182 x += 16;
8183
8184 __m128i vy = _mm_shuffle_epi8(_mm256_castsi256_si128(vtable0), vx);
8185
8186 vx = _mm_sub_epi8(vx, _mm256_castsi256_si128(voffset));
8187 vy = _mm_xor_si128(vy, _mm_shuffle_epi8(_mm256_castsi256_si128(vtable1), vx));
8188 vx = _mm_sub_epi8(vx, _mm256_castsi256_si128(voffset));
8189 vy = _mm_xor_si128(vy, _mm_shuffle_epi8(_mm256_castsi256_si128(vtable2), vx));
8190 vx = _mm_sub_epi8(vx, _mm256_castsi256_si128(voffset));
8191 vy = _mm_xor_si128(vy, _mm_shuffle_epi8(_mm256_castsi256_si128(vtable3), vx));
8192 vx = _mm_sub_epi8(vx, _mm256_castsi256_si128(voffset));
8193 vy = _mm_xor_si128(vy, _mm_shuffle_epi8(_mm256_castsi256_si128(vtable4), vx));
8194 vx = _mm_sub_epi8(vx, _mm256_castsi256_si128(voffset));
8195 vy = _mm_xor_si128(vy, _mm_shuffle_epi8(_mm256_castsi256_si128(vtable5), vx));
8196 vx = _mm_sub_epi8(vx, _mm256_castsi256_si128(voffset));
8197 vy = _mm_xor_si128(vy, _mm_shuffle_epi8(_mm256_castsi256_si128(vtable6), vx));
8198 vx = _mm_sub_epi8(vx, _mm256_castsi256_si128(voffset));
8199 vy = _mm_xor_si128(vy, _mm_shuffle_epi8(_mm256_castsi256_si128(vtable7), vx));
8200 vx = _mm_sub_epi8(vx, _mm256_castsi256_si128(voffset));
8201 vy = _mm_xor_si128(vy, _mm_shuffle_epi8(_mm256_castsi256_si128(vtable8), vx));
8202
8203 vx = _mm_subs_epi8(vx, _mm256_castsi256_si128(voffset));
8204 vy = _mm_xor_si128(vy, _mm_shuffle_epi8(_mm256_castsi256_si128(vtable9), vx));
8205 vx = _mm_subs_epi8(vx, _mm256_castsi256_si128(voffset));
8206 vy = _mm_xor_si128(vy, _mm_shuffle_epi8(_mm256_castsi256_si128(vtableA), vx));
8207 vx = _mm_subs_epi8(vx, _mm256_castsi256_si128(voffset));
8208 vy = _mm_xor_si128(vy, _mm_shuffle_epi8(_mm256_castsi256_si128(vtableB), vx));
8209 vx = _mm_subs_epi8(vx, _mm256_castsi256_si128(voffset));
8210 vy = _mm_xor_si128(vy, _mm_shuffle_epi8(_mm256_castsi256_si128(vtableC), vx));
8211 vx = _mm_subs_epi8(vx, _mm256_castsi256_si128(voffset));
8212 vy = _mm_xor_si128(vy, _mm_shuffle_epi8(_mm256_castsi256_si128(vtableD), vx));
8213 vx = _mm_subs_epi8(vx, _mm256_castsi256_si128(voffset));
8214 vy = _mm_xor_si128(vy, _mm_shuffle_epi8(_mm256_castsi256_si128(vtableE), vx));
8215 vx = _mm_subs_epi8(vx, _mm256_castsi256_si128(voffset));
8216 vy = _mm_xor_si128(vy, _mm_shuffle_epi8(_mm256_castsi256_si128(vtableF), vx));
8217
8218 _mm_storeu_si128((__m128i*) y, vy);
8219 y += 16;
8220 }
8221 if XNN_UNLIKELY(n != 0) {
8222 __m128i vx = _mm_loadu_si128((const __m128i*) x);
8223
8224 __m128i vy = _mm_shuffle_epi8(_mm256_castsi256_si128(vtable0), vx);
8225
8226 vx = _mm_sub_epi8(vx, _mm256_castsi256_si128(voffset));
8227 vy = _mm_xor_si128(vy, _mm_shuffle_epi8(_mm256_castsi256_si128(vtable1), vx));
8228 vx = _mm_sub_epi8(vx, _mm256_castsi256_si128(voffset));
8229 vy = _mm_xor_si128(vy, _mm_shuffle_epi8(_mm256_castsi256_si128(vtable2), vx));
8230 vx = _mm_sub_epi8(vx, _mm256_castsi256_si128(voffset));
8231 vy = _mm_xor_si128(vy, _mm_shuffle_epi8(_mm256_castsi256_si128(vtable3), vx));
8232 vx = _mm_sub_epi8(vx, _mm256_castsi256_si128(voffset));
8233 vy = _mm_xor_si128(vy, _mm_shuffle_epi8(_mm256_castsi256_si128(vtable4), vx));
8234 vx = _mm_sub_epi8(vx, _mm256_castsi256_si128(voffset));
8235 vy = _mm_xor_si128(vy, _mm_shuffle_epi8(_mm256_castsi256_si128(vtable5), vx));
8236 vx = _mm_sub_epi8(vx, _mm256_castsi256_si128(voffset));
8237 vy = _mm_xor_si128(vy, _mm_shuffle_epi8(_mm256_castsi256_si128(vtable6), vx));
8238 vx = _mm_sub_epi8(vx, _mm256_castsi256_si128(voffset));
8239 vy = _mm_xor_si128(vy, _mm_shuffle_epi8(_mm256_castsi256_si128(vtable7), vx));
8240 vx = _mm_sub_epi8(vx, _mm256_castsi256_si128(voffset));
8241 vy = _mm_xor_si128(vy, _mm_shuffle_epi8(_mm256_castsi256_si128(vtable8), vx));
8242
8243 vx = _mm_subs_epi8(vx, _mm256_castsi256_si128(voffset));
8244 vy = _mm_xor_si128(vy, _mm_shuffle_epi8(_mm256_castsi256_si128(vtable9), vx));
8245 vx = _mm_subs_epi8(vx, _mm256_castsi256_si128(voffset));
8246 vy = _mm_xor_si128(vy, _mm_shuffle_epi8(_mm256_castsi256_si128(vtableA), vx));
8247 vx = _mm_subs_epi8(vx, _mm256_castsi256_si128(voffset));
8248 vy = _mm_xor_si128(vy, _mm_shuffle_epi8(_mm256_castsi256_si128(vtableB), vx));
8249 vx = _mm_subs_epi8(vx, _mm256_castsi256_si128(voffset));
8250 vy = _mm_xor_si128(vy, _mm_shuffle_epi8(_mm256_castsi256_si128(vtableC), vx));
8251 vx = _mm_subs_epi8(vx, _mm256_castsi256_si128(voffset));
8252 vy = _mm_xor_si128(vy, _mm_shuffle_epi8(_mm256_castsi256_si128(vtableD), vx));
8253 vx = _mm_subs_epi8(vx, _mm256_castsi256_si128(voffset));
8254 vy = _mm_xor_si128(vy, _mm_shuffle_epi8(_mm256_castsi256_si128(vtableE), vx));
8255 vx = _mm_subs_epi8(vx, _mm256_castsi256_si128(voffset));
8256 vy = _mm_xor_si128(vy, _mm_shuffle_epi8(_mm256_castsi256_si128(vtableF), vx));
8257
8258 if (n & (8 * sizeof(uint8_t))) {
8259 _mm_storel_epi64((__m128i*) y, vy);
8260 vy = _mm_unpackhi_epi64(vy, vy);
8261 y += 8;
8262 }
8263 if (n & (4 * sizeof(uint8_t))) {
8264 _mm_storeu_si32(y, vy);
8265 vy = _mm_srli_epi64(vy, 32);
8266 y += 4;
8267 }
8268 if (n & (2 * sizeof(uint8_t))) {
8269 _mm_storeu_si16(y, vy);
8270 vy = _mm_srli_epi32(vy, 16);
8271 y += 2;
8272 }
8273 if (n & (1 * sizeof(uint8_t))) {
8274 *y = (uint8_t) _mm_extract_epi8(vy, 0);
8275 }
8276 }
8277 }
8278