1 // Copyright 2021 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7
8 #include <immintrin.h>
9
10 #include <xnnpack/common.h>
11 #include <xnnpack/dwconv.h>
12 #include <xnnpack/gemm.h>
13 #include <xnnpack/igemm.h>
14 #include <xnnpack/intrinsics-polyfill.h>
15 #include <xnnpack/lut.h>
16 #include <xnnpack/math.h>
17 #include <xnnpack/vadd.h>
18 #include <xnnpack/vcvt.h>
19
20
xnn_f16_f32_vcvt_ukernel__avx512skx_x16(size_t n,const void * input,float * output,const union xnn_f16_f32_cvt_params params[restrict XNN_MIN_ELEMENTS (1)])21 void xnn_f16_f32_vcvt_ukernel__avx512skx_x16(
22 size_t n,
23 const void* input,
24 float* output,
25 const union xnn_f16_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)])
26 {
27 assert(n != 0);
28 assert(n % sizeof(uint16_t) == 0);
29 assert(input != NULL);
30 assert(output != NULL);
31
32 const uint16_t* i = (const uint16_t*) input;
33 for (; n >= 16 * sizeof(uint16_t); n -= 16 * sizeof(uint16_t)) {
34 const __m512 vacc = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i));
35 i += 16;
36
37 _mm512_storeu_ps(output, vacc);
38 output += 16;
39 }
40 if XNN_UNLIKELY(n != 0) {
41 assert(n >= 1 * sizeof(uint16_t));
42 assert(n <= 15 * sizeof(uint16_t));
43
44 // Prepare mask for valid 32-bit elements (depends on n).
45 n >>= 1 /* log2(sizeof(uint16_t)) */;
46 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
47
48 const __m512 vacc = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, i));
49
50 _mm512_mask_storeu_ps(output, vmask, vacc);
51 }
52 }
53
xnn_f32_f16_vcvt_ukernel__avx512skx_x16(size_t n,const float * input,void * output,const union xnn_f32_f16_cvt_params params[restrict XNN_MIN_ELEMENTS (1)])54 void xnn_f32_f16_vcvt_ukernel__avx512skx_x16(
55 size_t n,
56 const float* input,
57 void* output,
58 const union xnn_f32_f16_cvt_params params[restrict XNN_MIN_ELEMENTS(1)])
59 {
60 assert(n != 0);
61 assert(n % sizeof(float) == 0);
62 assert(input != NULL);
63 assert(output != NULL);
64
65 uint16_t* o = (uint16_t*) output;
66 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
67 const __m512 vf = _mm512_loadu_ps(input);
68 input += 16;
69
70 _mm256_storeu_si256((__m256i*) o, _mm512_cvtps_ph(vf, _MM_FROUND_NO_EXC));
71 o += 16;
72 }
73 if XNN_UNLIKELY(n != 0) {
74 assert(n >= 1 * sizeof(float));
75 assert(n <= 15 * sizeof(float));
76
77 // Prepare mask for valid elements (depends on n).
78 n >>= 2 /* log2(sizeof(float)) */;
79 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
80
81 const __m512 vf = _mm512_maskz_loadu_ps(vmask, input);
82 const __m256i vh = _mm512_cvtps_ph(vf, _MM_FROUND_NO_EXC);
83 _mm256_mask_storeu_epi16(o, vmask, vh);
84 }
85 }
86
xnn_f32_qs8_vcvt_ukernel__avx512skx_x128(size_t n,const float * x,int8_t * y,const union xnn_f32_qs8_cvt_params params[restrict XNN_MIN_ELEMENTS (1)])87 void xnn_f32_qs8_vcvt_ukernel__avx512skx_x128(
88 size_t n,
89 const float* x,
90 int8_t* y,
91 const union xnn_f32_qs8_cvt_params params[restrict XNN_MIN_ELEMENTS(1)])
92 {
93 assert(n != 0);
94 assert(n % sizeof(float) == 0);
95 assert(x != NULL);
96 assert(y != NULL);
97
98 const __m512 vscale = _mm512_load_ps(params->avx2.scale);
99 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->avx512.output_max_less_zero_point);
100 const __m512i voutput_zero_point = _mm512_load_si512(params->avx512.output_zero_point);
101 const __m512i vshuffle512_mask = _mm512_load_si512(params->avx512.shuffle512_mask);
102 const __m512i voutput_min = _mm512_load_si512(params->avx512.output_min);
103 for (; n >= 128 * sizeof(float); n -= 128 * sizeof(float)) {
104 __m512 vx0123 = _mm512_loadu_ps(x);
105 __m512 vx4567 = _mm512_loadu_ps(x + 16);
106 __m512 vx89AB = _mm512_loadu_ps(x + 32);
107 __m512 vxCDEF = _mm512_loadu_ps(x + 48);
108 __m512 vxGHIJ = _mm512_loadu_ps(x + 64);
109 __m512 vxKLMN = _mm512_loadu_ps(x + 80);
110 __m512 vxOPQR = _mm512_loadu_ps(x + 96);
111 __m512 vxSTUV = _mm512_loadu_ps(x + 112);
112 x += 128;
113
114 vx0123 = _mm512_mul_ps(vx0123, vscale);
115 vx4567 = _mm512_mul_ps(vx4567, vscale);
116 vx89AB = _mm512_mul_ps(vx89AB, vscale);
117 vxCDEF = _mm512_mul_ps(vxCDEF, vscale);
118 vxGHIJ = _mm512_mul_ps(vxGHIJ, vscale);
119 vxKLMN = _mm512_mul_ps(vxKLMN, vscale);
120 vxOPQR = _mm512_mul_ps(vxOPQR, vscale);
121 vxSTUV = _mm512_mul_ps(vxSTUV, vscale);
122
123 vx0123 = _mm512_min_ps(vx0123, voutput_max_less_zero_point);
124 vx4567 = _mm512_min_ps(vx4567, voutput_max_less_zero_point);
125 vx89AB = _mm512_min_ps(vx89AB, voutput_max_less_zero_point);
126 vxCDEF = _mm512_min_ps(vxCDEF, voutput_max_less_zero_point);
127 vxGHIJ = _mm512_min_ps(vxGHIJ, voutput_max_less_zero_point);
128 vxKLMN = _mm512_min_ps(vxKLMN, voutput_max_less_zero_point);
129 vxOPQR = _mm512_min_ps(vxOPQR, voutput_max_less_zero_point);
130 vxSTUV = _mm512_min_ps(vxSTUV, voutput_max_less_zero_point);
131
132 const __m512i vacc0123 = _mm512_cvtps_epi32(vx0123);
133 const __m512i vacc4567 = _mm512_cvtps_epi32(vx4567);
134 const __m512i vacc89AB = _mm512_cvtps_epi32(vx89AB);
135 const __m512i vaccCDEF = _mm512_cvtps_epi32(vxCDEF);
136 const __m512i vaccGHIJ = _mm512_cvtps_epi32(vxGHIJ);
137 const __m512i vaccKLMN = _mm512_cvtps_epi32(vxKLMN);
138 const __m512i vaccOPQR = _mm512_cvtps_epi32(vxOPQR);
139 const __m512i vaccSTUV = _mm512_cvtps_epi32(vxSTUV);
140
141 __m512i vacc04152637 = _mm512_packs_epi32(vacc0123, vacc4567);
142 __m512i vacc8C9DAEBF = _mm512_packs_epi32(vacc89AB, vaccCDEF);
143 __m512i vaccGKHLIMJN = _mm512_packs_epi32(vaccGHIJ, vaccKLMN);
144 __m512i vaccOSPTQURV = _mm512_packs_epi32(vaccOPQR, vaccSTUV);
145
146 vacc04152637 = _mm512_adds_epi16(vacc04152637, voutput_zero_point);
147 vacc8C9DAEBF = _mm512_adds_epi16(vacc8C9DAEBF, voutput_zero_point);
148 vaccGKHLIMJN = _mm512_adds_epi16(vaccGKHLIMJN, voutput_zero_point);
149 vaccOSPTQURV = _mm512_adds_epi16(vaccOSPTQURV, voutput_zero_point);
150
151 __m512i vy048C159D26AE37BF = _mm512_packs_epi16(vacc04152637, vacc8C9DAEBF);
152 __m512i vyGKOSHLPTIMQUJNRV = _mm512_packs_epi16(vaccGKHLIMJN, vaccOSPTQURV);
153
154 vy048C159D26AE37BF = _mm512_max_epi8(vy048C159D26AE37BF, voutput_min);
155 vyGKOSHLPTIMQUJNRV = _mm512_max_epi8(vyGKOSHLPTIMQUJNRV, voutput_min);
156
157 const __m512i vy0123456789ABCDEF = _mm512_permutexvar_epi32(vshuffle512_mask, vy048C159D26AE37BF);
158 const __m512i vyGHIJKLMNOPQRSTUV = _mm512_permutexvar_epi32(vshuffle512_mask, vyGKOSHLPTIMQUJNRV);
159
160 _mm512_storeu_si512(y, vy0123456789ABCDEF);
161 _mm512_storeu_si512(y + 64, vyGHIJKLMNOPQRSTUV);
162 y += 128;
163 }
164 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
165 __m512 vx0123 = _mm512_loadu_ps(x);
166 vx0123 = _mm512_mul_ps(vx0123, vscale);
167 vx0123 = _mm512_min_ps(vx0123, voutput_max_less_zero_point);
168 x += 16;
169
170 const __m512i vacc0123 = _mm512_cvtps_epi32(vx0123);
171
172 __m256i vacc0213 = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0123), _mm512_extracti32x8_epi32(vacc0123, 1));
173 vacc0213 = _mm256_adds_epi16(vacc0213, _mm512_castsi512_si256(voutput_zero_point));
174 const __m128i vy0213 = _mm_packs_epi16(_mm256_castsi256_si128(vacc0213), _mm256_extracti128_si256(vacc0213, 1));
175 __m128i vy0123 = _mm_shuffle_epi32(vy0213, _MM_SHUFFLE(3, 1, 2, 0));
176 vy0123 = _mm_max_epi8(vy0123, _mm512_castsi512_si128(voutput_min));
177
178 _mm_storeu_si128((__m128i*) y, vy0123);
179 y += 16;
180 }
181 if XNN_UNLIKELY(n != 0) {
182 assert(n >= 1 * sizeof(float));
183 assert(n <= 15 * sizeof(float));
184
185 // Prepare mask for valid elements (depends on n).
186 n >>= 2 /* log2(sizeof(float)) */;
187 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
188
189 __m512 vx0123 = _mm512_maskz_loadu_ps(vmask, x);
190 vx0123 = _mm512_mul_ps(vx0123, vscale);
191 vx0123 = _mm512_min_ps(vx0123, voutput_max_less_zero_point);
192
193 const __m512i vacc0123 = _mm512_cvtps_epi32(vx0123);
194
195 __m256i vacc0213 = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0123), _mm512_extracti32x8_epi32(vacc0123, 1));
196 vacc0213 = _mm256_adds_epi16(vacc0213, _mm512_castsi512_si256(voutput_zero_point));
197 const __m128i vy0213 = _mm_packs_epi16(_mm256_castsi256_si128(vacc0213), _mm256_extracti128_si256(vacc0213, 1));
198 __m128i vy0123 = _mm_shuffle_epi32(vy0213, _MM_SHUFFLE(3, 1, 2, 0));
199 vy0123 = _mm_max_epi8(vy0123, _mm512_castsi512_si128(voutput_min));
200
201 _mm_mask_storeu_epi8(y, vmask, vy0123);
202 }
203 }
204
xnn_f32_qu8_vcvt_ukernel__avx512skx_x128(size_t n,const float * x,uint8_t * y,const union xnn_f32_qu8_cvt_params params[restrict XNN_MIN_ELEMENTS (1)])205 void xnn_f32_qu8_vcvt_ukernel__avx512skx_x128(
206 size_t n,
207 const float* x,
208 uint8_t* y,
209 const union xnn_f32_qu8_cvt_params params[restrict XNN_MIN_ELEMENTS(1)])
210 {
211 assert(n != 0);
212 assert(n % sizeof(float) == 0);
213 assert(x != NULL);
214 assert(y != NULL);
215
216 const __m512 vscale = _mm512_load_ps(params->avx2.scale);
217 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->avx512.output_max_less_zero_point);
218 const __m512i voutput_zero_point = _mm512_load_si512(params->avx512.output_zero_point);
219 const __m512i vshuffle512_mask = _mm512_load_si512(params->avx512.shuffle512_mask);
220 const __m512i voutput_min = _mm512_load_si512(params->avx512.output_min);
221 for (; n >= 128 * sizeof(float); n -= 128 * sizeof(float)) {
222 __m512 vx0123 = _mm512_loadu_ps(x);
223 __m512 vx4567 = _mm512_loadu_ps(x + 16);
224 __m512 vx89AB = _mm512_loadu_ps(x + 32);
225 __m512 vxCDEF = _mm512_loadu_ps(x + 48);
226 __m512 vxGHIJ = _mm512_loadu_ps(x + 64);
227 __m512 vxKLMN = _mm512_loadu_ps(x + 80);
228 __m512 vxOPQR = _mm512_loadu_ps(x + 96);
229 __m512 vxSTUV = _mm512_loadu_ps(x + 112);
230 x += 128;
231
232 vx0123 = _mm512_mul_ps(vx0123, vscale);
233 vx4567 = _mm512_mul_ps(vx4567, vscale);
234 vx89AB = _mm512_mul_ps(vx89AB, vscale);
235 vxCDEF = _mm512_mul_ps(vxCDEF, vscale);
236 vxGHIJ = _mm512_mul_ps(vxGHIJ, vscale);
237 vxKLMN = _mm512_mul_ps(vxKLMN, vscale);
238 vxOPQR = _mm512_mul_ps(vxOPQR, vscale);
239 vxSTUV = _mm512_mul_ps(vxSTUV, vscale);
240
241 vx0123 = _mm512_min_ps(vx0123, voutput_max_less_zero_point);
242 vx4567 = _mm512_min_ps(vx4567, voutput_max_less_zero_point);
243 vx89AB = _mm512_min_ps(vx89AB, voutput_max_less_zero_point);
244 vxCDEF = _mm512_min_ps(vxCDEF, voutput_max_less_zero_point);
245 vxGHIJ = _mm512_min_ps(vxGHIJ, voutput_max_less_zero_point);
246 vxKLMN = _mm512_min_ps(vxKLMN, voutput_max_less_zero_point);
247 vxOPQR = _mm512_min_ps(vxOPQR, voutput_max_less_zero_point);
248 vxSTUV = _mm512_min_ps(vxSTUV, voutput_max_less_zero_point);
249
250 const __m512i vacc0123 = _mm512_cvtps_epi32(vx0123);
251 const __m512i vacc4567 = _mm512_cvtps_epi32(vx4567);
252 const __m512i vacc89AB = _mm512_cvtps_epi32(vx89AB);
253 const __m512i vaccCDEF = _mm512_cvtps_epi32(vxCDEF);
254 const __m512i vaccGHIJ = _mm512_cvtps_epi32(vxGHIJ);
255 const __m512i vaccKLMN = _mm512_cvtps_epi32(vxKLMN);
256 const __m512i vaccOPQR = _mm512_cvtps_epi32(vxOPQR);
257 const __m512i vaccSTUV = _mm512_cvtps_epi32(vxSTUV);
258
259 __m512i vacc04152637 = _mm512_packs_epi32(vacc0123, vacc4567);
260 __m512i vacc8C9DAEBF = _mm512_packs_epi32(vacc89AB, vaccCDEF);
261 __m512i vaccGKHLIMJN = _mm512_packs_epi32(vaccGHIJ, vaccKLMN);
262 __m512i vaccOSPTQURV = _mm512_packs_epi32(vaccOPQR, vaccSTUV);
263
264 vacc04152637 = _mm512_adds_epi16(vacc04152637, voutput_zero_point);
265 vacc8C9DAEBF = _mm512_adds_epi16(vacc8C9DAEBF, voutput_zero_point);
266 vaccGKHLIMJN = _mm512_adds_epi16(vaccGKHLIMJN, voutput_zero_point);
267 vaccOSPTQURV = _mm512_adds_epi16(vaccOSPTQURV, voutput_zero_point);
268
269 __m512i vy048C159D26AE37BF = _mm512_packus_epi16(vacc04152637, vacc8C9DAEBF);
270 __m512i vyGKOSHLPTIMQUJNRV = _mm512_packus_epi16(vaccGKHLIMJN, vaccOSPTQURV);
271
272 vy048C159D26AE37BF = _mm512_max_epu8(vy048C159D26AE37BF, voutput_min);
273 vyGKOSHLPTIMQUJNRV = _mm512_max_epu8(vyGKOSHLPTIMQUJNRV, voutput_min);
274
275 const __m512i vy0123456789ABCDEF = _mm512_permutexvar_epi32(vshuffle512_mask, vy048C159D26AE37BF);
276 const __m512i vyGHIJKLMNOPQRSTUV = _mm512_permutexvar_epi32(vshuffle512_mask, vyGKOSHLPTIMQUJNRV);
277
278 _mm512_storeu_si512(y, vy0123456789ABCDEF);
279 _mm512_storeu_si512(y + 64, vyGHIJKLMNOPQRSTUV);
280 y += 128;
281 }
282 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
283 __m512 vx0123 = _mm512_loadu_ps(x);
284 vx0123 = _mm512_mul_ps(vx0123, vscale);
285 vx0123 = _mm512_min_ps(vx0123, voutput_max_less_zero_point);
286 x += 16;
287
288 const __m512i vacc0123 = _mm512_cvtps_epi32(vx0123);
289
290 __m256i vacc0213 = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0123), _mm512_extracti32x8_epi32(vacc0123, 1));
291 vacc0213 = _mm256_adds_epi16(vacc0213, _mm512_castsi512_si256(voutput_zero_point));
292 const __m128i vy0213 = _mm_packus_epi16(_mm256_castsi256_si128(vacc0213), _mm256_extracti128_si256(vacc0213, 1));
293 __m128i vy0123 = _mm_shuffle_epi32(vy0213, _MM_SHUFFLE(3, 1, 2, 0));
294 vy0123 = _mm_max_epu8(vy0123, _mm512_castsi512_si128(voutput_min));
295
296 _mm_storeu_si128((__m128i*) y, vy0123);
297 y += 16;
298 }
299 if XNN_UNLIKELY(n != 0) {
300 assert(n >= 1 * sizeof(float));
301 assert(n <= 15 * sizeof(float));
302
303 // Prepare mask for valid elements (depends on n).
304 n >>= 2 /* log2(sizeof(float)) */;
305 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
306
307 __m512 vx0123 = _mm512_maskz_loadu_ps(vmask, x);
308 vx0123 = _mm512_mul_ps(vx0123, vscale);
309 vx0123 = _mm512_min_ps(vx0123, voutput_max_less_zero_point);
310
311 const __m512i vacc0123 = _mm512_cvtps_epi32(vx0123);
312
313 __m256i vacc0213 = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0123), _mm512_extracti32x8_epi32(vacc0123, 1));
314 vacc0213 = _mm256_adds_epi16(vacc0213, _mm512_castsi512_si256(voutput_zero_point));
315 const __m128i vy0213 = _mm_packus_epi16(_mm256_castsi256_si128(vacc0213), _mm256_extracti128_si256(vacc0213, 1));
316 __m128i vy0123 = _mm_shuffle_epi32(vy0213, _MM_SHUFFLE(3, 1, 2, 0));
317 vy0123 = _mm_max_epu8(vy0123, _mm512_castsi512_si128(voutput_min));
318
319 _mm_mask_storeu_epi8(y, vmask, vy0123);
320 }
321 }
322
xnn_qc8_dwconv_minmax_fp32_ukernel_up32x25__avx512skx_mul32(size_t channels,size_t output_width,const int8_t ** input,const void * weights,int8_t * output,size_t input_stride,size_t output_increment,size_t input_offset,const int8_t * zero,const union xnn_qc8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])323 void xnn_qc8_dwconv_minmax_fp32_ukernel_up32x25__avx512skx_mul32(
324 size_t channels,
325 size_t output_width,
326 const int8_t** input,
327 const void* weights,
328 int8_t* output,
329 size_t input_stride,
330 size_t output_increment,
331 size_t input_offset,
332 const int8_t* zero,
333 const union xnn_qc8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_MSAN
334 {
335 assert(channels != 0);
336 assert(output_width != 0);
337
338 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
339 const __m512i voutput_zero_point = _mm512_load_si512(params->fp32_avx512.output_zero_point);
340 const __m256i voutput_min = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_min);
341 const __m256i vpermute_mask = _mm256_set_epi32(7, 3, 5, 1, 6, 2, 4, 0);
342
343 do {
344 const int8_t* i0 = input[0];
345 assert(i0 != NULL);
346 if XNN_UNPREDICTABLE(i0 != zero) {
347 i0 = (const int8_t*) ((uintptr_t) i0 + input_offset);
348 }
349 const int8_t* i1 = input[1];
350 assert(i1 != NULL);
351 if XNN_UNPREDICTABLE(i1 != zero) {
352 i1 = (const int8_t*) ((uintptr_t) i1 + input_offset);
353 }
354 const int8_t* i2 = input[2];
355 assert(i2 != NULL);
356 if XNN_UNPREDICTABLE(i2 != zero) {
357 i2 = (const int8_t*) ((uintptr_t) i2 + input_offset);
358 }
359 const int8_t* i3 = input[3];
360 assert(i3 != NULL);
361 if XNN_UNPREDICTABLE(i3 != zero) {
362 i3 = (const int8_t*) ((uintptr_t) i3 + input_offset);
363 }
364 const int8_t* i4 = input[4];
365 assert(i4 != NULL);
366 if XNN_UNPREDICTABLE(i4 != zero) {
367 i4 = (const int8_t*) ((uintptr_t) i4 + input_offset);
368 }
369 const int8_t* i5 = input[5];
370 assert(i5 != NULL);
371 if XNN_UNPREDICTABLE(i5 != zero) {
372 i5 = (const int8_t*) ((uintptr_t) i5 + input_offset);
373 }
374 const int8_t* i6 = input[6];
375 assert(i6 != NULL);
376 if XNN_UNPREDICTABLE(i6 != zero) {
377 i6 = (const int8_t*) ((uintptr_t) i6 + input_offset);
378 }
379 const int8_t* i7 = input[7];
380 assert(i7 != NULL);
381 if XNN_UNPREDICTABLE(i7 != zero) {
382 i7 = (const int8_t*) ((uintptr_t) i7 + input_offset);
383 }
384 const int8_t* i8 = input[8];
385 assert(i8 != NULL);
386 if XNN_UNPREDICTABLE(i8 != zero) {
387 i8 = (const int8_t*) ((uintptr_t) i8 + input_offset);
388 }
389 const int8_t* i9 = input[9];
390 assert(i9 != NULL);
391 if XNN_UNPREDICTABLE(i9 != zero) {
392 i9 = (const int8_t*) ((uintptr_t) i9 + input_offset);
393 }
394 const int8_t* i10 = input[10];
395 assert(i10 != NULL);
396 if XNN_UNPREDICTABLE(i10 != zero) {
397 i10 = (const int8_t*) ((uintptr_t) i10 + input_offset);
398 }
399 const int8_t* i11 = input[11];
400 assert(i11 != NULL);
401 if XNN_UNPREDICTABLE(i11 != zero) {
402 i11 = (const int8_t*) ((uintptr_t) i11 + input_offset);
403 }
404 const int8_t* i12 = input[12];
405 assert(i12 != NULL);
406 if XNN_UNPREDICTABLE(i12 != zero) {
407 i12 = (const int8_t*) ((uintptr_t) i12 + input_offset);
408 }
409 const int8_t* i13 = input[13];
410 assert(i13 != NULL);
411 if XNN_UNPREDICTABLE(i13 != zero) {
412 i13 = (const int8_t*) ((uintptr_t) i13 + input_offset);
413 }
414 const int8_t* i14 = input[14];
415 assert(i14 != NULL);
416 if XNN_UNPREDICTABLE(i14 != zero) {
417 i14 = (const int8_t*) ((uintptr_t) i14 + input_offset);
418 }
419 const int8_t* i15 = input[15];
420 assert(i15 != NULL);
421 if XNN_UNPREDICTABLE(i15 != zero) {
422 i15 = (const int8_t*) ((uintptr_t) i15 + input_offset);
423 }
424 const int8_t* i16 = input[16];
425 assert(i16 != NULL);
426 if XNN_UNPREDICTABLE(i16 != zero) {
427 i16 = (const int8_t*) ((uintptr_t) i16 + input_offset);
428 }
429 const int8_t* i17 = input[17];
430 assert(i17 != NULL);
431 if XNN_UNPREDICTABLE(i17 != zero) {
432 i17 = (const int8_t*) ((uintptr_t) i17 + input_offset);
433 }
434 const int8_t* i18 = input[18];
435 assert(i18 != NULL);
436 if XNN_UNPREDICTABLE(i18 != zero) {
437 i18 = (const int8_t*) ((uintptr_t) i18 + input_offset);
438 }
439 const int8_t* i19 = input[19];
440 assert(i19 != NULL);
441 if XNN_UNPREDICTABLE(i19 != zero) {
442 i19 = (const int8_t*) ((uintptr_t) i19 + input_offset);
443 }
444 const int8_t* i20 = input[20];
445 assert(i20 != NULL);
446 if XNN_UNPREDICTABLE(i20 != zero) {
447 i20 = (const int8_t*) ((uintptr_t) i20 + input_offset);
448 }
449 const int8_t* i21 = input[21];
450 assert(i21 != NULL);
451 if XNN_UNPREDICTABLE(i21 != zero) {
452 i21 = (const int8_t*) ((uintptr_t) i21 + input_offset);
453 }
454 const int8_t* i22 = input[22];
455 assert(i22 != NULL);
456 if XNN_UNPREDICTABLE(i22 != zero) {
457 i22 = (const int8_t*) ((uintptr_t) i22 + input_offset);
458 }
459 const int8_t* i23 = input[23];
460 assert(i23 != NULL);
461 if XNN_UNPREDICTABLE(i23 != zero) {
462 i23 = (const int8_t*) ((uintptr_t) i23 + input_offset);
463 }
464 const int8_t* i24 = input[24];
465 assert(i24 != NULL);
466 if XNN_UNPREDICTABLE(i24 != zero) {
467 i24 = (const int8_t*) ((uintptr_t) i24 + input_offset);
468 }
469 input = (const int8_t**) ((uintptr_t) input + input_stride);
470
471 size_t c = channels;
472 const void* w = weights;
473 for (; c >= 32; c -= 32) {
474 __m512i vacc0123456789ABCDEF = _mm512_loadu_si512(w);
475 __m512i vaccGHIJKLMNOPQRSTUV = _mm512_loadu_si512((const void*) ((uintptr_t) w + 16 * sizeof(int32_t)));
476
477
478 const __m512i vi0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i0));
479 const __m512i vk0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 0 * sizeof(int8_t))));
480 const __m512i vi0xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i0 + 16)));
481 const __m512i vk0xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 16 * sizeof(int8_t))));
482 i0 += 32;
483
484 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF));
485 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi0xGHIJKLMNOPQRSTUV, vk0xGHIJKLMNOPQRSTUV));
486
487 const __m512i vi1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i1));
488 const __m512i vk1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 32 * sizeof(int8_t))));
489 const __m512i vi1xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i1 + 16)));
490 const __m512i vk1xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 48 * sizeof(int8_t))));
491 i1 += 32;
492
493 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF));
494 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi1xGHIJKLMNOPQRSTUV, vk1xGHIJKLMNOPQRSTUV));
495
496 const __m512i vi2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i2));
497 const __m512i vk2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 64 * sizeof(int8_t))));
498 const __m512i vi2xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i2 + 16)));
499 const __m512i vk2xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 80 * sizeof(int8_t))));
500 i2 += 32;
501
502 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF));
503 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi2xGHIJKLMNOPQRSTUV, vk2xGHIJKLMNOPQRSTUV));
504
505 const __m512i vi3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i3));
506 const __m512i vk3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 96 * sizeof(int8_t))));
507 const __m512i vi3xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i3 + 16)));
508 const __m512i vk3xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 112 * sizeof(int8_t))));
509 i3 += 32;
510
511 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF));
512 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi3xGHIJKLMNOPQRSTUV, vk3xGHIJKLMNOPQRSTUV));
513
514 const __m512i vi4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i4));
515 const __m512i vk4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 128 * sizeof(int8_t))));
516 const __m512i vi4xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i4 + 16)));
517 const __m512i vk4xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 144 * sizeof(int8_t))));
518 i4 += 32;
519
520 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF));
521 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi4xGHIJKLMNOPQRSTUV, vk4xGHIJKLMNOPQRSTUV));
522
523 const __m512i vi5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i5));
524 const __m512i vk5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 160 * sizeof(int8_t))));
525 const __m512i vi5xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i5 + 16)));
526 const __m512i vk5xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 176 * sizeof(int8_t))));
527 i5 += 32;
528
529 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF));
530 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi5xGHIJKLMNOPQRSTUV, vk5xGHIJKLMNOPQRSTUV));
531
532 const __m512i vi6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i6));
533 const __m512i vk6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 192 * sizeof(int8_t))));
534 const __m512i vi6xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i6 + 16)));
535 const __m512i vk6xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 208 * sizeof(int8_t))));
536 i6 += 32;
537
538 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF));
539 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi6xGHIJKLMNOPQRSTUV, vk6xGHIJKLMNOPQRSTUV));
540
541 const __m512i vi7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i7));
542 const __m512i vk7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 224 * sizeof(int8_t))));
543 const __m512i vi7xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i7 + 16)));
544 const __m512i vk7xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 240 * sizeof(int8_t))));
545 i7 += 32;
546
547 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF));
548 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi7xGHIJKLMNOPQRSTUV, vk7xGHIJKLMNOPQRSTUV));
549
550 const __m512i vi8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i8));
551 const __m512i vk8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 256 * sizeof(int8_t))));
552 const __m512i vi8xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i8 + 16)));
553 const __m512i vk8xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 272 * sizeof(int8_t))));
554 i8 += 32;
555
556 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF));
557 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi8xGHIJKLMNOPQRSTUV, vk8xGHIJKLMNOPQRSTUV));
558
559 const __m512i vi9x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i9));
560 const __m512i vk9x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 288 * sizeof(int8_t))));
561 const __m512i vi9xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i9 + 16)));
562 const __m512i vk9xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 304 * sizeof(int8_t))));
563 i9 += 32;
564
565 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi9x0123456789ABCDEF, vk9x0123456789ABCDEF));
566 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi9xGHIJKLMNOPQRSTUV, vk9xGHIJKLMNOPQRSTUV));
567
568 const __m512i vi10x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i10));
569 const __m512i vk10x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 320 * sizeof(int8_t))));
570 const __m512i vi10xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i10 + 16)));
571 const __m512i vk10xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 336 * sizeof(int8_t))));
572 i10 += 32;
573
574 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi10x0123456789ABCDEF, vk10x0123456789ABCDEF));
575 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi10xGHIJKLMNOPQRSTUV, vk10xGHIJKLMNOPQRSTUV));
576
577 const __m512i vi11x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i11));
578 const __m512i vk11x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 352 * sizeof(int8_t))));
579 const __m512i vi11xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i11 + 16)));
580 const __m512i vk11xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 368 * sizeof(int8_t))));
581 i11 += 32;
582
583 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi11x0123456789ABCDEF, vk11x0123456789ABCDEF));
584 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi11xGHIJKLMNOPQRSTUV, vk11xGHIJKLMNOPQRSTUV));
585
586 const __m512i vi12x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i12));
587 const __m512i vk12x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 384 * sizeof(int8_t))));
588 const __m512i vi12xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i12 + 16)));
589 const __m512i vk12xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 400 * sizeof(int8_t))));
590 i12 += 32;
591
592 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi12x0123456789ABCDEF, vk12x0123456789ABCDEF));
593 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi12xGHIJKLMNOPQRSTUV, vk12xGHIJKLMNOPQRSTUV));
594
595 const __m512i vi13x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i13));
596 const __m512i vk13x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 416 * sizeof(int8_t))));
597 const __m512i vi13xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i13 + 16)));
598 const __m512i vk13xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 432 * sizeof(int8_t))));
599 i13 += 32;
600
601 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi13x0123456789ABCDEF, vk13x0123456789ABCDEF));
602 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi13xGHIJKLMNOPQRSTUV, vk13xGHIJKLMNOPQRSTUV));
603
604 const __m512i vi14x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i14));
605 const __m512i vk14x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 448 * sizeof(int8_t))));
606 const __m512i vi14xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i14 + 16)));
607 const __m512i vk14xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 464 * sizeof(int8_t))));
608 i14 += 32;
609
610 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi14x0123456789ABCDEF, vk14x0123456789ABCDEF));
611 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi14xGHIJKLMNOPQRSTUV, vk14xGHIJKLMNOPQRSTUV));
612
613 const __m512i vi15x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i15));
614 const __m512i vk15x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 480 * sizeof(int8_t))));
615 const __m512i vi15xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i15 + 16)));
616 const __m512i vk15xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 496 * sizeof(int8_t))));
617 i15 += 32;
618
619 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi15x0123456789ABCDEF, vk15x0123456789ABCDEF));
620 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi15xGHIJKLMNOPQRSTUV, vk15xGHIJKLMNOPQRSTUV));
621
622 const __m512i vi16x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i16));
623 const __m512i vk16x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 512 * sizeof(int8_t))));
624 const __m512i vi16xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i16 + 16)));
625 const __m512i vk16xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 528 * sizeof(int8_t))));
626 i16 += 32;
627
628 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi16x0123456789ABCDEF, vk16x0123456789ABCDEF));
629 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi16xGHIJKLMNOPQRSTUV, vk16xGHIJKLMNOPQRSTUV));
630
631 const __m512i vi17x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i17));
632 const __m512i vk17x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 544 * sizeof(int8_t))));
633 const __m512i vi17xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i17 + 16)));
634 const __m512i vk17xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 560 * sizeof(int8_t))));
635 i17 += 32;
636
637 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi17x0123456789ABCDEF, vk17x0123456789ABCDEF));
638 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi17xGHIJKLMNOPQRSTUV, vk17xGHIJKLMNOPQRSTUV));
639
640 const __m512i vi18x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i18));
641 const __m512i vk18x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 576 * sizeof(int8_t))));
642 const __m512i vi18xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i18 + 16)));
643 const __m512i vk18xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 592 * sizeof(int8_t))));
644 i18 += 32;
645
646 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi18x0123456789ABCDEF, vk18x0123456789ABCDEF));
647 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi18xGHIJKLMNOPQRSTUV, vk18xGHIJKLMNOPQRSTUV));
648
649 const __m512i vi19x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i19));
650 const __m512i vk19x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 608 * sizeof(int8_t))));
651 const __m512i vi19xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i19 + 16)));
652 const __m512i vk19xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 624 * sizeof(int8_t))));
653 i19 += 32;
654
655 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi19x0123456789ABCDEF, vk19x0123456789ABCDEF));
656 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi19xGHIJKLMNOPQRSTUV, vk19xGHIJKLMNOPQRSTUV));
657
658 const __m512i vi20x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i20));
659 const __m512i vk20x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 640 * sizeof(int8_t))));
660 const __m512i vi20xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i20 + 16)));
661 const __m512i vk20xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 656 * sizeof(int8_t))));
662 i20 += 32;
663
664 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi20x0123456789ABCDEF, vk20x0123456789ABCDEF));
665 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi20xGHIJKLMNOPQRSTUV, vk20xGHIJKLMNOPQRSTUV));
666
667 const __m512i vi21x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i21));
668 const __m512i vk21x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 672 * sizeof(int8_t))));
669 const __m512i vi21xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i21 + 16)));
670 const __m512i vk21xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 688 * sizeof(int8_t))));
671 i21 += 32;
672
673 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi21x0123456789ABCDEF, vk21x0123456789ABCDEF));
674 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi21xGHIJKLMNOPQRSTUV, vk21xGHIJKLMNOPQRSTUV));
675
676 const __m512i vi22x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i22));
677 const __m512i vk22x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 704 * sizeof(int8_t))));
678 const __m512i vi22xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i22 + 16)));
679 const __m512i vk22xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 720 * sizeof(int8_t))));
680 i22 += 32;
681
682 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi22x0123456789ABCDEF, vk22x0123456789ABCDEF));
683 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi22xGHIJKLMNOPQRSTUV, vk22xGHIJKLMNOPQRSTUV));
684
685 const __m512i vi23x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i23));
686 const __m512i vk23x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 736 * sizeof(int8_t))));
687 const __m512i vi23xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i23 + 16)));
688 const __m512i vk23xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 752 * sizeof(int8_t))));
689 i23 += 32;
690
691 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi23x0123456789ABCDEF, vk23x0123456789ABCDEF));
692 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi23xGHIJKLMNOPQRSTUV, vk23xGHIJKLMNOPQRSTUV));
693
694 const __m512i vi24x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i24));
695 const __m512i vk24x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 768 * sizeof(int8_t))));
696 const __m512i vi24xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i24 + 16)));
697 const __m512i vk24xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 784 * sizeof(int8_t))));
698 i24 += 32;
699
700 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi24x0123456789ABCDEF, vk24x0123456789ABCDEF));
701 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi24xGHIJKLMNOPQRSTUV, vk24xGHIJKLMNOPQRSTUV));
702
703 w = (const void*) ((uintptr_t) w + 32 * sizeof(int32_t) + 800 * sizeof(int8_t));
704
705 __m512 vscaled0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0123456789ABCDEF);
706 __m512 vscaledGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vaccGHIJKLMNOPQRSTUV);
707
708 const __m512 vscale0123456789ABCDEF = _mm512_loadu_ps(w);
709 const __m512 vscaleGHIJKLMNOPQRSTUV = _mm512_loadu_ps((const void*) ((uintptr_t) w + 16 * sizeof(float)));
710 w = (const void*) ((uintptr_t) w + 32 * sizeof(float));
711 vscaled0123456789ABCDEF = _mm512_mul_ps(vscaled0123456789ABCDEF, vscale0123456789ABCDEF);
712 vscaledGHIJKLMNOPQRSTUV = _mm512_mul_ps(vscaledGHIJKLMNOPQRSTUV, vscaleGHIJKLMNOPQRSTUV);
713
714 vscaled0123456789ABCDEF = _mm512_min_ps(vscaled0123456789ABCDEF, voutput_max_less_zero_point);
715 vscaledGHIJKLMNOPQRSTUV = _mm512_min_ps(vscaledGHIJKLMNOPQRSTUV, voutput_max_less_zero_point);
716
717 vacc0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0123456789ABCDEF);
718 vaccGHIJKLMNOPQRSTUV = _mm512_cvtps_epi32(vscaledGHIJKLMNOPQRSTUV);
719
720 __m512i vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV = _mm512_adds_epi16(_mm512_packs_epi32(vacc0123456789ABCDEF, vaccGHIJKLMNOPQRSTUV), voutput_zero_point);
721 __m256i voutGHIJOPQRKLMNSTUV = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vaccGHIJKLMNOPQRSTUV), _mm512_extracti32x8_epi32(vaccGHIJKLMNOPQRSTUV, 1)), _mm512_castsi512_si256(voutput_zero_point));
722
723 const __m256i vout0123GHIJ4567KLMN = _mm512_castsi512_si256(vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV);
724 const __m256i vout89ABOPQRCDEFSTUV = _mm512_extracti32x8_epi32(vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV, 1);
725 const __m256i vout0123GHIJ89ABOPQR4567KLMNCDEFSTUV = _mm256_packs_epi16(vout0123GHIJ4567KLMN, vout89ABOPQRCDEFSTUV);
726 __m256i vout0123456789ABCDEFGHIJKLMNOPQRSTUV = _mm256_permutevar8x32_epi32(vout0123GHIJ89ABOPQR4567KLMNCDEFSTUV, vpermute_mask);
727 const __m128i voutGHIJOPQR = _mm256_castsi256_si128(voutGHIJOPQRKLMNSTUV);
728 const __m128i voutKLMNSTUV = _mm256_extracti128_si256(voutGHIJOPQRKLMNSTUV, 1);
729 __m128i voutGHIJKLMNOPQRSTUV = _mm_shuffle_epi32(_mm_packs_epi16(voutGHIJOPQR, voutKLMNSTUV), _MM_SHUFFLE(3, 1, 2, 0));
730
731 vout0123456789ABCDEFGHIJKLMNOPQRSTUV = _mm256_max_epi8(vout0123456789ABCDEFGHIJKLMNOPQRSTUV, voutput_min);
732 voutGHIJKLMNOPQRSTUV = _mm_max_epi8(voutGHIJKLMNOPQRSTUV, _mm256_castsi256_si128(voutput_min));
733
734 _mm256_storeu_si256((__m256i*) output, vout0123456789ABCDEFGHIJKLMNOPQRSTUV);
735 _mm_storeu_si128((__m128i*) (output + 16), voutGHIJKLMNOPQRSTUV);
736 output += 32;
737 }
738 if XNN_UNLIKELY(c != 0) {
739 // Prepare mask for valid 8-bit elements (depends on nc).
740 const __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << (c & 15)) - UINT32_C(1)));
741 const int8_t* k = (const int8_t*) ((uintptr_t) w + 32 * sizeof(int32_t));
742 do {
743 __m512i vacc0123456789ABCDEF = _mm512_loadu_si512(w);
744
745
746 const __m512i vi0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i0));
747 const __m512i vk0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) k));
748 i0 += 16;
749
750 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF));
751
752 const __m512i vi1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i1));
753 const __m512i vk1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 32)));
754 i1 += 16;
755
756 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF));
757
758 const __m512i vi2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i2));
759 const __m512i vk2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 64)));
760 i2 += 16;
761
762 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF));
763
764 const __m512i vi3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i3));
765 const __m512i vk3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 96)));
766 i3 += 16;
767
768 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF));
769
770 const __m512i vi4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i4));
771 const __m512i vk4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 128)));
772 i4 += 16;
773
774 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF));
775
776 const __m512i vi5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i5));
777 const __m512i vk5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 160)));
778 i5 += 16;
779
780 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF));
781
782 const __m512i vi6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i6));
783 const __m512i vk6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 192)));
784 i6 += 16;
785
786 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF));
787
788 const __m512i vi7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i7));
789 const __m512i vk7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 224)));
790 i7 += 16;
791
792 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF));
793
794 const __m512i vi8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i8));
795 const __m512i vk8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 256)));
796 i8 += 16;
797
798 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF));
799
800 const __m512i vi9x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i9));
801 const __m512i vk9x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 288)));
802 i9 += 16;
803
804 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi9x0123456789ABCDEF, vk9x0123456789ABCDEF));
805
806 const __m512i vi10x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i10));
807 const __m512i vk10x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 320)));
808 i10 += 16;
809
810 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi10x0123456789ABCDEF, vk10x0123456789ABCDEF));
811
812 const __m512i vi11x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i11));
813 const __m512i vk11x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 352)));
814 i11 += 16;
815
816 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi11x0123456789ABCDEF, vk11x0123456789ABCDEF));
817
818 const __m512i vi12x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i12));
819 const __m512i vk12x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 384)));
820 i12 += 16;
821
822 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi12x0123456789ABCDEF, vk12x0123456789ABCDEF));
823
824 const __m512i vi13x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i13));
825 const __m512i vk13x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 416)));
826 i13 += 16;
827
828 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi13x0123456789ABCDEF, vk13x0123456789ABCDEF));
829
830 const __m512i vi14x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i14));
831 const __m512i vk14x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 448)));
832 i14 += 16;
833
834 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi14x0123456789ABCDEF, vk14x0123456789ABCDEF));
835
836 const __m512i vi15x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i15));
837 const __m512i vk15x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 480)));
838 i15 += 16;
839
840 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi15x0123456789ABCDEF, vk15x0123456789ABCDEF));
841
842 const __m512i vi16x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i16));
843 const __m512i vk16x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 512)));
844 i16 += 16;
845
846 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi16x0123456789ABCDEF, vk16x0123456789ABCDEF));
847
848 const __m512i vi17x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i17));
849 const __m512i vk17x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 544)));
850 i17 += 16;
851
852 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi17x0123456789ABCDEF, vk17x0123456789ABCDEF));
853
854 const __m512i vi18x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i18));
855 const __m512i vk18x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 576)));
856 i18 += 16;
857
858 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi18x0123456789ABCDEF, vk18x0123456789ABCDEF));
859
860 const __m512i vi19x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i19));
861 const __m512i vk19x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 608)));
862 i19 += 16;
863
864 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi19x0123456789ABCDEF, vk19x0123456789ABCDEF));
865
866 const __m512i vi20x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i20));
867 const __m512i vk20x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 640)));
868 i20 += 16;
869
870 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi20x0123456789ABCDEF, vk20x0123456789ABCDEF));
871
872 const __m512i vi21x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i21));
873 const __m512i vk21x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 672)));
874 i21 += 16;
875
876 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi21x0123456789ABCDEF, vk21x0123456789ABCDEF));
877
878 const __m512i vi22x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i22));
879 const __m512i vk22x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 704)));
880 i22 += 16;
881
882 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi22x0123456789ABCDEF, vk22x0123456789ABCDEF));
883
884 const __m512i vi23x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i23));
885 const __m512i vk23x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 736)));
886 i23 += 16;
887
888 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi23x0123456789ABCDEF, vk23x0123456789ABCDEF));
889
890 const __m512i vi24x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i24));
891 const __m512i vk24x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 768)));
892 i24 += 16;
893
894 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi24x0123456789ABCDEF, vk24x0123456789ABCDEF));
895
896 k += 16;
897
898 __m512 vscaled0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0123456789ABCDEF);
899 const __m512 vscale0123456789ABCDEF = _mm512_loadu_ps((const void*) ((uintptr_t) w + 32 * sizeof(int32_t) + 800 * sizeof(int8_t)));
900 vscaled0123456789ABCDEF = _mm512_mul_ps(vscaled0123456789ABCDEF, vscale0123456789ABCDEF);
901 vscaled0123456789ABCDEF = _mm512_min_ps(vscaled0123456789ABCDEF, voutput_max_less_zero_point);
902 vacc0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0123456789ABCDEF);
903
904 w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
905
906 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0123456789ABCDEF, 1)), _mm512_castsi512_si256(voutput_zero_point));
907
908 const __m128i vout012389AB = _mm256_castsi256_si128(vout012389AB4567CDEF);
909 const __m128i vout4567CDEF = _mm256_extracti128_si256(vout012389AB4567CDEF, 1);
910 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packs_epi16(vout012389AB, vout4567CDEF), _MM_SHUFFLE(3, 1, 2, 0));
911 vout0123456789ABCDEF = _mm_max_epi8(vout0123456789ABCDEF, _mm256_castsi256_si128(voutput_min));
912
913 if XNN_LIKELY(c >= 16) {
914 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
915 output += 16;
916 c -= 16;
917 } else {
918 _mm_mask_storeu_epi8(output, vmask, vout0123456789ABCDEF);
919 output = (int8_t*) ((uintptr_t) output + c);
920 c = 0;
921 }
922 } while (c != 0);
923 }
924
925 output = (int8_t*) ((uintptr_t) output + output_increment);
926 } while (--output_width != 0);
927 }
928
xnn_qc8_dwconv_minmax_fp32_ukernel_up32x3__avx512skx_mul32(size_t channels,size_t output_width,const int8_t ** input,const void * weights,int8_t * output,size_t input_stride,size_t output_increment,size_t input_offset,const int8_t * zero,const union xnn_qc8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])929 void xnn_qc8_dwconv_minmax_fp32_ukernel_up32x3__avx512skx_mul32(
930 size_t channels,
931 size_t output_width,
932 const int8_t** input,
933 const void* weights,
934 int8_t* output,
935 size_t input_stride,
936 size_t output_increment,
937 size_t input_offset,
938 const int8_t* zero,
939 const union xnn_qc8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_MSAN
940 {
941 assert(channels != 0);
942 assert(output_width != 0);
943
944 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
945 const __m512i voutput_zero_point = _mm512_load_si512(params->fp32_avx512.output_zero_point);
946 const __m256i voutput_min = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_min);
947 const __m256i vpermute_mask = _mm256_set_epi32(7, 3, 5, 1, 6, 2, 4, 0);
948
949 do {
950 const int8_t* i0 = input[0];
951 assert(i0 != NULL);
952 if XNN_UNPREDICTABLE(i0 != zero) {
953 i0 = (const int8_t*) ((uintptr_t) i0 + input_offset);
954 }
955 const int8_t* i1 = input[1];
956 assert(i1 != NULL);
957 if XNN_UNPREDICTABLE(i1 != zero) {
958 i1 = (const int8_t*) ((uintptr_t) i1 + input_offset);
959 }
960 const int8_t* i2 = input[2];
961 assert(i2 != NULL);
962 if XNN_UNPREDICTABLE(i2 != zero) {
963 i2 = (const int8_t*) ((uintptr_t) i2 + input_offset);
964 }
965 input = (const int8_t**) ((uintptr_t) input + input_stride);
966
967 size_t c = channels;
968 const void* w = weights;
969 for (; c >= 32; c -= 32) {
970 __m512i vacc0123456789ABCDEF = _mm512_loadu_si512(w);
971 __m512i vaccGHIJKLMNOPQRSTUV = _mm512_loadu_si512((const void*) ((uintptr_t) w + 16 * sizeof(int32_t)));
972
973
974 const __m512i vi0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i0));
975 const __m512i vk0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 0 * sizeof(int8_t))));
976 const __m512i vi0xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i0 + 16)));
977 const __m512i vk0xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 16 * sizeof(int8_t))));
978 i0 += 32;
979
980 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF));
981 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi0xGHIJKLMNOPQRSTUV, vk0xGHIJKLMNOPQRSTUV));
982
983 const __m512i vi1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i1));
984 const __m512i vk1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 32 * sizeof(int8_t))));
985 const __m512i vi1xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i1 + 16)));
986 const __m512i vk1xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 48 * sizeof(int8_t))));
987 i1 += 32;
988
989 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF));
990 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi1xGHIJKLMNOPQRSTUV, vk1xGHIJKLMNOPQRSTUV));
991
992 const __m512i vi2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i2));
993 const __m512i vk2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 64 * sizeof(int8_t))));
994 const __m512i vi2xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i2 + 16)));
995 const __m512i vk2xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 80 * sizeof(int8_t))));
996 i2 += 32;
997
998 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF));
999 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi2xGHIJKLMNOPQRSTUV, vk2xGHIJKLMNOPQRSTUV));
1000
1001 w = (const void*) ((uintptr_t) w + 32 * sizeof(int32_t) + 96 * sizeof(int8_t));
1002
1003 __m512 vscaled0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0123456789ABCDEF);
1004 __m512 vscaledGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vaccGHIJKLMNOPQRSTUV);
1005
1006 const __m512 vscale0123456789ABCDEF = _mm512_loadu_ps(w);
1007 const __m512 vscaleGHIJKLMNOPQRSTUV = _mm512_loadu_ps((const void*) ((uintptr_t) w + 16 * sizeof(float)));
1008 w = (const void*) ((uintptr_t) w + 32 * sizeof(float));
1009 vscaled0123456789ABCDEF = _mm512_mul_ps(vscaled0123456789ABCDEF, vscale0123456789ABCDEF);
1010 vscaledGHIJKLMNOPQRSTUV = _mm512_mul_ps(vscaledGHIJKLMNOPQRSTUV, vscaleGHIJKLMNOPQRSTUV);
1011
1012 vscaled0123456789ABCDEF = _mm512_min_ps(vscaled0123456789ABCDEF, voutput_max_less_zero_point);
1013 vscaledGHIJKLMNOPQRSTUV = _mm512_min_ps(vscaledGHIJKLMNOPQRSTUV, voutput_max_less_zero_point);
1014
1015 vacc0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0123456789ABCDEF);
1016 vaccGHIJKLMNOPQRSTUV = _mm512_cvtps_epi32(vscaledGHIJKLMNOPQRSTUV);
1017
1018 __m512i vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV = _mm512_adds_epi16(_mm512_packs_epi32(vacc0123456789ABCDEF, vaccGHIJKLMNOPQRSTUV), voutput_zero_point);
1019 __m256i voutGHIJOPQRKLMNSTUV = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vaccGHIJKLMNOPQRSTUV), _mm512_extracti32x8_epi32(vaccGHIJKLMNOPQRSTUV, 1)), _mm512_castsi512_si256(voutput_zero_point));
1020
1021 const __m256i vout0123GHIJ4567KLMN = _mm512_castsi512_si256(vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV);
1022 const __m256i vout89ABOPQRCDEFSTUV = _mm512_extracti32x8_epi32(vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV, 1);
1023 const __m256i vout0123GHIJ89ABOPQR4567KLMNCDEFSTUV = _mm256_packs_epi16(vout0123GHIJ4567KLMN, vout89ABOPQRCDEFSTUV);
1024 __m256i vout0123456789ABCDEFGHIJKLMNOPQRSTUV = _mm256_permutevar8x32_epi32(vout0123GHIJ89ABOPQR4567KLMNCDEFSTUV, vpermute_mask);
1025 const __m128i voutGHIJOPQR = _mm256_castsi256_si128(voutGHIJOPQRKLMNSTUV);
1026 const __m128i voutKLMNSTUV = _mm256_extracti128_si256(voutGHIJOPQRKLMNSTUV, 1);
1027 __m128i voutGHIJKLMNOPQRSTUV = _mm_shuffle_epi32(_mm_packs_epi16(voutGHIJOPQR, voutKLMNSTUV), _MM_SHUFFLE(3, 1, 2, 0));
1028
1029 vout0123456789ABCDEFGHIJKLMNOPQRSTUV = _mm256_max_epi8(vout0123456789ABCDEFGHIJKLMNOPQRSTUV, voutput_min);
1030 voutGHIJKLMNOPQRSTUV = _mm_max_epi8(voutGHIJKLMNOPQRSTUV, _mm256_castsi256_si128(voutput_min));
1031
1032 _mm256_storeu_si256((__m256i*) output, vout0123456789ABCDEFGHIJKLMNOPQRSTUV);
1033 _mm_storeu_si128((__m128i*) (output + 16), voutGHIJKLMNOPQRSTUV);
1034 output += 32;
1035 }
1036 if XNN_UNLIKELY(c != 0) {
1037 // Prepare mask for valid 8-bit elements (depends on nc).
1038 const __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << (c & 15)) - UINT32_C(1)));
1039 const int8_t* k = (const int8_t*) ((uintptr_t) w + 32 * sizeof(int32_t));
1040 do {
1041 __m512i vacc0123456789ABCDEF = _mm512_loadu_si512(w);
1042
1043
1044 const __m512i vi0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i0));
1045 const __m512i vk0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) k));
1046 i0 += 16;
1047
1048 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF));
1049
1050 const __m512i vi1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i1));
1051 const __m512i vk1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 32)));
1052 i1 += 16;
1053
1054 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF));
1055
1056 const __m512i vi2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i2));
1057 const __m512i vk2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 64)));
1058 i2 += 16;
1059
1060 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF));
1061
1062 k += 16;
1063
1064 __m512 vscaled0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0123456789ABCDEF);
1065 const __m512 vscale0123456789ABCDEF = _mm512_loadu_ps((const void*) ((uintptr_t) w + 32 * sizeof(int32_t) + 96 * sizeof(int8_t)));
1066 vscaled0123456789ABCDEF = _mm512_mul_ps(vscaled0123456789ABCDEF, vscale0123456789ABCDEF);
1067 vscaled0123456789ABCDEF = _mm512_min_ps(vscaled0123456789ABCDEF, voutput_max_less_zero_point);
1068 vacc0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0123456789ABCDEF);
1069
1070 w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
1071
1072 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0123456789ABCDEF, 1)), _mm512_castsi512_si256(voutput_zero_point));
1073
1074 const __m128i vout012389AB = _mm256_castsi256_si128(vout012389AB4567CDEF);
1075 const __m128i vout4567CDEF = _mm256_extracti128_si256(vout012389AB4567CDEF, 1);
1076 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packs_epi16(vout012389AB, vout4567CDEF), _MM_SHUFFLE(3, 1, 2, 0));
1077 vout0123456789ABCDEF = _mm_max_epi8(vout0123456789ABCDEF, _mm256_castsi256_si128(voutput_min));
1078
1079 if XNN_LIKELY(c >= 16) {
1080 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
1081 output += 16;
1082 c -= 16;
1083 } else {
1084 _mm_mask_storeu_epi8(output, vmask, vout0123456789ABCDEF);
1085 output = (int8_t*) ((uintptr_t) output + c);
1086 c = 0;
1087 }
1088 } while (c != 0);
1089 }
1090
1091 output = (int8_t*) ((uintptr_t) output + output_increment);
1092 } while (--output_width != 0);
1093 }
1094
xnn_qc8_dwconv_minmax_fp32_ukernel_up32x9__avx512skx_mul32(size_t channels,size_t output_width,const int8_t ** input,const void * weights,int8_t * output,size_t input_stride,size_t output_increment,size_t input_offset,const int8_t * zero,const union xnn_qc8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1095 void xnn_qc8_dwconv_minmax_fp32_ukernel_up32x9__avx512skx_mul32(
1096 size_t channels,
1097 size_t output_width,
1098 const int8_t** input,
1099 const void* weights,
1100 int8_t* output,
1101 size_t input_stride,
1102 size_t output_increment,
1103 size_t input_offset,
1104 const int8_t* zero,
1105 const union xnn_qc8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_MSAN
1106 {
1107 assert(channels != 0);
1108 assert(output_width != 0);
1109
1110 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
1111 const __m512i voutput_zero_point = _mm512_load_si512(params->fp32_avx512.output_zero_point);
1112 const __m256i voutput_min = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_min);
1113 const __m256i vpermute_mask = _mm256_set_epi32(7, 3, 5, 1, 6, 2, 4, 0);
1114
1115 do {
1116 const int8_t* i0 = input[0];
1117 assert(i0 != NULL);
1118 if XNN_UNPREDICTABLE(i0 != zero) {
1119 i0 = (const int8_t*) ((uintptr_t) i0 + input_offset);
1120 }
1121 const int8_t* i1 = input[1];
1122 assert(i1 != NULL);
1123 if XNN_UNPREDICTABLE(i1 != zero) {
1124 i1 = (const int8_t*) ((uintptr_t) i1 + input_offset);
1125 }
1126 const int8_t* i2 = input[2];
1127 assert(i2 != NULL);
1128 if XNN_UNPREDICTABLE(i2 != zero) {
1129 i2 = (const int8_t*) ((uintptr_t) i2 + input_offset);
1130 }
1131 const int8_t* i3 = input[3];
1132 assert(i3 != NULL);
1133 if XNN_UNPREDICTABLE(i3 != zero) {
1134 i3 = (const int8_t*) ((uintptr_t) i3 + input_offset);
1135 }
1136 const int8_t* i4 = input[4];
1137 assert(i4 != NULL);
1138 if XNN_UNPREDICTABLE(i4 != zero) {
1139 i4 = (const int8_t*) ((uintptr_t) i4 + input_offset);
1140 }
1141 const int8_t* i5 = input[5];
1142 assert(i5 != NULL);
1143 if XNN_UNPREDICTABLE(i5 != zero) {
1144 i5 = (const int8_t*) ((uintptr_t) i5 + input_offset);
1145 }
1146 const int8_t* i6 = input[6];
1147 assert(i6 != NULL);
1148 if XNN_UNPREDICTABLE(i6 != zero) {
1149 i6 = (const int8_t*) ((uintptr_t) i6 + input_offset);
1150 }
1151 const int8_t* i7 = input[7];
1152 assert(i7 != NULL);
1153 if XNN_UNPREDICTABLE(i7 != zero) {
1154 i7 = (const int8_t*) ((uintptr_t) i7 + input_offset);
1155 }
1156 const int8_t* i8 = input[8];
1157 assert(i8 != NULL);
1158 if XNN_UNPREDICTABLE(i8 != zero) {
1159 i8 = (const int8_t*) ((uintptr_t) i8 + input_offset);
1160 }
1161 input = (const int8_t**) ((uintptr_t) input + input_stride);
1162
1163 size_t c = channels;
1164 const void* w = weights;
1165 for (; c >= 32; c -= 32) {
1166 __m512i vacc0123456789ABCDEF = _mm512_loadu_si512(w);
1167 __m512i vaccGHIJKLMNOPQRSTUV = _mm512_loadu_si512((const void*) ((uintptr_t) w + 16 * sizeof(int32_t)));
1168
1169
1170 const __m512i vi0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i0));
1171 const __m512i vk0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 0 * sizeof(int8_t))));
1172 const __m512i vi0xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i0 + 16)));
1173 const __m512i vk0xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 16 * sizeof(int8_t))));
1174 i0 += 32;
1175
1176 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF));
1177 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi0xGHIJKLMNOPQRSTUV, vk0xGHIJKLMNOPQRSTUV));
1178
1179 const __m512i vi1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i1));
1180 const __m512i vk1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 32 * sizeof(int8_t))));
1181 const __m512i vi1xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i1 + 16)));
1182 const __m512i vk1xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 48 * sizeof(int8_t))));
1183 i1 += 32;
1184
1185 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF));
1186 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi1xGHIJKLMNOPQRSTUV, vk1xGHIJKLMNOPQRSTUV));
1187
1188 const __m512i vi2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i2));
1189 const __m512i vk2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 64 * sizeof(int8_t))));
1190 const __m512i vi2xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i2 + 16)));
1191 const __m512i vk2xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 80 * sizeof(int8_t))));
1192 i2 += 32;
1193
1194 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF));
1195 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi2xGHIJKLMNOPQRSTUV, vk2xGHIJKLMNOPQRSTUV));
1196
1197 const __m512i vi3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i3));
1198 const __m512i vk3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 96 * sizeof(int8_t))));
1199 const __m512i vi3xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i3 + 16)));
1200 const __m512i vk3xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 112 * sizeof(int8_t))));
1201 i3 += 32;
1202
1203 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF));
1204 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi3xGHIJKLMNOPQRSTUV, vk3xGHIJKLMNOPQRSTUV));
1205
1206 const __m512i vi4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i4));
1207 const __m512i vk4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 128 * sizeof(int8_t))));
1208 const __m512i vi4xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i4 + 16)));
1209 const __m512i vk4xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 144 * sizeof(int8_t))));
1210 i4 += 32;
1211
1212 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF));
1213 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi4xGHIJKLMNOPQRSTUV, vk4xGHIJKLMNOPQRSTUV));
1214
1215 const __m512i vi5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i5));
1216 const __m512i vk5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 160 * sizeof(int8_t))));
1217 const __m512i vi5xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i5 + 16)));
1218 const __m512i vk5xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 176 * sizeof(int8_t))));
1219 i5 += 32;
1220
1221 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF));
1222 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi5xGHIJKLMNOPQRSTUV, vk5xGHIJKLMNOPQRSTUV));
1223
1224 const __m512i vi6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i6));
1225 const __m512i vk6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 192 * sizeof(int8_t))));
1226 const __m512i vi6xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i6 + 16)));
1227 const __m512i vk6xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 208 * sizeof(int8_t))));
1228 i6 += 32;
1229
1230 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF));
1231 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi6xGHIJKLMNOPQRSTUV, vk6xGHIJKLMNOPQRSTUV));
1232
1233 const __m512i vi7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i7));
1234 const __m512i vk7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 224 * sizeof(int8_t))));
1235 const __m512i vi7xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i7 + 16)));
1236 const __m512i vk7xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 240 * sizeof(int8_t))));
1237 i7 += 32;
1238
1239 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF));
1240 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi7xGHIJKLMNOPQRSTUV, vk7xGHIJKLMNOPQRSTUV));
1241
1242 const __m512i vi8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i8));
1243 const __m512i vk8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 256 * sizeof(int8_t))));
1244 const __m512i vi8xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i8 + 16)));
1245 const __m512i vk8xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 272 * sizeof(int8_t))));
1246 i8 += 32;
1247
1248 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF));
1249 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi8xGHIJKLMNOPQRSTUV, vk8xGHIJKLMNOPQRSTUV));
1250
1251 w = (const void*) ((uintptr_t) w + 32 * sizeof(int32_t) + 288 * sizeof(int8_t));
1252
1253 __m512 vscaled0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0123456789ABCDEF);
1254 __m512 vscaledGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vaccGHIJKLMNOPQRSTUV);
1255
1256 const __m512 vscale0123456789ABCDEF = _mm512_loadu_ps(w);
1257 const __m512 vscaleGHIJKLMNOPQRSTUV = _mm512_loadu_ps((const void*) ((uintptr_t) w + 16 * sizeof(float)));
1258 w = (const void*) ((uintptr_t) w + 32 * sizeof(float));
1259 vscaled0123456789ABCDEF = _mm512_mul_ps(vscaled0123456789ABCDEF, vscale0123456789ABCDEF);
1260 vscaledGHIJKLMNOPQRSTUV = _mm512_mul_ps(vscaledGHIJKLMNOPQRSTUV, vscaleGHIJKLMNOPQRSTUV);
1261
1262 vscaled0123456789ABCDEF = _mm512_min_ps(vscaled0123456789ABCDEF, voutput_max_less_zero_point);
1263 vscaledGHIJKLMNOPQRSTUV = _mm512_min_ps(vscaledGHIJKLMNOPQRSTUV, voutput_max_less_zero_point);
1264
1265 vacc0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0123456789ABCDEF);
1266 vaccGHIJKLMNOPQRSTUV = _mm512_cvtps_epi32(vscaledGHIJKLMNOPQRSTUV);
1267
1268 __m512i vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV = _mm512_adds_epi16(_mm512_packs_epi32(vacc0123456789ABCDEF, vaccGHIJKLMNOPQRSTUV), voutput_zero_point);
1269 __m256i voutGHIJOPQRKLMNSTUV = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vaccGHIJKLMNOPQRSTUV), _mm512_extracti32x8_epi32(vaccGHIJKLMNOPQRSTUV, 1)), _mm512_castsi512_si256(voutput_zero_point));
1270
1271 const __m256i vout0123GHIJ4567KLMN = _mm512_castsi512_si256(vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV);
1272 const __m256i vout89ABOPQRCDEFSTUV = _mm512_extracti32x8_epi32(vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV, 1);
1273 const __m256i vout0123GHIJ89ABOPQR4567KLMNCDEFSTUV = _mm256_packs_epi16(vout0123GHIJ4567KLMN, vout89ABOPQRCDEFSTUV);
1274 __m256i vout0123456789ABCDEFGHIJKLMNOPQRSTUV = _mm256_permutevar8x32_epi32(vout0123GHIJ89ABOPQR4567KLMNCDEFSTUV, vpermute_mask);
1275 const __m128i voutGHIJOPQR = _mm256_castsi256_si128(voutGHIJOPQRKLMNSTUV);
1276 const __m128i voutKLMNSTUV = _mm256_extracti128_si256(voutGHIJOPQRKLMNSTUV, 1);
1277 __m128i voutGHIJKLMNOPQRSTUV = _mm_shuffle_epi32(_mm_packs_epi16(voutGHIJOPQR, voutKLMNSTUV), _MM_SHUFFLE(3, 1, 2, 0));
1278
1279 vout0123456789ABCDEFGHIJKLMNOPQRSTUV = _mm256_max_epi8(vout0123456789ABCDEFGHIJKLMNOPQRSTUV, voutput_min);
1280 voutGHIJKLMNOPQRSTUV = _mm_max_epi8(voutGHIJKLMNOPQRSTUV, _mm256_castsi256_si128(voutput_min));
1281
1282 _mm256_storeu_si256((__m256i*) output, vout0123456789ABCDEFGHIJKLMNOPQRSTUV);
1283 _mm_storeu_si128((__m128i*) (output + 16), voutGHIJKLMNOPQRSTUV);
1284 output += 32;
1285 }
1286 if XNN_UNLIKELY(c != 0) {
1287 // Prepare mask for valid 8-bit elements (depends on nc).
1288 const __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << (c & 15)) - UINT32_C(1)));
1289 const int8_t* k = (const int8_t*) ((uintptr_t) w + 32 * sizeof(int32_t));
1290 do {
1291 __m512i vacc0123456789ABCDEF = _mm512_loadu_si512(w);
1292
1293
1294 const __m512i vi0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i0));
1295 const __m512i vk0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) k));
1296 i0 += 16;
1297
1298 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF));
1299
1300 const __m512i vi1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i1));
1301 const __m512i vk1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 32)));
1302 i1 += 16;
1303
1304 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF));
1305
1306 const __m512i vi2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i2));
1307 const __m512i vk2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 64)));
1308 i2 += 16;
1309
1310 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF));
1311
1312 const __m512i vi3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i3));
1313 const __m512i vk3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 96)));
1314 i3 += 16;
1315
1316 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF));
1317
1318 const __m512i vi4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i4));
1319 const __m512i vk4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 128)));
1320 i4 += 16;
1321
1322 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF));
1323
1324 const __m512i vi5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i5));
1325 const __m512i vk5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 160)));
1326 i5 += 16;
1327
1328 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF));
1329
1330 const __m512i vi6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i6));
1331 const __m512i vk6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 192)));
1332 i6 += 16;
1333
1334 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF));
1335
1336 const __m512i vi7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i7));
1337 const __m512i vk7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 224)));
1338 i7 += 16;
1339
1340 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF));
1341
1342 const __m512i vi8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i8));
1343 const __m512i vk8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 256)));
1344 i8 += 16;
1345
1346 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF));
1347
1348 k += 16;
1349
1350 __m512 vscaled0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0123456789ABCDEF);
1351 const __m512 vscale0123456789ABCDEF = _mm512_loadu_ps((const void*) ((uintptr_t) w + 32 * sizeof(int32_t) + 288 * sizeof(int8_t)));
1352 vscaled0123456789ABCDEF = _mm512_mul_ps(vscaled0123456789ABCDEF, vscale0123456789ABCDEF);
1353 vscaled0123456789ABCDEF = _mm512_min_ps(vscaled0123456789ABCDEF, voutput_max_less_zero_point);
1354 vacc0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0123456789ABCDEF);
1355
1356 w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
1357
1358 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0123456789ABCDEF, 1)), _mm512_castsi512_si256(voutput_zero_point));
1359
1360 const __m128i vout012389AB = _mm256_castsi256_si128(vout012389AB4567CDEF);
1361 const __m128i vout4567CDEF = _mm256_extracti128_si256(vout012389AB4567CDEF, 1);
1362 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packs_epi16(vout012389AB, vout4567CDEF), _MM_SHUFFLE(3, 1, 2, 0));
1363 vout0123456789ABCDEF = _mm_max_epi8(vout0123456789ABCDEF, _mm256_castsi256_si128(voutput_min));
1364
1365 if XNN_LIKELY(c >= 16) {
1366 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
1367 output += 16;
1368 c -= 16;
1369 } else {
1370 _mm_mask_storeu_epi8(output, vmask, vout0123456789ABCDEF);
1371 output = (int8_t*) ((uintptr_t) output + c);
1372 c = 0;
1373 }
1374 } while (c != 0);
1375 }
1376
1377 output = (int8_t*) ((uintptr_t) output + output_increment);
1378 } while (--output_width != 0);
1379 }
1380
xnn_qc8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx(size_t mr,size_t nc,size_t kc,const int8_t * restrict a,size_t a_stride,const void * restrict w,int8_t * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_qc8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1381 void xnn_qc8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx(
1382 size_t mr,
1383 size_t nc,
1384 size_t kc,
1385 const int8_t* restrict a,
1386 size_t a_stride,
1387 const void* restrict w,
1388 int8_t* restrict c,
1389 size_t cm_stride,
1390 size_t cn_stride,
1391 const union xnn_qc8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
1392 {
1393 assert(mr != 0);
1394 assert(mr <= 1);
1395 assert(nc != 0);
1396 assert(kc != 0);
1397 assert(kc % sizeof(int8_t) == 0);
1398 assert(a != NULL);
1399 assert(w != NULL);
1400 assert(c != NULL);
1401
1402 kc = round_up_po2(kc, 8);
1403 const int8_t* a0 = a;
1404 int8_t* c0 = c;
1405
1406 const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
1407 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
1408 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point);
1409 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min);
1410 do {
1411 __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
1412 __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
1413 __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
1414 __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
1415 w = (const void*) ((const int32_t*) w + 16);
1416
1417 size_t k = 0;
1418 while (k < kc) {
1419 const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
1420 a0 += 8;
1421
1422 const __m512i vb0123 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) w));
1423
1424 vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
1425 const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
1426
1427 vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
1428 const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
1429
1430 vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
1431 const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
1432
1433 vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
1434
1435 w = (const void*) ((const int8_t*) w + 128);
1436 k += 8 * sizeof(int8_t);
1437 }
1438
1439 const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
1440 const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
1441
1442 __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
1443
1444 __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
1445
1446 const __m512 vscale012345678ABCDEF = _mm512_load_ps(w);
1447 w = (const void*) ((const float*) w + 16);
1448 const __m512 vscale084C195D2A6E3B7F = _mm512_permutexvar_ps(_mm512_set_epi32(15, 7, 11, 3, 14, 6, 10, 2, 13, 5, 9, 1, 12, 4, 8, 0), vscale012345678ABCDEF);
1449 vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
1450
1451 vscaled0x084C195D2A6E3B7F = _mm512_min_ps(vscaled0x084C195D2A6E3B7F, voutput_max_less_zero_point);
1452
1453 vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
1454
1455 const __m256i vacc0x084C2A6E195D3B7F = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0x084C195D2A6E3B7F), _mm512_extracti32x8_epi32(vacc0x084C195D2A6E3B7F, 1)), voutput_zero_point);
1456
1457 const __m128i vout0x084C2A6E195D3B7F = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x084C2A6E195D3B7F), _mm256_extracti128_si256(vacc0x084C2A6E195D3B7F, 1));
1458 __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x084C2A6E195D3B7F, _mm_set_epi8(15, 7, 11, 3, 13, 5, 9, 1, 14, 6, 10, 2, 12, 4, 8, 0));
1459 vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min);
1460
1461 if (nc >= 16) {
1462 _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF);
1463
1464 a0 = (const int8_t*) ((uintptr_t) a0 - k);
1465
1466 c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
1467
1468 nc -= 16;
1469 } else {
1470 // Prepare mask for valid 8-bit elements (depends on nc).
1471 const __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT32_C(1) << nc) - UINT32_C(1)));
1472
1473 _mm_mask_storeu_epi8(c0, vmask, vout0x0123456789ABCDEF);
1474
1475 nc = 0;
1476 }
1477 } while (nc != 0);
1478 }
1479
xnn_qc8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx(size_t mr,size_t nc,size_t kc,const int8_t * restrict a,size_t a_stride,const void * restrict w,int8_t * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_qc8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1480 void xnn_qc8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx(
1481 size_t mr,
1482 size_t nc,
1483 size_t kc,
1484 const int8_t* restrict a,
1485 size_t a_stride,
1486 const void* restrict w,
1487 int8_t* restrict c,
1488 size_t cm_stride,
1489 size_t cn_stride,
1490 const union xnn_qc8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
1491 {
1492 assert(mr != 0);
1493 assert(mr <= 4);
1494 assert(nc != 0);
1495 assert(kc != 0);
1496 assert(kc % sizeof(int8_t) == 0);
1497 assert(a != NULL);
1498 assert(w != NULL);
1499 assert(c != NULL);
1500
1501 kc = round_up_po2(kc, 8);
1502 const int8_t* a0 = a;
1503 int8_t* c0 = c;
1504 const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
1505 int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
1506 if XNN_UNPREDICTABLE(mr < 2) {
1507 a1 = a0;
1508 c1 = c0;
1509 }
1510 const int8_t* a2 = (const int8_t*) ((uintptr_t) a1 + a_stride);
1511 int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride);
1512 if XNN_UNPREDICTABLE(mr <= 2) {
1513 a2 = a1;
1514 c2 = c1;
1515 }
1516 const int8_t* a3 = (const int8_t*) ((uintptr_t) a2 + a_stride);
1517 int8_t* c3 = (int8_t*) ((uintptr_t) c2 + cm_stride);
1518 if XNN_UNPREDICTABLE(mr != 4) {
1519 a3 = a2;
1520 c3 = c2;
1521 }
1522
1523 const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
1524 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
1525 const __m512i voutput_zero_point = _mm512_load_si512(params->fp32_avx512.output_zero_point);
1526 const __m512i voutput_min = _mm512_load_si512(params->fp32_avx512.output_min);
1527 do {
1528 __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
1529 __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
1530 __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
1531 __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
1532 __m512i vacc1x0123 = vacc0x0123;
1533 __m512i vacc1x4567 = vacc0x4567;
1534 __m512i vacc1x89AB = vacc0x89AB;
1535 __m512i vacc1xCDEF = vacc0xCDEF;
1536 __m512i vacc2x0123 = vacc0x0123;
1537 __m512i vacc2x4567 = vacc0x4567;
1538 __m512i vacc2x89AB = vacc0x89AB;
1539 __m512i vacc2xCDEF = vacc0xCDEF;
1540 __m512i vacc3x0123 = vacc0x0123;
1541 __m512i vacc3x4567 = vacc0x4567;
1542 __m512i vacc3x89AB = vacc0x89AB;
1543 __m512i vacc3xCDEF = vacc0xCDEF;
1544 w = (const void*) ((const int32_t*) w + 16);
1545
1546 size_t k = 0;
1547 while (k < kc) {
1548 const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
1549 a0 += 8;
1550 const __m512i va1 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a1)));
1551 a1 += 8;
1552 const __m512i va2 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a2)));
1553 a2 += 8;
1554 const __m512i va3 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a3)));
1555 a3 += 8;
1556
1557 const __m512i vb0123 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) w));
1558
1559 vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
1560 vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
1561 vacc2x0123 = _mm512_add_epi32(vacc2x0123, _mm512_madd_epi16(va2, vb0123));
1562 vacc3x0123 = _mm512_add_epi32(vacc3x0123, _mm512_madd_epi16(va3, vb0123));
1563 const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
1564
1565 vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
1566 vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
1567 vacc2x4567 = _mm512_add_epi32(vacc2x4567, _mm512_madd_epi16(va2, vb4567));
1568 vacc3x4567 = _mm512_add_epi32(vacc3x4567, _mm512_madd_epi16(va3, vb4567));
1569 const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
1570
1571 vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
1572 vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
1573 vacc2x89AB = _mm512_add_epi32(vacc2x89AB, _mm512_madd_epi16(va2, vb89AB));
1574 vacc3x89AB = _mm512_add_epi32(vacc3x89AB, _mm512_madd_epi16(va3, vb89AB));
1575 const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
1576
1577 vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
1578 vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
1579 vacc2xCDEF = _mm512_add_epi32(vacc2xCDEF, _mm512_madd_epi16(va2, vbCDEF));
1580 vacc3xCDEF = _mm512_add_epi32(vacc3xCDEF, _mm512_madd_epi16(va3, vbCDEF));
1581
1582 w = (const void*) ((const int8_t*) w + 128);
1583 k += 8 * sizeof(int8_t);
1584 }
1585
1586 const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
1587 const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
1588 const __m512i vacc1x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x0123, vacc1x4567), _mm512_unpackhi_epi32(vacc1x0123, vacc1x4567));
1589 const __m512i vacc1x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x89AB, vacc1xCDEF), _mm512_unpackhi_epi32(vacc1x89AB, vacc1xCDEF));
1590 const __m512i vacc2x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x0123, vacc2x4567), _mm512_unpackhi_epi32(vacc2x0123, vacc2x4567));
1591 const __m512i vacc2x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x89AB, vacc2xCDEF), _mm512_unpackhi_epi32(vacc2x89AB, vacc2xCDEF));
1592 const __m512i vacc3x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x0123, vacc3x4567), _mm512_unpackhi_epi32(vacc3x0123, vacc3x4567));
1593 const __m512i vacc3x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x89AB, vacc3xCDEF), _mm512_unpackhi_epi32(vacc3x89AB, vacc3xCDEF));
1594
1595 __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
1596 __m512i vacc1x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x04152637, vacc1x8C9DAEBF), _mm512_unpackhi_epi32(vacc1x04152637, vacc1x8C9DAEBF));
1597 __m512i vacc2x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x04152637, vacc2x8C9DAEBF), _mm512_unpackhi_epi32(vacc2x04152637, vacc2x8C9DAEBF));
1598 __m512i vacc3x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x04152637, vacc3x8C9DAEBF), _mm512_unpackhi_epi32(vacc3x04152637, vacc3x8C9DAEBF));
1599
1600 __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
1601 __m512 vscaled1x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc1x084C195D2A6E3B7F);
1602 __m512 vscaled2x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc2x084C195D2A6E3B7F);
1603 __m512 vscaled3x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc3x084C195D2A6E3B7F);
1604
1605 const __m512 vscale012345678ABCDEF = _mm512_load_ps(w);
1606 w = (const void*) ((const float*) w + 16);
1607 const __m512 vscale084C195D2A6E3B7F = _mm512_permutexvar_ps(_mm512_set_epi32(15, 7, 11, 3, 14, 6, 10, 2, 13, 5, 9, 1, 12, 4, 8, 0), vscale012345678ABCDEF);
1608 vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
1609 vscaled1x084C195D2A6E3B7F = _mm512_mul_ps(vscaled1x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
1610 vscaled2x084C195D2A6E3B7F = _mm512_mul_ps(vscaled2x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
1611 vscaled3x084C195D2A6E3B7F = _mm512_mul_ps(vscaled3x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
1612
1613 vscaled0x084C195D2A6E3B7F = _mm512_min_ps(vscaled0x084C195D2A6E3B7F, voutput_max_less_zero_point);
1614 vscaled1x084C195D2A6E3B7F = _mm512_min_ps(vscaled1x084C195D2A6E3B7F, voutput_max_less_zero_point);
1615 vscaled2x084C195D2A6E3B7F = _mm512_min_ps(vscaled2x084C195D2A6E3B7F, voutput_max_less_zero_point);
1616 vscaled3x084C195D2A6E3B7F = _mm512_min_ps(vscaled3x084C195D2A6E3B7F, voutput_max_less_zero_point);
1617
1618 vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
1619 vacc1x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled1x084C195D2A6E3B7F);
1620 vacc2x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled2x084C195D2A6E3B7F);
1621 vacc3x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled3x084C195D2A6E3B7F);
1622
1623 const __m512i vacc01x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc0x084C195D2A6E3B7F, vacc1x084C195D2A6E3B7F), voutput_zero_point);
1624 const __m512i vacc23x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc2x084C195D2A6E3B7F, vacc3x084C195D2A6E3B7F), voutput_zero_point);
1625
1626 __m512i vout0123x084Cx195Dx2A6Ex3B7F = _mm512_packs_epi16(vacc01x084Cx195Dx2A6Ex3B7F, vacc23x084Cx195Dx2A6Ex3B7F);
1627 vout0123x084Cx195Dx2A6Ex3B7F = _mm512_permutexvar_epi32(_mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0), vout0123x084Cx195Dx2A6Ex3B7F);
1628 __m512i vout0123x0123456789ABCDEF = _mm512_shuffle_epi8(vout0123x084Cx195Dx2A6Ex3B7F, _mm512_set_epi8(15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0));
1629 vout0123x0123456789ABCDEF = _mm512_max_epi8(vout0123x0123456789ABCDEF, voutput_min);
1630
1631 if (nc >= 16) {
1632 _mm_storeu_si128((__m128i*) c0, _mm512_castsi512_si128(vout0123x0123456789ABCDEF));
1633 _mm_storeu_si128((__m128i*) c1, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 1));
1634 _mm_storeu_si128((__m128i*) c2, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 2));
1635 _mm_storeu_si128((__m128i*) c3, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 3));
1636
1637 a0 = (const int8_t*) ((uintptr_t) a0 - k);
1638 a1 = (const int8_t*) ((uintptr_t) a1 - k);
1639 a2 = (const int8_t*) ((uintptr_t) a2 - k);
1640 a3 = (const int8_t*) ((uintptr_t) a3 - k);
1641
1642 c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
1643 c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
1644 c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
1645 c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
1646
1647 nc -= 16;
1648 } else {
1649 // Prepare mask for valid 8-bit elements (depends on nc).
1650 __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT32_C(1) << nc) - UINT32_C(1)));
1651
1652 _mm512_mask_storeu_epi8(c0, vmask, vout0123x0123456789ABCDEF);
1653 vmask = _kshiftli_mask64(vmask, 16);
1654 _mm512_mask_storeu_epi8(c1 - 16, vmask, vout0123x0123456789ABCDEF);
1655 vmask = _kshiftli_mask64(vmask, 16);
1656 _mm512_mask_storeu_epi8(c2 - 32, vmask, vout0123x0123456789ABCDEF);
1657 vmask = _kshiftli_mask64(vmask, 16);
1658 _mm512_mask_storeu_epi8(c3 - 48, vmask, vout0123x0123456789ABCDEF);
1659
1660 nc = 0;
1661 }
1662 } while (nc != 0);
1663 }
1664
xnn_qc8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx(size_t mr,size_t nc,size_t kc,size_t ks,const int8_t ** restrict a,const void * restrict w,int8_t * restrict c,size_t cm_stride,size_t cn_stride,size_t a_offset,const int8_t * zero,const union xnn_qc8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1665 void xnn_qc8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx(
1666 size_t mr,
1667 size_t nc,
1668 size_t kc,
1669 size_t ks,
1670 const int8_t** restrict a,
1671 const void* restrict w,
1672 int8_t* restrict c,
1673 size_t cm_stride,
1674 size_t cn_stride,
1675 size_t a_offset,
1676 const int8_t* zero,
1677 const union xnn_qc8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
1678 {
1679 assert(mr != 0);
1680 assert(mr <= 1);
1681 assert(nc != 0);
1682 assert(kc != 0);
1683 assert(kc % sizeof(int8_t) == 0);
1684 assert(a != NULL);
1685 assert(w != NULL);
1686 assert(c != NULL);
1687
1688 kc = round_up_po2(kc, 8);
1689 int8_t* c0 = c;
1690
1691 const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
1692 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
1693 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point);
1694 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min);
1695 do {
1696 __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
1697 __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
1698 __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
1699 __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
1700 w = (const void*) ((const int32_t*) w + 16);
1701
1702 size_t p = ks;
1703 do {
1704 const int8_t* restrict a0 = a[0];
1705 if XNN_UNPREDICTABLE(a0 != zero) {
1706 a0 = (const int8_t*) ((uintptr_t) a0 + a_offset);
1707 }
1708 a += 1;
1709
1710 size_t k = 0;
1711 while (k < kc) {
1712 const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
1713 a0 += 8;
1714
1715 const __m512i vb0123 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) w));
1716
1717 vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
1718 const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
1719
1720 vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
1721 const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
1722
1723 vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
1724 const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
1725
1726 vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
1727
1728 w = (const void*) ((const int8_t*) w + 128);
1729 k += 8 * sizeof(int8_t);
1730 }
1731 p -= 1 * sizeof(void*);
1732 } while (p != 0);
1733
1734 const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
1735 const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
1736
1737 __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
1738
1739 __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
1740
1741 const __m512 vscale012345678ABCDEF = _mm512_load_ps(w);
1742 w = (const void*) ((const float*) w + 16);
1743 const __m512 vscale084C195D2A6E3B7F = _mm512_permutexvar_ps(_mm512_set_epi32(15, 7, 11, 3, 14, 6, 10, 2, 13, 5, 9, 1, 12, 4, 8, 0), vscale012345678ABCDEF);
1744 vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
1745
1746 vscaled0x084C195D2A6E3B7F = _mm512_min_ps(vscaled0x084C195D2A6E3B7F, voutput_max_less_zero_point);
1747
1748 vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
1749
1750 const __m256i vacc0x084C2A6E195D3B7F = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0x084C195D2A6E3B7F), _mm512_extracti32x8_epi32(vacc0x084C195D2A6E3B7F, 1)), voutput_zero_point);
1751
1752 const __m128i vout0x084C2A6E195D3B7F = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x084C2A6E195D3B7F), _mm256_extracti128_si256(vacc0x084C2A6E195D3B7F, 1));
1753 __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x084C2A6E195D3B7F, _mm_set_epi8(15, 7, 11, 3, 13, 5, 9, 1, 14, 6, 10, 2, 12, 4, 8, 0));
1754 vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min);
1755
1756 if (nc >= 16) {
1757 _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF);
1758
1759 c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
1760
1761 a = (const int8_t**restrict) ((uintptr_t) a - ks);
1762
1763 nc -= 16;
1764 } else {
1765 // Prepare mask for valid 8-bit elements (depends on nc).
1766 const __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT32_C(1) << nc) - UINT32_C(1)));
1767
1768 _mm_mask_storeu_epi8(c0, vmask, vout0x0123456789ABCDEF);
1769
1770 nc = 0;
1771 }
1772 } while (nc != 0);
1773 }
1774
xnn_qc8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx(size_t mr,size_t nc,size_t kc,size_t ks,const int8_t ** restrict a,const void * restrict w,int8_t * restrict c,size_t cm_stride,size_t cn_stride,size_t a_offset,const int8_t * zero,const union xnn_qc8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1775 void xnn_qc8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx(
1776 size_t mr,
1777 size_t nc,
1778 size_t kc,
1779 size_t ks,
1780 const int8_t** restrict a,
1781 const void* restrict w,
1782 int8_t* restrict c,
1783 size_t cm_stride,
1784 size_t cn_stride,
1785 size_t a_offset,
1786 const int8_t* zero,
1787 const union xnn_qc8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
1788 {
1789 assert(mr != 0);
1790 assert(mr <= 4);
1791 assert(nc != 0);
1792 assert(kc != 0);
1793 assert(kc % sizeof(int8_t) == 0);
1794 assert(a != NULL);
1795 assert(w != NULL);
1796 assert(c != NULL);
1797
1798 kc = round_up_po2(kc, 8);
1799 int8_t* c0 = c;
1800 int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
1801 if XNN_UNPREDICTABLE(mr < 2) {
1802 c1 = c0;
1803 }
1804 int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride);
1805 if XNN_UNPREDICTABLE(mr <= 2) {
1806 c2 = c1;
1807 }
1808 int8_t* c3 = (int8_t*) ((uintptr_t) c2 + cm_stride);
1809 if XNN_UNPREDICTABLE(mr != 4) {
1810 c3 = c2;
1811 }
1812
1813 const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
1814 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
1815 const __m512i voutput_zero_point = _mm512_load_si512(params->fp32_avx512.output_zero_point);
1816 const __m512i voutput_min = _mm512_load_si512(params->fp32_avx512.output_min);
1817 do {
1818 __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
1819 __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
1820 __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
1821 __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
1822 __m512i vacc1x0123 = vacc0x0123;
1823 __m512i vacc1x4567 = vacc0x4567;
1824 __m512i vacc1x89AB = vacc0x89AB;
1825 __m512i vacc1xCDEF = vacc0xCDEF;
1826 __m512i vacc2x0123 = vacc0x0123;
1827 __m512i vacc2x4567 = vacc0x4567;
1828 __m512i vacc2x89AB = vacc0x89AB;
1829 __m512i vacc2xCDEF = vacc0xCDEF;
1830 __m512i vacc3x0123 = vacc0x0123;
1831 __m512i vacc3x4567 = vacc0x4567;
1832 __m512i vacc3x89AB = vacc0x89AB;
1833 __m512i vacc3xCDEF = vacc0xCDEF;
1834 w = (const void*) ((const int32_t*) w + 16);
1835
1836 size_t p = ks;
1837 do {
1838 const int8_t* restrict a0 = a[0];
1839 if XNN_UNPREDICTABLE(a0 != zero) {
1840 a0 = (const int8_t*) ((uintptr_t) a0 + a_offset);
1841 }
1842 const int8_t* restrict a1 = a[1];
1843 if XNN_UNPREDICTABLE(a1 != zero) {
1844 a1 = (const int8_t*) ((uintptr_t) a1 + a_offset);
1845 }
1846 const int8_t* restrict a2 = a[2];
1847 if XNN_UNPREDICTABLE(a2 != zero) {
1848 a2 = (const int8_t*) ((uintptr_t) a2 + a_offset);
1849 }
1850 const int8_t* restrict a3 = a[3];
1851 if XNN_UNPREDICTABLE(a3 != zero) {
1852 a3 = (const int8_t*) ((uintptr_t) a3 + a_offset);
1853 }
1854 a += 4;
1855
1856 size_t k = 0;
1857 while (k < kc) {
1858 const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
1859 a0 += 8;
1860 const __m512i va1 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a1)));
1861 a1 += 8;
1862 const __m512i va2 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a2)));
1863 a2 += 8;
1864 const __m512i va3 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a3)));
1865 a3 += 8;
1866
1867 const __m512i vb0123 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) w));
1868
1869 vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
1870 vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
1871 vacc2x0123 = _mm512_add_epi32(vacc2x0123, _mm512_madd_epi16(va2, vb0123));
1872 vacc3x0123 = _mm512_add_epi32(vacc3x0123, _mm512_madd_epi16(va3, vb0123));
1873 const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
1874
1875 vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
1876 vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
1877 vacc2x4567 = _mm512_add_epi32(vacc2x4567, _mm512_madd_epi16(va2, vb4567));
1878 vacc3x4567 = _mm512_add_epi32(vacc3x4567, _mm512_madd_epi16(va3, vb4567));
1879 const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
1880
1881 vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
1882 vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
1883 vacc2x89AB = _mm512_add_epi32(vacc2x89AB, _mm512_madd_epi16(va2, vb89AB));
1884 vacc3x89AB = _mm512_add_epi32(vacc3x89AB, _mm512_madd_epi16(va3, vb89AB));
1885 const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
1886
1887 vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
1888 vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
1889 vacc2xCDEF = _mm512_add_epi32(vacc2xCDEF, _mm512_madd_epi16(va2, vbCDEF));
1890 vacc3xCDEF = _mm512_add_epi32(vacc3xCDEF, _mm512_madd_epi16(va3, vbCDEF));
1891
1892 w = (const void*) ((const int8_t*) w + 128);
1893 k += 8 * sizeof(int8_t);
1894 }
1895 p -= 4 * sizeof(void*);
1896 } while (p != 0);
1897
1898 const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
1899 const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
1900 const __m512i vacc1x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x0123, vacc1x4567), _mm512_unpackhi_epi32(vacc1x0123, vacc1x4567));
1901 const __m512i vacc1x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x89AB, vacc1xCDEF), _mm512_unpackhi_epi32(vacc1x89AB, vacc1xCDEF));
1902 const __m512i vacc2x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x0123, vacc2x4567), _mm512_unpackhi_epi32(vacc2x0123, vacc2x4567));
1903 const __m512i vacc2x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x89AB, vacc2xCDEF), _mm512_unpackhi_epi32(vacc2x89AB, vacc2xCDEF));
1904 const __m512i vacc3x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x0123, vacc3x4567), _mm512_unpackhi_epi32(vacc3x0123, vacc3x4567));
1905 const __m512i vacc3x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x89AB, vacc3xCDEF), _mm512_unpackhi_epi32(vacc3x89AB, vacc3xCDEF));
1906
1907 __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
1908 __m512i vacc1x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x04152637, vacc1x8C9DAEBF), _mm512_unpackhi_epi32(vacc1x04152637, vacc1x8C9DAEBF));
1909 __m512i vacc2x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x04152637, vacc2x8C9DAEBF), _mm512_unpackhi_epi32(vacc2x04152637, vacc2x8C9DAEBF));
1910 __m512i vacc3x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x04152637, vacc3x8C9DAEBF), _mm512_unpackhi_epi32(vacc3x04152637, vacc3x8C9DAEBF));
1911
1912 __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
1913 __m512 vscaled1x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc1x084C195D2A6E3B7F);
1914 __m512 vscaled2x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc2x084C195D2A6E3B7F);
1915 __m512 vscaled3x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc3x084C195D2A6E3B7F);
1916
1917 const __m512 vscale012345678ABCDEF = _mm512_load_ps(w);
1918 w = (const void*) ((const float*) w + 16);
1919 const __m512 vscale084C195D2A6E3B7F = _mm512_permutexvar_ps(_mm512_set_epi32(15, 7, 11, 3, 14, 6, 10, 2, 13, 5, 9, 1, 12, 4, 8, 0), vscale012345678ABCDEF);
1920 vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
1921 vscaled1x084C195D2A6E3B7F = _mm512_mul_ps(vscaled1x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
1922 vscaled2x084C195D2A6E3B7F = _mm512_mul_ps(vscaled2x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
1923 vscaled3x084C195D2A6E3B7F = _mm512_mul_ps(vscaled3x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
1924
1925 vscaled0x084C195D2A6E3B7F = _mm512_min_ps(vscaled0x084C195D2A6E3B7F, voutput_max_less_zero_point);
1926 vscaled1x084C195D2A6E3B7F = _mm512_min_ps(vscaled1x084C195D2A6E3B7F, voutput_max_less_zero_point);
1927 vscaled2x084C195D2A6E3B7F = _mm512_min_ps(vscaled2x084C195D2A6E3B7F, voutput_max_less_zero_point);
1928 vscaled3x084C195D2A6E3B7F = _mm512_min_ps(vscaled3x084C195D2A6E3B7F, voutput_max_less_zero_point);
1929
1930 vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
1931 vacc1x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled1x084C195D2A6E3B7F);
1932 vacc2x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled2x084C195D2A6E3B7F);
1933 vacc3x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled3x084C195D2A6E3B7F);
1934
1935 const __m512i vacc01x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc0x084C195D2A6E3B7F, vacc1x084C195D2A6E3B7F), voutput_zero_point);
1936 const __m512i vacc23x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc2x084C195D2A6E3B7F, vacc3x084C195D2A6E3B7F), voutput_zero_point);
1937
1938 __m512i vout0123x084Cx195Dx2A6Ex3B7F = _mm512_packs_epi16(vacc01x084Cx195Dx2A6Ex3B7F, vacc23x084Cx195Dx2A6Ex3B7F);
1939 vout0123x084Cx195Dx2A6Ex3B7F = _mm512_permutexvar_epi32(_mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0), vout0123x084Cx195Dx2A6Ex3B7F);
1940 __m512i vout0123x0123456789ABCDEF = _mm512_shuffle_epi8(vout0123x084Cx195Dx2A6Ex3B7F, _mm512_set_epi8(15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0));
1941 vout0123x0123456789ABCDEF = _mm512_max_epi8(vout0123x0123456789ABCDEF, voutput_min);
1942
1943 if (nc >= 16) {
1944 _mm_storeu_si128((__m128i*) c3, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 3));
1945 _mm_storeu_si128((__m128i*) c2, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 2));
1946 _mm_storeu_si128((__m128i*) c1, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 1));
1947 _mm_storeu_si128((__m128i*) c0, _mm512_castsi512_si128(vout0123x0123456789ABCDEF));
1948
1949 c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
1950 c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
1951 c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
1952 c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
1953
1954 a = (const int8_t**restrict) ((uintptr_t) a - ks);
1955
1956 nc -= 16;
1957 } else {
1958 // Prepare mask for valid 8-bit elements (depends on nc).
1959 __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT64_C(1) << (nc + 48)) - (UINT64_C(1) << 48)));
1960
1961 _mm512_mask_storeu_epi8(c3 - 48, vmask, vout0123x0123456789ABCDEF);
1962 vmask = _kshiftri_mask64(vmask, 16);
1963 _mm512_mask_storeu_epi8(c2 - 32, vmask, vout0123x0123456789ABCDEF);
1964 vmask = _kshiftri_mask64(vmask, 16);
1965 _mm512_mask_storeu_epi8(c1 - 16, vmask, vout0123x0123456789ABCDEF);
1966 vmask = _kshiftri_mask64(vmask, 16);
1967 _mm512_mask_storeu_epi8(c0, vmask, vout0123x0123456789ABCDEF);
1968
1969 nc = 0;
1970 }
1971 } while (nc != 0);
1972 }
1973
xnn_qs8_dwconv_minmax_fp32_ukernel_up32x25__avx512skx_mul32(size_t channels,size_t output_width,const int8_t ** input,const void * weights,int8_t * output,size_t input_stride,size_t output_increment,size_t input_offset,const int8_t * zero,const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1974 void xnn_qs8_dwconv_minmax_fp32_ukernel_up32x25__avx512skx_mul32(
1975 size_t channels,
1976 size_t output_width,
1977 const int8_t** input,
1978 const void* weights,
1979 int8_t* output,
1980 size_t input_stride,
1981 size_t output_increment,
1982 size_t input_offset,
1983 const int8_t* zero,
1984 const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_MSAN
1985 {
1986 assert(channels != 0);
1987 assert(output_width != 0);
1988
1989 const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
1990 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
1991 const __m512i voutput_zero_point = _mm512_load_si512(params->fp32_avx512.output_zero_point);
1992 const __m256i voutput_min = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_min);
1993 const __m256i vpermute_mask = _mm256_set_epi32(7, 3, 5, 1, 6, 2, 4, 0);
1994
1995 do {
1996 const int8_t* i0 = input[0];
1997 assert(i0 != NULL);
1998 if XNN_UNPREDICTABLE(i0 != zero) {
1999 i0 = (const int8_t*) ((uintptr_t) i0 + input_offset);
2000 }
2001 const int8_t* i1 = input[1];
2002 assert(i1 != NULL);
2003 if XNN_UNPREDICTABLE(i1 != zero) {
2004 i1 = (const int8_t*) ((uintptr_t) i1 + input_offset);
2005 }
2006 const int8_t* i2 = input[2];
2007 assert(i2 != NULL);
2008 if XNN_UNPREDICTABLE(i2 != zero) {
2009 i2 = (const int8_t*) ((uintptr_t) i2 + input_offset);
2010 }
2011 const int8_t* i3 = input[3];
2012 assert(i3 != NULL);
2013 if XNN_UNPREDICTABLE(i3 != zero) {
2014 i3 = (const int8_t*) ((uintptr_t) i3 + input_offset);
2015 }
2016 const int8_t* i4 = input[4];
2017 assert(i4 != NULL);
2018 if XNN_UNPREDICTABLE(i4 != zero) {
2019 i4 = (const int8_t*) ((uintptr_t) i4 + input_offset);
2020 }
2021 const int8_t* i5 = input[5];
2022 assert(i5 != NULL);
2023 if XNN_UNPREDICTABLE(i5 != zero) {
2024 i5 = (const int8_t*) ((uintptr_t) i5 + input_offset);
2025 }
2026 const int8_t* i6 = input[6];
2027 assert(i6 != NULL);
2028 if XNN_UNPREDICTABLE(i6 != zero) {
2029 i6 = (const int8_t*) ((uintptr_t) i6 + input_offset);
2030 }
2031 const int8_t* i7 = input[7];
2032 assert(i7 != NULL);
2033 if XNN_UNPREDICTABLE(i7 != zero) {
2034 i7 = (const int8_t*) ((uintptr_t) i7 + input_offset);
2035 }
2036 const int8_t* i8 = input[8];
2037 assert(i8 != NULL);
2038 if XNN_UNPREDICTABLE(i8 != zero) {
2039 i8 = (const int8_t*) ((uintptr_t) i8 + input_offset);
2040 }
2041 const int8_t* i9 = input[9];
2042 assert(i9 != NULL);
2043 if XNN_UNPREDICTABLE(i9 != zero) {
2044 i9 = (const int8_t*) ((uintptr_t) i9 + input_offset);
2045 }
2046 const int8_t* i10 = input[10];
2047 assert(i10 != NULL);
2048 if XNN_UNPREDICTABLE(i10 != zero) {
2049 i10 = (const int8_t*) ((uintptr_t) i10 + input_offset);
2050 }
2051 const int8_t* i11 = input[11];
2052 assert(i11 != NULL);
2053 if XNN_UNPREDICTABLE(i11 != zero) {
2054 i11 = (const int8_t*) ((uintptr_t) i11 + input_offset);
2055 }
2056 const int8_t* i12 = input[12];
2057 assert(i12 != NULL);
2058 if XNN_UNPREDICTABLE(i12 != zero) {
2059 i12 = (const int8_t*) ((uintptr_t) i12 + input_offset);
2060 }
2061 const int8_t* i13 = input[13];
2062 assert(i13 != NULL);
2063 if XNN_UNPREDICTABLE(i13 != zero) {
2064 i13 = (const int8_t*) ((uintptr_t) i13 + input_offset);
2065 }
2066 const int8_t* i14 = input[14];
2067 assert(i14 != NULL);
2068 if XNN_UNPREDICTABLE(i14 != zero) {
2069 i14 = (const int8_t*) ((uintptr_t) i14 + input_offset);
2070 }
2071 const int8_t* i15 = input[15];
2072 assert(i15 != NULL);
2073 if XNN_UNPREDICTABLE(i15 != zero) {
2074 i15 = (const int8_t*) ((uintptr_t) i15 + input_offset);
2075 }
2076 const int8_t* i16 = input[16];
2077 assert(i16 != NULL);
2078 if XNN_UNPREDICTABLE(i16 != zero) {
2079 i16 = (const int8_t*) ((uintptr_t) i16 + input_offset);
2080 }
2081 const int8_t* i17 = input[17];
2082 assert(i17 != NULL);
2083 if XNN_UNPREDICTABLE(i17 != zero) {
2084 i17 = (const int8_t*) ((uintptr_t) i17 + input_offset);
2085 }
2086 const int8_t* i18 = input[18];
2087 assert(i18 != NULL);
2088 if XNN_UNPREDICTABLE(i18 != zero) {
2089 i18 = (const int8_t*) ((uintptr_t) i18 + input_offset);
2090 }
2091 const int8_t* i19 = input[19];
2092 assert(i19 != NULL);
2093 if XNN_UNPREDICTABLE(i19 != zero) {
2094 i19 = (const int8_t*) ((uintptr_t) i19 + input_offset);
2095 }
2096 const int8_t* i20 = input[20];
2097 assert(i20 != NULL);
2098 if XNN_UNPREDICTABLE(i20 != zero) {
2099 i20 = (const int8_t*) ((uintptr_t) i20 + input_offset);
2100 }
2101 const int8_t* i21 = input[21];
2102 assert(i21 != NULL);
2103 if XNN_UNPREDICTABLE(i21 != zero) {
2104 i21 = (const int8_t*) ((uintptr_t) i21 + input_offset);
2105 }
2106 const int8_t* i22 = input[22];
2107 assert(i22 != NULL);
2108 if XNN_UNPREDICTABLE(i22 != zero) {
2109 i22 = (const int8_t*) ((uintptr_t) i22 + input_offset);
2110 }
2111 const int8_t* i23 = input[23];
2112 assert(i23 != NULL);
2113 if XNN_UNPREDICTABLE(i23 != zero) {
2114 i23 = (const int8_t*) ((uintptr_t) i23 + input_offset);
2115 }
2116 const int8_t* i24 = input[24];
2117 assert(i24 != NULL);
2118 if XNN_UNPREDICTABLE(i24 != zero) {
2119 i24 = (const int8_t*) ((uintptr_t) i24 + input_offset);
2120 }
2121 input = (const int8_t**) ((uintptr_t) input + input_stride);
2122
2123 size_t c = channels;
2124 const void* w = weights;
2125 for (; c >= 32; c -= 32) {
2126 __m512i vacc0123456789ABCDEF = _mm512_loadu_si512(w);
2127 __m512i vaccGHIJKLMNOPQRSTUV = _mm512_loadu_si512((const void*) ((uintptr_t) w + 16 * sizeof(int32_t)));
2128
2129
2130 const __m512i vi0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i0));
2131 const __m512i vk0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 0 * sizeof(int8_t))));
2132 const __m512i vi0xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i0 + 16)));
2133 const __m512i vk0xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 16 * sizeof(int8_t))));
2134 i0 += 32;
2135
2136 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF));
2137 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi0xGHIJKLMNOPQRSTUV, vk0xGHIJKLMNOPQRSTUV));
2138
2139 const __m512i vi1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i1));
2140 const __m512i vk1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 32 * sizeof(int8_t))));
2141 const __m512i vi1xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i1 + 16)));
2142 const __m512i vk1xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 48 * sizeof(int8_t))));
2143 i1 += 32;
2144
2145 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF));
2146 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi1xGHIJKLMNOPQRSTUV, vk1xGHIJKLMNOPQRSTUV));
2147
2148 const __m512i vi2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i2));
2149 const __m512i vk2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 64 * sizeof(int8_t))));
2150 const __m512i vi2xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i2 + 16)));
2151 const __m512i vk2xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 80 * sizeof(int8_t))));
2152 i2 += 32;
2153
2154 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF));
2155 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi2xGHIJKLMNOPQRSTUV, vk2xGHIJKLMNOPQRSTUV));
2156
2157 const __m512i vi3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i3));
2158 const __m512i vk3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 96 * sizeof(int8_t))));
2159 const __m512i vi3xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i3 + 16)));
2160 const __m512i vk3xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 112 * sizeof(int8_t))));
2161 i3 += 32;
2162
2163 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF));
2164 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi3xGHIJKLMNOPQRSTUV, vk3xGHIJKLMNOPQRSTUV));
2165
2166 const __m512i vi4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i4));
2167 const __m512i vk4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 128 * sizeof(int8_t))));
2168 const __m512i vi4xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i4 + 16)));
2169 const __m512i vk4xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 144 * sizeof(int8_t))));
2170 i4 += 32;
2171
2172 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF));
2173 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi4xGHIJKLMNOPQRSTUV, vk4xGHIJKLMNOPQRSTUV));
2174
2175 const __m512i vi5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i5));
2176 const __m512i vk5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 160 * sizeof(int8_t))));
2177 const __m512i vi5xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i5 + 16)));
2178 const __m512i vk5xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 176 * sizeof(int8_t))));
2179 i5 += 32;
2180
2181 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF));
2182 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi5xGHIJKLMNOPQRSTUV, vk5xGHIJKLMNOPQRSTUV));
2183
2184 const __m512i vi6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i6));
2185 const __m512i vk6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 192 * sizeof(int8_t))));
2186 const __m512i vi6xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i6 + 16)));
2187 const __m512i vk6xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 208 * sizeof(int8_t))));
2188 i6 += 32;
2189
2190 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF));
2191 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi6xGHIJKLMNOPQRSTUV, vk6xGHIJKLMNOPQRSTUV));
2192
2193 const __m512i vi7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i7));
2194 const __m512i vk7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 224 * sizeof(int8_t))));
2195 const __m512i vi7xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i7 + 16)));
2196 const __m512i vk7xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 240 * sizeof(int8_t))));
2197 i7 += 32;
2198
2199 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF));
2200 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi7xGHIJKLMNOPQRSTUV, vk7xGHIJKLMNOPQRSTUV));
2201
2202 const __m512i vi8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i8));
2203 const __m512i vk8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 256 * sizeof(int8_t))));
2204 const __m512i vi8xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i8 + 16)));
2205 const __m512i vk8xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 272 * sizeof(int8_t))));
2206 i8 += 32;
2207
2208 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF));
2209 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi8xGHIJKLMNOPQRSTUV, vk8xGHIJKLMNOPQRSTUV));
2210
2211 const __m512i vi9x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i9));
2212 const __m512i vk9x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 288 * sizeof(int8_t))));
2213 const __m512i vi9xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i9 + 16)));
2214 const __m512i vk9xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 304 * sizeof(int8_t))));
2215 i9 += 32;
2216
2217 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi9x0123456789ABCDEF, vk9x0123456789ABCDEF));
2218 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi9xGHIJKLMNOPQRSTUV, vk9xGHIJKLMNOPQRSTUV));
2219
2220 const __m512i vi10x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i10));
2221 const __m512i vk10x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 320 * sizeof(int8_t))));
2222 const __m512i vi10xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i10 + 16)));
2223 const __m512i vk10xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 336 * sizeof(int8_t))));
2224 i10 += 32;
2225
2226 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi10x0123456789ABCDEF, vk10x0123456789ABCDEF));
2227 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi10xGHIJKLMNOPQRSTUV, vk10xGHIJKLMNOPQRSTUV));
2228
2229 const __m512i vi11x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i11));
2230 const __m512i vk11x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 352 * sizeof(int8_t))));
2231 const __m512i vi11xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i11 + 16)));
2232 const __m512i vk11xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 368 * sizeof(int8_t))));
2233 i11 += 32;
2234
2235 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi11x0123456789ABCDEF, vk11x0123456789ABCDEF));
2236 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi11xGHIJKLMNOPQRSTUV, vk11xGHIJKLMNOPQRSTUV));
2237
2238 const __m512i vi12x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i12));
2239 const __m512i vk12x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 384 * sizeof(int8_t))));
2240 const __m512i vi12xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i12 + 16)));
2241 const __m512i vk12xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 400 * sizeof(int8_t))));
2242 i12 += 32;
2243
2244 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi12x0123456789ABCDEF, vk12x0123456789ABCDEF));
2245 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi12xGHIJKLMNOPQRSTUV, vk12xGHIJKLMNOPQRSTUV));
2246
2247 const __m512i vi13x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i13));
2248 const __m512i vk13x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 416 * sizeof(int8_t))));
2249 const __m512i vi13xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i13 + 16)));
2250 const __m512i vk13xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 432 * sizeof(int8_t))));
2251 i13 += 32;
2252
2253 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi13x0123456789ABCDEF, vk13x0123456789ABCDEF));
2254 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi13xGHIJKLMNOPQRSTUV, vk13xGHIJKLMNOPQRSTUV));
2255
2256 const __m512i vi14x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i14));
2257 const __m512i vk14x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 448 * sizeof(int8_t))));
2258 const __m512i vi14xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i14 + 16)));
2259 const __m512i vk14xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 464 * sizeof(int8_t))));
2260 i14 += 32;
2261
2262 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi14x0123456789ABCDEF, vk14x0123456789ABCDEF));
2263 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi14xGHIJKLMNOPQRSTUV, vk14xGHIJKLMNOPQRSTUV));
2264
2265 const __m512i vi15x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i15));
2266 const __m512i vk15x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 480 * sizeof(int8_t))));
2267 const __m512i vi15xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i15 + 16)));
2268 const __m512i vk15xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 496 * sizeof(int8_t))));
2269 i15 += 32;
2270
2271 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi15x0123456789ABCDEF, vk15x0123456789ABCDEF));
2272 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi15xGHIJKLMNOPQRSTUV, vk15xGHIJKLMNOPQRSTUV));
2273
2274 const __m512i vi16x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i16));
2275 const __m512i vk16x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 512 * sizeof(int8_t))));
2276 const __m512i vi16xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i16 + 16)));
2277 const __m512i vk16xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 528 * sizeof(int8_t))));
2278 i16 += 32;
2279
2280 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi16x0123456789ABCDEF, vk16x0123456789ABCDEF));
2281 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi16xGHIJKLMNOPQRSTUV, vk16xGHIJKLMNOPQRSTUV));
2282
2283 const __m512i vi17x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i17));
2284 const __m512i vk17x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 544 * sizeof(int8_t))));
2285 const __m512i vi17xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i17 + 16)));
2286 const __m512i vk17xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 560 * sizeof(int8_t))));
2287 i17 += 32;
2288
2289 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi17x0123456789ABCDEF, vk17x0123456789ABCDEF));
2290 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi17xGHIJKLMNOPQRSTUV, vk17xGHIJKLMNOPQRSTUV));
2291
2292 const __m512i vi18x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i18));
2293 const __m512i vk18x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 576 * sizeof(int8_t))));
2294 const __m512i vi18xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i18 + 16)));
2295 const __m512i vk18xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 592 * sizeof(int8_t))));
2296 i18 += 32;
2297
2298 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi18x0123456789ABCDEF, vk18x0123456789ABCDEF));
2299 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi18xGHIJKLMNOPQRSTUV, vk18xGHIJKLMNOPQRSTUV));
2300
2301 const __m512i vi19x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i19));
2302 const __m512i vk19x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 608 * sizeof(int8_t))));
2303 const __m512i vi19xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i19 + 16)));
2304 const __m512i vk19xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 624 * sizeof(int8_t))));
2305 i19 += 32;
2306
2307 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi19x0123456789ABCDEF, vk19x0123456789ABCDEF));
2308 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi19xGHIJKLMNOPQRSTUV, vk19xGHIJKLMNOPQRSTUV));
2309
2310 const __m512i vi20x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i20));
2311 const __m512i vk20x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 640 * sizeof(int8_t))));
2312 const __m512i vi20xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i20 + 16)));
2313 const __m512i vk20xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 656 * sizeof(int8_t))));
2314 i20 += 32;
2315
2316 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi20x0123456789ABCDEF, vk20x0123456789ABCDEF));
2317 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi20xGHIJKLMNOPQRSTUV, vk20xGHIJKLMNOPQRSTUV));
2318
2319 const __m512i vi21x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i21));
2320 const __m512i vk21x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 672 * sizeof(int8_t))));
2321 const __m512i vi21xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i21 + 16)));
2322 const __m512i vk21xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 688 * sizeof(int8_t))));
2323 i21 += 32;
2324
2325 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi21x0123456789ABCDEF, vk21x0123456789ABCDEF));
2326 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi21xGHIJKLMNOPQRSTUV, vk21xGHIJKLMNOPQRSTUV));
2327
2328 const __m512i vi22x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i22));
2329 const __m512i vk22x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 704 * sizeof(int8_t))));
2330 const __m512i vi22xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i22 + 16)));
2331 const __m512i vk22xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 720 * sizeof(int8_t))));
2332 i22 += 32;
2333
2334 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi22x0123456789ABCDEF, vk22x0123456789ABCDEF));
2335 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi22xGHIJKLMNOPQRSTUV, vk22xGHIJKLMNOPQRSTUV));
2336
2337 const __m512i vi23x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i23));
2338 const __m512i vk23x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 736 * sizeof(int8_t))));
2339 const __m512i vi23xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i23 + 16)));
2340 const __m512i vk23xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 752 * sizeof(int8_t))));
2341 i23 += 32;
2342
2343 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi23x0123456789ABCDEF, vk23x0123456789ABCDEF));
2344 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi23xGHIJKLMNOPQRSTUV, vk23xGHIJKLMNOPQRSTUV));
2345
2346 const __m512i vi24x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i24));
2347 const __m512i vk24x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 768 * sizeof(int8_t))));
2348 const __m512i vi24xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i24 + 16)));
2349 const __m512i vk24xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 784 * sizeof(int8_t))));
2350 i24 += 32;
2351
2352 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi24x0123456789ABCDEF, vk24x0123456789ABCDEF));
2353 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi24xGHIJKLMNOPQRSTUV, vk24xGHIJKLMNOPQRSTUV));
2354
2355 w = (const void*) ((uintptr_t) w + 32 * sizeof(int32_t) + 800 * sizeof(int8_t));
2356
2357 __m512 vscaled0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0123456789ABCDEF);
2358 __m512 vscaledGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vaccGHIJKLMNOPQRSTUV);
2359
2360 vscaled0123456789ABCDEF = _mm512_mul_ps(vscaled0123456789ABCDEF, vscale);
2361 vscaledGHIJKLMNOPQRSTUV = _mm512_mul_ps(vscaledGHIJKLMNOPQRSTUV, vscale);
2362
2363 vscaled0123456789ABCDEF = _mm512_min_ps(vscaled0123456789ABCDEF, voutput_max_less_zero_point);
2364 vscaledGHIJKLMNOPQRSTUV = _mm512_min_ps(vscaledGHIJKLMNOPQRSTUV, voutput_max_less_zero_point);
2365
2366 vacc0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0123456789ABCDEF);
2367 vaccGHIJKLMNOPQRSTUV = _mm512_cvtps_epi32(vscaledGHIJKLMNOPQRSTUV);
2368
2369 __m512i vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV = _mm512_adds_epi16(_mm512_packs_epi32(vacc0123456789ABCDEF, vaccGHIJKLMNOPQRSTUV), voutput_zero_point);
2370 __m256i voutGHIJOPQRKLMNSTUV = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vaccGHIJKLMNOPQRSTUV), _mm512_extracti32x8_epi32(vaccGHIJKLMNOPQRSTUV, 1)), _mm512_castsi512_si256(voutput_zero_point));
2371
2372 const __m256i vout0123GHIJ4567KLMN = _mm512_castsi512_si256(vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV);
2373 const __m256i vout89ABOPQRCDEFSTUV = _mm512_extracti32x8_epi32(vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV, 1);
2374 const __m256i vout0123GHIJ89ABOPQR4567KLMNCDEFSTUV = _mm256_packs_epi16(vout0123GHIJ4567KLMN, vout89ABOPQRCDEFSTUV);
2375 __m256i vout0123456789ABCDEFGHIJKLMNOPQRSTUV = _mm256_permutevar8x32_epi32(vout0123GHIJ89ABOPQR4567KLMNCDEFSTUV, vpermute_mask);
2376 const __m128i voutGHIJOPQR = _mm256_castsi256_si128(voutGHIJOPQRKLMNSTUV);
2377 const __m128i voutKLMNSTUV = _mm256_extracti128_si256(voutGHIJOPQRKLMNSTUV, 1);
2378 __m128i voutGHIJKLMNOPQRSTUV = _mm_shuffle_epi32(_mm_packs_epi16(voutGHIJOPQR, voutKLMNSTUV), _MM_SHUFFLE(3, 1, 2, 0));
2379
2380 vout0123456789ABCDEFGHIJKLMNOPQRSTUV = _mm256_max_epi8(vout0123456789ABCDEFGHIJKLMNOPQRSTUV, voutput_min);
2381 voutGHIJKLMNOPQRSTUV = _mm_max_epi8(voutGHIJKLMNOPQRSTUV, _mm256_castsi256_si128(voutput_min));
2382
2383 _mm256_storeu_si256((__m256i*) output, vout0123456789ABCDEFGHIJKLMNOPQRSTUV);
2384 _mm_storeu_si128((__m128i*) (output + 16), voutGHIJKLMNOPQRSTUV);
2385 output += 32;
2386 }
2387 if XNN_UNLIKELY(c != 0) {
2388 // Prepare mask for valid 8-bit elements (depends on nc).
2389 const __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << (c & 15)) - UINT32_C(1)));
2390 const int8_t* k = (const int8_t*) ((uintptr_t) w + 32 * sizeof(int32_t));
2391 do {
2392 __m512i vacc0123456789ABCDEF = _mm512_loadu_si512(w);
2393
2394
2395 const __m512i vi0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i0));
2396 const __m512i vk0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) k));
2397 i0 += 16;
2398
2399 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF));
2400
2401 const __m512i vi1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i1));
2402 const __m512i vk1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 32)));
2403 i1 += 16;
2404
2405 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF));
2406
2407 const __m512i vi2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i2));
2408 const __m512i vk2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 64)));
2409 i2 += 16;
2410
2411 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF));
2412
2413 const __m512i vi3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i3));
2414 const __m512i vk3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 96)));
2415 i3 += 16;
2416
2417 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF));
2418
2419 const __m512i vi4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i4));
2420 const __m512i vk4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 128)));
2421 i4 += 16;
2422
2423 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF));
2424
2425 const __m512i vi5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i5));
2426 const __m512i vk5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 160)));
2427 i5 += 16;
2428
2429 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF));
2430
2431 const __m512i vi6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i6));
2432 const __m512i vk6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 192)));
2433 i6 += 16;
2434
2435 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF));
2436
2437 const __m512i vi7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i7));
2438 const __m512i vk7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 224)));
2439 i7 += 16;
2440
2441 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF));
2442
2443 const __m512i vi8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i8));
2444 const __m512i vk8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 256)));
2445 i8 += 16;
2446
2447 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF));
2448
2449 const __m512i vi9x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i9));
2450 const __m512i vk9x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 288)));
2451 i9 += 16;
2452
2453 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi9x0123456789ABCDEF, vk9x0123456789ABCDEF));
2454
2455 const __m512i vi10x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i10));
2456 const __m512i vk10x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 320)));
2457 i10 += 16;
2458
2459 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi10x0123456789ABCDEF, vk10x0123456789ABCDEF));
2460
2461 const __m512i vi11x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i11));
2462 const __m512i vk11x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 352)));
2463 i11 += 16;
2464
2465 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi11x0123456789ABCDEF, vk11x0123456789ABCDEF));
2466
2467 const __m512i vi12x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i12));
2468 const __m512i vk12x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 384)));
2469 i12 += 16;
2470
2471 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi12x0123456789ABCDEF, vk12x0123456789ABCDEF));
2472
2473 const __m512i vi13x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i13));
2474 const __m512i vk13x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 416)));
2475 i13 += 16;
2476
2477 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi13x0123456789ABCDEF, vk13x0123456789ABCDEF));
2478
2479 const __m512i vi14x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i14));
2480 const __m512i vk14x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 448)));
2481 i14 += 16;
2482
2483 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi14x0123456789ABCDEF, vk14x0123456789ABCDEF));
2484
2485 const __m512i vi15x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i15));
2486 const __m512i vk15x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 480)));
2487 i15 += 16;
2488
2489 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi15x0123456789ABCDEF, vk15x0123456789ABCDEF));
2490
2491 const __m512i vi16x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i16));
2492 const __m512i vk16x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 512)));
2493 i16 += 16;
2494
2495 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi16x0123456789ABCDEF, vk16x0123456789ABCDEF));
2496
2497 const __m512i vi17x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i17));
2498 const __m512i vk17x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 544)));
2499 i17 += 16;
2500
2501 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi17x0123456789ABCDEF, vk17x0123456789ABCDEF));
2502
2503 const __m512i vi18x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i18));
2504 const __m512i vk18x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 576)));
2505 i18 += 16;
2506
2507 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi18x0123456789ABCDEF, vk18x0123456789ABCDEF));
2508
2509 const __m512i vi19x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i19));
2510 const __m512i vk19x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 608)));
2511 i19 += 16;
2512
2513 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi19x0123456789ABCDEF, vk19x0123456789ABCDEF));
2514
2515 const __m512i vi20x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i20));
2516 const __m512i vk20x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 640)));
2517 i20 += 16;
2518
2519 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi20x0123456789ABCDEF, vk20x0123456789ABCDEF));
2520
2521 const __m512i vi21x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i21));
2522 const __m512i vk21x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 672)));
2523 i21 += 16;
2524
2525 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi21x0123456789ABCDEF, vk21x0123456789ABCDEF));
2526
2527 const __m512i vi22x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i22));
2528 const __m512i vk22x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 704)));
2529 i22 += 16;
2530
2531 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi22x0123456789ABCDEF, vk22x0123456789ABCDEF));
2532
2533 const __m512i vi23x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i23));
2534 const __m512i vk23x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 736)));
2535 i23 += 16;
2536
2537 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi23x0123456789ABCDEF, vk23x0123456789ABCDEF));
2538
2539 const __m512i vi24x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i24));
2540 const __m512i vk24x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 768)));
2541 i24 += 16;
2542
2543 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi24x0123456789ABCDEF, vk24x0123456789ABCDEF));
2544
2545 k += 16;
2546
2547 __m512 vscaled0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0123456789ABCDEF);
2548 vscaled0123456789ABCDEF = _mm512_mul_ps(vscaled0123456789ABCDEF, vscale);
2549 vscaled0123456789ABCDEF = _mm512_min_ps(vscaled0123456789ABCDEF, voutput_max_less_zero_point);
2550 vacc0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0123456789ABCDEF);
2551
2552 w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
2553
2554 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0123456789ABCDEF, 1)), _mm512_castsi512_si256(voutput_zero_point));
2555
2556 const __m128i vout012389AB = _mm256_castsi256_si128(vout012389AB4567CDEF);
2557 const __m128i vout4567CDEF = _mm256_extracti128_si256(vout012389AB4567CDEF, 1);
2558 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packs_epi16(vout012389AB, vout4567CDEF), _MM_SHUFFLE(3, 1, 2, 0));
2559 vout0123456789ABCDEF = _mm_max_epi8(vout0123456789ABCDEF, _mm256_castsi256_si128(voutput_min));
2560
2561 if XNN_LIKELY(c >= 16) {
2562 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
2563 output += 16;
2564 c -= 16;
2565 } else {
2566 _mm_mask_storeu_epi8(output, vmask, vout0123456789ABCDEF);
2567 output = (int8_t*) ((uintptr_t) output + c);
2568 c = 0;
2569 }
2570 } while (c != 0);
2571 }
2572
2573 output = (int8_t*) ((uintptr_t) output + output_increment);
2574 } while (--output_width != 0);
2575 }
2576
xnn_qs8_dwconv_minmax_fp32_ukernel_up32x9__avx512skx_mul32(size_t channels,size_t output_width,const int8_t ** input,const void * weights,int8_t * output,size_t input_stride,size_t output_increment,size_t input_offset,const int8_t * zero,const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])2577 void xnn_qs8_dwconv_minmax_fp32_ukernel_up32x9__avx512skx_mul32(
2578 size_t channels,
2579 size_t output_width,
2580 const int8_t** input,
2581 const void* weights,
2582 int8_t* output,
2583 size_t input_stride,
2584 size_t output_increment,
2585 size_t input_offset,
2586 const int8_t* zero,
2587 const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_MSAN
2588 {
2589 assert(channels != 0);
2590 assert(output_width != 0);
2591
2592 const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
2593 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
2594 const __m512i voutput_zero_point = _mm512_load_si512(params->fp32_avx512.output_zero_point);
2595 const __m256i voutput_min = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_min);
2596 const __m256i vpermute_mask = _mm256_set_epi32(7, 3, 5, 1, 6, 2, 4, 0);
2597
2598 do {
2599 const int8_t* i0 = input[0];
2600 assert(i0 != NULL);
2601 if XNN_UNPREDICTABLE(i0 != zero) {
2602 i0 = (const int8_t*) ((uintptr_t) i0 + input_offset);
2603 }
2604 const int8_t* i1 = input[1];
2605 assert(i1 != NULL);
2606 if XNN_UNPREDICTABLE(i1 != zero) {
2607 i1 = (const int8_t*) ((uintptr_t) i1 + input_offset);
2608 }
2609 const int8_t* i2 = input[2];
2610 assert(i2 != NULL);
2611 if XNN_UNPREDICTABLE(i2 != zero) {
2612 i2 = (const int8_t*) ((uintptr_t) i2 + input_offset);
2613 }
2614 const int8_t* i3 = input[3];
2615 assert(i3 != NULL);
2616 if XNN_UNPREDICTABLE(i3 != zero) {
2617 i3 = (const int8_t*) ((uintptr_t) i3 + input_offset);
2618 }
2619 const int8_t* i4 = input[4];
2620 assert(i4 != NULL);
2621 if XNN_UNPREDICTABLE(i4 != zero) {
2622 i4 = (const int8_t*) ((uintptr_t) i4 + input_offset);
2623 }
2624 const int8_t* i5 = input[5];
2625 assert(i5 != NULL);
2626 if XNN_UNPREDICTABLE(i5 != zero) {
2627 i5 = (const int8_t*) ((uintptr_t) i5 + input_offset);
2628 }
2629 const int8_t* i6 = input[6];
2630 assert(i6 != NULL);
2631 if XNN_UNPREDICTABLE(i6 != zero) {
2632 i6 = (const int8_t*) ((uintptr_t) i6 + input_offset);
2633 }
2634 const int8_t* i7 = input[7];
2635 assert(i7 != NULL);
2636 if XNN_UNPREDICTABLE(i7 != zero) {
2637 i7 = (const int8_t*) ((uintptr_t) i7 + input_offset);
2638 }
2639 const int8_t* i8 = input[8];
2640 assert(i8 != NULL);
2641 if XNN_UNPREDICTABLE(i8 != zero) {
2642 i8 = (const int8_t*) ((uintptr_t) i8 + input_offset);
2643 }
2644 input = (const int8_t**) ((uintptr_t) input + input_stride);
2645
2646 size_t c = channels;
2647 const void* w = weights;
2648 for (; c >= 32; c -= 32) {
2649 __m512i vacc0123456789ABCDEF = _mm512_loadu_si512(w);
2650 __m512i vaccGHIJKLMNOPQRSTUV = _mm512_loadu_si512((const void*) ((uintptr_t) w + 16 * sizeof(int32_t)));
2651
2652
2653 const __m512i vi0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i0));
2654 const __m512i vk0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 0 * sizeof(int8_t))));
2655 const __m512i vi0xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i0 + 16)));
2656 const __m512i vk0xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 16 * sizeof(int8_t))));
2657 i0 += 32;
2658
2659 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF));
2660 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi0xGHIJKLMNOPQRSTUV, vk0xGHIJKLMNOPQRSTUV));
2661
2662 const __m512i vi1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i1));
2663 const __m512i vk1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 32 * sizeof(int8_t))));
2664 const __m512i vi1xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i1 + 16)));
2665 const __m512i vk1xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 48 * sizeof(int8_t))));
2666 i1 += 32;
2667
2668 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF));
2669 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi1xGHIJKLMNOPQRSTUV, vk1xGHIJKLMNOPQRSTUV));
2670
2671 const __m512i vi2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i2));
2672 const __m512i vk2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 64 * sizeof(int8_t))));
2673 const __m512i vi2xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i2 + 16)));
2674 const __m512i vk2xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 80 * sizeof(int8_t))));
2675 i2 += 32;
2676
2677 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF));
2678 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi2xGHIJKLMNOPQRSTUV, vk2xGHIJKLMNOPQRSTUV));
2679
2680 const __m512i vi3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i3));
2681 const __m512i vk3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 96 * sizeof(int8_t))));
2682 const __m512i vi3xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i3 + 16)));
2683 const __m512i vk3xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 112 * sizeof(int8_t))));
2684 i3 += 32;
2685
2686 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF));
2687 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi3xGHIJKLMNOPQRSTUV, vk3xGHIJKLMNOPQRSTUV));
2688
2689 const __m512i vi4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i4));
2690 const __m512i vk4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 128 * sizeof(int8_t))));
2691 const __m512i vi4xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i4 + 16)));
2692 const __m512i vk4xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 144 * sizeof(int8_t))));
2693 i4 += 32;
2694
2695 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF));
2696 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi4xGHIJKLMNOPQRSTUV, vk4xGHIJKLMNOPQRSTUV));
2697
2698 const __m512i vi5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i5));
2699 const __m512i vk5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 160 * sizeof(int8_t))));
2700 const __m512i vi5xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i5 + 16)));
2701 const __m512i vk5xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 176 * sizeof(int8_t))));
2702 i5 += 32;
2703
2704 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF));
2705 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi5xGHIJKLMNOPQRSTUV, vk5xGHIJKLMNOPQRSTUV));
2706
2707 const __m512i vi6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i6));
2708 const __m512i vk6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 192 * sizeof(int8_t))));
2709 const __m512i vi6xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i6 + 16)));
2710 const __m512i vk6xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 208 * sizeof(int8_t))));
2711 i6 += 32;
2712
2713 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF));
2714 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi6xGHIJKLMNOPQRSTUV, vk6xGHIJKLMNOPQRSTUV));
2715
2716 const __m512i vi7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i7));
2717 const __m512i vk7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 224 * sizeof(int8_t))));
2718 const __m512i vi7xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i7 + 16)));
2719 const __m512i vk7xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 240 * sizeof(int8_t))));
2720 i7 += 32;
2721
2722 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF));
2723 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi7xGHIJKLMNOPQRSTUV, vk7xGHIJKLMNOPQRSTUV));
2724
2725 const __m512i vi8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i8));
2726 const __m512i vk8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 256 * sizeof(int8_t))));
2727 const __m512i vi8xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i8 + 16)));
2728 const __m512i vk8xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 272 * sizeof(int8_t))));
2729 i8 += 32;
2730
2731 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF));
2732 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi8xGHIJKLMNOPQRSTUV, vk8xGHIJKLMNOPQRSTUV));
2733
2734 w = (const void*) ((uintptr_t) w + 32 * sizeof(int32_t) + 288 * sizeof(int8_t));
2735
2736 __m512 vscaled0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0123456789ABCDEF);
2737 __m512 vscaledGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vaccGHIJKLMNOPQRSTUV);
2738
2739 vscaled0123456789ABCDEF = _mm512_mul_ps(vscaled0123456789ABCDEF, vscale);
2740 vscaledGHIJKLMNOPQRSTUV = _mm512_mul_ps(vscaledGHIJKLMNOPQRSTUV, vscale);
2741
2742 vscaled0123456789ABCDEF = _mm512_min_ps(vscaled0123456789ABCDEF, voutput_max_less_zero_point);
2743 vscaledGHIJKLMNOPQRSTUV = _mm512_min_ps(vscaledGHIJKLMNOPQRSTUV, voutput_max_less_zero_point);
2744
2745 vacc0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0123456789ABCDEF);
2746 vaccGHIJKLMNOPQRSTUV = _mm512_cvtps_epi32(vscaledGHIJKLMNOPQRSTUV);
2747
2748 __m512i vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV = _mm512_adds_epi16(_mm512_packs_epi32(vacc0123456789ABCDEF, vaccGHIJKLMNOPQRSTUV), voutput_zero_point);
2749 __m256i voutGHIJOPQRKLMNSTUV = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vaccGHIJKLMNOPQRSTUV), _mm512_extracti32x8_epi32(vaccGHIJKLMNOPQRSTUV, 1)), _mm512_castsi512_si256(voutput_zero_point));
2750
2751 const __m256i vout0123GHIJ4567KLMN = _mm512_castsi512_si256(vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV);
2752 const __m256i vout89ABOPQRCDEFSTUV = _mm512_extracti32x8_epi32(vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV, 1);
2753 const __m256i vout0123GHIJ89ABOPQR4567KLMNCDEFSTUV = _mm256_packs_epi16(vout0123GHIJ4567KLMN, vout89ABOPQRCDEFSTUV);
2754 __m256i vout0123456789ABCDEFGHIJKLMNOPQRSTUV = _mm256_permutevar8x32_epi32(vout0123GHIJ89ABOPQR4567KLMNCDEFSTUV, vpermute_mask);
2755 const __m128i voutGHIJOPQR = _mm256_castsi256_si128(voutGHIJOPQRKLMNSTUV);
2756 const __m128i voutKLMNSTUV = _mm256_extracti128_si256(voutGHIJOPQRKLMNSTUV, 1);
2757 __m128i voutGHIJKLMNOPQRSTUV = _mm_shuffle_epi32(_mm_packs_epi16(voutGHIJOPQR, voutKLMNSTUV), _MM_SHUFFLE(3, 1, 2, 0));
2758
2759 vout0123456789ABCDEFGHIJKLMNOPQRSTUV = _mm256_max_epi8(vout0123456789ABCDEFGHIJKLMNOPQRSTUV, voutput_min);
2760 voutGHIJKLMNOPQRSTUV = _mm_max_epi8(voutGHIJKLMNOPQRSTUV, _mm256_castsi256_si128(voutput_min));
2761
2762 _mm256_storeu_si256((__m256i*) output, vout0123456789ABCDEFGHIJKLMNOPQRSTUV);
2763 _mm_storeu_si128((__m128i*) (output + 16), voutGHIJKLMNOPQRSTUV);
2764 output += 32;
2765 }
2766 if XNN_UNLIKELY(c != 0) {
2767 // Prepare mask for valid 8-bit elements (depends on nc).
2768 const __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << (c & 15)) - UINT32_C(1)));
2769 const int8_t* k = (const int8_t*) ((uintptr_t) w + 32 * sizeof(int32_t));
2770 do {
2771 __m512i vacc0123456789ABCDEF = _mm512_loadu_si512(w);
2772
2773
2774 const __m512i vi0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i0));
2775 const __m512i vk0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) k));
2776 i0 += 16;
2777
2778 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF));
2779
2780 const __m512i vi1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i1));
2781 const __m512i vk1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 32)));
2782 i1 += 16;
2783
2784 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF));
2785
2786 const __m512i vi2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i2));
2787 const __m512i vk2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 64)));
2788 i2 += 16;
2789
2790 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF));
2791
2792 const __m512i vi3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i3));
2793 const __m512i vk3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 96)));
2794 i3 += 16;
2795
2796 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF));
2797
2798 const __m512i vi4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i4));
2799 const __m512i vk4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 128)));
2800 i4 += 16;
2801
2802 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF));
2803
2804 const __m512i vi5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i5));
2805 const __m512i vk5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 160)));
2806 i5 += 16;
2807
2808 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF));
2809
2810 const __m512i vi6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i6));
2811 const __m512i vk6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 192)));
2812 i6 += 16;
2813
2814 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF));
2815
2816 const __m512i vi7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i7));
2817 const __m512i vk7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 224)));
2818 i7 += 16;
2819
2820 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF));
2821
2822 const __m512i vi8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i8));
2823 const __m512i vk8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 256)));
2824 i8 += 16;
2825
2826 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF));
2827
2828 k += 16;
2829
2830 __m512 vscaled0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0123456789ABCDEF);
2831 vscaled0123456789ABCDEF = _mm512_mul_ps(vscaled0123456789ABCDEF, vscale);
2832 vscaled0123456789ABCDEF = _mm512_min_ps(vscaled0123456789ABCDEF, voutput_max_less_zero_point);
2833 vacc0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0123456789ABCDEF);
2834
2835 w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
2836
2837 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0123456789ABCDEF, 1)), _mm512_castsi512_si256(voutput_zero_point));
2838
2839 const __m128i vout012389AB = _mm256_castsi256_si128(vout012389AB4567CDEF);
2840 const __m128i vout4567CDEF = _mm256_extracti128_si256(vout012389AB4567CDEF, 1);
2841 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packs_epi16(vout012389AB, vout4567CDEF), _MM_SHUFFLE(3, 1, 2, 0));
2842 vout0123456789ABCDEF = _mm_max_epi8(vout0123456789ABCDEF, _mm256_castsi256_si128(voutput_min));
2843
2844 if XNN_LIKELY(c >= 16) {
2845 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
2846 output += 16;
2847 c -= 16;
2848 } else {
2849 _mm_mask_storeu_epi8(output, vmask, vout0123456789ABCDEF);
2850 output = (int8_t*) ((uintptr_t) output + c);
2851 c = 0;
2852 }
2853 } while (c != 0);
2854 }
2855
2856 output = (int8_t*) ((uintptr_t) output + output_increment);
2857 } while (--output_width != 0);
2858 }
2859
xnn_qs8_f32_vcvt_ukernel__avx512skx_x32(size_t n,const int8_t * x,float * y,const union xnn_qs8_f32_cvt_params params[restrict XNN_MIN_ELEMENTS (1)])2860 void xnn_qs8_f32_vcvt_ukernel__avx512skx_x32(
2861 size_t n,
2862 const int8_t* x,
2863 float* y,
2864 const union xnn_qs8_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
2865 {
2866 assert(n != 0);
2867 assert(n % sizeof(int8_t) == 0);
2868 assert(x != NULL);
2869 assert(y != NULL);
2870
2871 const __m512i vminus_zero_point = _mm512_load_si512(params->avx512.minus_zero_point);
2872 const __m512 vscale = _mm512_load_ps(params->avx512.scale);
2873 for (; n >= 32 * sizeof(int8_t); n -= 32 * sizeof(int8_t)) {
2874 __m512i vx0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) x));
2875 __m512i vxGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (x + 16)));
2876 x += 32;
2877
2878 vx0123456789ABCDEF = _mm512_add_epi32(vx0123456789ABCDEF, vminus_zero_point);
2879 vxGHIJKLMNOPQRSTUV = _mm512_add_epi32(vxGHIJKLMNOPQRSTUV, vminus_zero_point);
2880
2881 __m512 vy0123456789ABCDEF = _mm512_cvtepi32_ps(vx0123456789ABCDEF);
2882 __m512 vyGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vxGHIJKLMNOPQRSTUV);
2883
2884 vy0123456789ABCDEF = _mm512_mul_ps(vy0123456789ABCDEF, vscale);
2885 vyGHIJKLMNOPQRSTUV = _mm512_mul_ps(vyGHIJKLMNOPQRSTUV, vscale);
2886
2887 _mm512_storeu_ps(y, vy0123456789ABCDEF);
2888 _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
2889 y += 32;
2890 }
2891 for (; n >= 16 * sizeof(int8_t); n -= 16 * sizeof(int8_t)) {
2892 __m512i vx = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) x));
2893 vx = _mm512_add_epi32(vx, vminus_zero_point);
2894 x += 16;
2895
2896 __m512 vy = _mm512_cvtepi32_ps(vx);
2897 vy = _mm512_mul_ps(vy, vscale);
2898
2899 _mm512_storeu_ps(y, vy);
2900 y += 16;
2901 }
2902 if XNN_UNLIKELY(n != 0) {
2903 assert(n >= 1 * sizeof(int8_t));
2904 assert(n <= 15 * sizeof(int8_t));
2905
2906 // Prepare mask for valid elements (depends on n).
2907 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2908
2909 __m512i vx = _mm512_cvtepi8_epi32(_mm_maskz_loadu_epi8(vmask, x));
2910 vx = _mm512_add_epi32(vx, vminus_zero_point);
2911
2912 __m512 vy = _mm512_cvtepi32_ps(vx);
2913 vy = _mm512_mul_ps(vy, vscale);
2914
2915 _mm512_mask_storeu_ps(y, vmask, vy);
2916 }
2917 }
2918
xnn_qs8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx(size_t mr,size_t nc,size_t kc,const int8_t * restrict a,size_t a_stride,const void * restrict w,int8_t * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])2919 void xnn_qs8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx(
2920 size_t mr,
2921 size_t nc,
2922 size_t kc,
2923 const int8_t* restrict a,
2924 size_t a_stride,
2925 const void* restrict w,
2926 int8_t* restrict c,
2927 size_t cm_stride,
2928 size_t cn_stride,
2929 const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
2930 {
2931 assert(mr != 0);
2932 assert(mr <= 1);
2933 assert(nc != 0);
2934 assert(kc != 0);
2935 assert(kc % sizeof(int8_t) == 0);
2936 assert(a != NULL);
2937 assert(w != NULL);
2938 assert(c != NULL);
2939
2940 kc = round_up_po2(kc, 8);
2941 const int8_t* a0 = a;
2942 int8_t* c0 = c;
2943
2944 const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
2945 const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
2946 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
2947 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point);
2948 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min);
2949 do {
2950 __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
2951 __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
2952 __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
2953 __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
2954 w = (const void*) ((const int32_t*) w + 16);
2955
2956 size_t k = 0;
2957 while (k < kc) {
2958 const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
2959 a0 += 8;
2960
2961 const __m512i vb0123 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) w));
2962
2963 vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
2964 const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
2965
2966 vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
2967 const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
2968
2969 vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
2970 const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
2971
2972 vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
2973
2974 w = (const void*) ((const int8_t*) w + 128);
2975 k += 8 * sizeof(int8_t);
2976 }
2977
2978 const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
2979 const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
2980
2981 __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
2982
2983 __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
2984
2985 vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale);
2986
2987 vscaled0x084C195D2A6E3B7F = _mm512_min_ps(vscaled0x084C195D2A6E3B7F, voutput_max_less_zero_point);
2988
2989 vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
2990
2991 const __m256i vacc0x084C2A6E195D3B7F = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0x084C195D2A6E3B7F), _mm512_extracti32x8_epi32(vacc0x084C195D2A6E3B7F, 1)), voutput_zero_point);
2992
2993 const __m128i vout0x084C2A6E195D3B7F = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x084C2A6E195D3B7F), _mm256_extracti128_si256(vacc0x084C2A6E195D3B7F, 1));
2994 __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x084C2A6E195D3B7F, _mm_set_epi8(15, 7, 11, 3, 13, 5, 9, 1, 14, 6, 10, 2, 12, 4, 8, 0));
2995 vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min);
2996
2997 if (nc >= 16) {
2998 _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF);
2999
3000 a0 = (const int8_t*) ((uintptr_t) a0 - k);
3001
3002 c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
3003
3004 nc -= 16;
3005 } else {
3006 // Prepare mask for valid 8-bit elements (depends on nc).
3007 const __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT32_C(1) << nc) - UINT32_C(1)));
3008
3009 _mm_mask_storeu_epi8(c0, vmask, vout0x0123456789ABCDEF);
3010
3011 nc = 0;
3012 }
3013 } while (nc != 0);
3014 }
3015
xnn_qs8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx(size_t mr,size_t nc,size_t kc,const int8_t * restrict a,size_t a_stride,const void * restrict w,int8_t * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])3016 void xnn_qs8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx(
3017 size_t mr,
3018 size_t nc,
3019 size_t kc,
3020 const int8_t* restrict a,
3021 size_t a_stride,
3022 const void* restrict w,
3023 int8_t* restrict c,
3024 size_t cm_stride,
3025 size_t cn_stride,
3026 const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
3027 {
3028 assert(mr != 0);
3029 assert(mr <= 4);
3030 assert(nc != 0);
3031 assert(kc != 0);
3032 assert(kc % sizeof(int8_t) == 0);
3033 assert(a != NULL);
3034 assert(w != NULL);
3035 assert(c != NULL);
3036
3037 kc = round_up_po2(kc, 8);
3038 const int8_t* a0 = a;
3039 int8_t* c0 = c;
3040 const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
3041 int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
3042 if XNN_UNPREDICTABLE(mr < 2) {
3043 a1 = a0;
3044 c1 = c0;
3045 }
3046 const int8_t* a2 = (const int8_t*) ((uintptr_t) a1 + a_stride);
3047 int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride);
3048 if XNN_UNPREDICTABLE(mr <= 2) {
3049 a2 = a1;
3050 c2 = c1;
3051 }
3052 const int8_t* a3 = (const int8_t*) ((uintptr_t) a2 + a_stride);
3053 int8_t* c3 = (int8_t*) ((uintptr_t) c2 + cm_stride);
3054 if XNN_UNPREDICTABLE(mr != 4) {
3055 a3 = a2;
3056 c3 = c2;
3057 }
3058
3059 const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
3060 const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
3061 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
3062 const __m512i voutput_zero_point = _mm512_load_si512(params->fp32_avx512.output_zero_point);
3063 const __m512i voutput_min = _mm512_load_si512(params->fp32_avx512.output_min);
3064 do {
3065 __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
3066 __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
3067 __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
3068 __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
3069 __m512i vacc1x0123 = vacc0x0123;
3070 __m512i vacc1x4567 = vacc0x4567;
3071 __m512i vacc1x89AB = vacc0x89AB;
3072 __m512i vacc1xCDEF = vacc0xCDEF;
3073 __m512i vacc2x0123 = vacc0x0123;
3074 __m512i vacc2x4567 = vacc0x4567;
3075 __m512i vacc2x89AB = vacc0x89AB;
3076 __m512i vacc2xCDEF = vacc0xCDEF;
3077 __m512i vacc3x0123 = vacc0x0123;
3078 __m512i vacc3x4567 = vacc0x4567;
3079 __m512i vacc3x89AB = vacc0x89AB;
3080 __m512i vacc3xCDEF = vacc0xCDEF;
3081 w = (const void*) ((const int32_t*) w + 16);
3082
3083 size_t k = 0;
3084 while (k < kc) {
3085 const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
3086 a0 += 8;
3087 const __m512i va1 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a1)));
3088 a1 += 8;
3089 const __m512i va2 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a2)));
3090 a2 += 8;
3091 const __m512i va3 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a3)));
3092 a3 += 8;
3093
3094 const __m512i vb0123 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) w));
3095
3096 vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
3097 vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
3098 vacc2x0123 = _mm512_add_epi32(vacc2x0123, _mm512_madd_epi16(va2, vb0123));
3099 vacc3x0123 = _mm512_add_epi32(vacc3x0123, _mm512_madd_epi16(va3, vb0123));
3100 const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
3101
3102 vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
3103 vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
3104 vacc2x4567 = _mm512_add_epi32(vacc2x4567, _mm512_madd_epi16(va2, vb4567));
3105 vacc3x4567 = _mm512_add_epi32(vacc3x4567, _mm512_madd_epi16(va3, vb4567));
3106 const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
3107
3108 vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
3109 vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
3110 vacc2x89AB = _mm512_add_epi32(vacc2x89AB, _mm512_madd_epi16(va2, vb89AB));
3111 vacc3x89AB = _mm512_add_epi32(vacc3x89AB, _mm512_madd_epi16(va3, vb89AB));
3112 const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
3113
3114 vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
3115 vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
3116 vacc2xCDEF = _mm512_add_epi32(vacc2xCDEF, _mm512_madd_epi16(va2, vbCDEF));
3117 vacc3xCDEF = _mm512_add_epi32(vacc3xCDEF, _mm512_madd_epi16(va3, vbCDEF));
3118
3119 w = (const void*) ((const int8_t*) w + 128);
3120 k += 8 * sizeof(int8_t);
3121 }
3122
3123 const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
3124 const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
3125 const __m512i vacc1x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x0123, vacc1x4567), _mm512_unpackhi_epi32(vacc1x0123, vacc1x4567));
3126 const __m512i vacc1x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x89AB, vacc1xCDEF), _mm512_unpackhi_epi32(vacc1x89AB, vacc1xCDEF));
3127 const __m512i vacc2x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x0123, vacc2x4567), _mm512_unpackhi_epi32(vacc2x0123, vacc2x4567));
3128 const __m512i vacc2x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x89AB, vacc2xCDEF), _mm512_unpackhi_epi32(vacc2x89AB, vacc2xCDEF));
3129 const __m512i vacc3x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x0123, vacc3x4567), _mm512_unpackhi_epi32(vacc3x0123, vacc3x4567));
3130 const __m512i vacc3x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x89AB, vacc3xCDEF), _mm512_unpackhi_epi32(vacc3x89AB, vacc3xCDEF));
3131
3132 __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
3133 __m512i vacc1x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x04152637, vacc1x8C9DAEBF), _mm512_unpackhi_epi32(vacc1x04152637, vacc1x8C9DAEBF));
3134 __m512i vacc2x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x04152637, vacc2x8C9DAEBF), _mm512_unpackhi_epi32(vacc2x04152637, vacc2x8C9DAEBF));
3135 __m512i vacc3x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x04152637, vacc3x8C9DAEBF), _mm512_unpackhi_epi32(vacc3x04152637, vacc3x8C9DAEBF));
3136
3137 __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
3138 __m512 vscaled1x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc1x084C195D2A6E3B7F);
3139 __m512 vscaled2x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc2x084C195D2A6E3B7F);
3140 __m512 vscaled3x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc3x084C195D2A6E3B7F);
3141
3142 vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale);
3143 vscaled1x084C195D2A6E3B7F = _mm512_mul_ps(vscaled1x084C195D2A6E3B7F, vscale);
3144 vscaled2x084C195D2A6E3B7F = _mm512_mul_ps(vscaled2x084C195D2A6E3B7F, vscale);
3145 vscaled3x084C195D2A6E3B7F = _mm512_mul_ps(vscaled3x084C195D2A6E3B7F, vscale);
3146
3147 vscaled0x084C195D2A6E3B7F = _mm512_min_ps(vscaled0x084C195D2A6E3B7F, voutput_max_less_zero_point);
3148 vscaled1x084C195D2A6E3B7F = _mm512_min_ps(vscaled1x084C195D2A6E3B7F, voutput_max_less_zero_point);
3149 vscaled2x084C195D2A6E3B7F = _mm512_min_ps(vscaled2x084C195D2A6E3B7F, voutput_max_less_zero_point);
3150 vscaled3x084C195D2A6E3B7F = _mm512_min_ps(vscaled3x084C195D2A6E3B7F, voutput_max_less_zero_point);
3151
3152 vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
3153 vacc1x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled1x084C195D2A6E3B7F);
3154 vacc2x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled2x084C195D2A6E3B7F);
3155 vacc3x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled3x084C195D2A6E3B7F);
3156
3157 const __m512i vacc01x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc0x084C195D2A6E3B7F, vacc1x084C195D2A6E3B7F), voutput_zero_point);
3158 const __m512i vacc23x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc2x084C195D2A6E3B7F, vacc3x084C195D2A6E3B7F), voutput_zero_point);
3159
3160 __m512i vout0123x084Cx195Dx2A6Ex3B7F = _mm512_packs_epi16(vacc01x084Cx195Dx2A6Ex3B7F, vacc23x084Cx195Dx2A6Ex3B7F);
3161 vout0123x084Cx195Dx2A6Ex3B7F = _mm512_permutexvar_epi32(_mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0), vout0123x084Cx195Dx2A6Ex3B7F);
3162 __m512i vout0123x0123456789ABCDEF = _mm512_shuffle_epi8(vout0123x084Cx195Dx2A6Ex3B7F, _mm512_set_epi8(15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0));
3163 vout0123x0123456789ABCDEF = _mm512_max_epi8(vout0123x0123456789ABCDEF, voutput_min);
3164
3165 if (nc >= 16) {
3166 _mm_storeu_si128((__m128i*) c0, _mm512_castsi512_si128(vout0123x0123456789ABCDEF));
3167 _mm_storeu_si128((__m128i*) c1, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 1));
3168 _mm_storeu_si128((__m128i*) c2, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 2));
3169 _mm_storeu_si128((__m128i*) c3, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 3));
3170
3171 a0 = (const int8_t*) ((uintptr_t) a0 - k);
3172 a1 = (const int8_t*) ((uintptr_t) a1 - k);
3173 a2 = (const int8_t*) ((uintptr_t) a2 - k);
3174 a3 = (const int8_t*) ((uintptr_t) a3 - k);
3175
3176 c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
3177 c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
3178 c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
3179 c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
3180
3181 nc -= 16;
3182 } else {
3183 // Prepare mask for valid 8-bit elements (depends on nc).
3184 __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT32_C(1) << nc) - UINT32_C(1)));
3185
3186 _mm512_mask_storeu_epi8(c0, vmask, vout0123x0123456789ABCDEF);
3187 vmask = _kshiftli_mask64(vmask, 16);
3188 _mm512_mask_storeu_epi8(c1 - 16, vmask, vout0123x0123456789ABCDEF);
3189 vmask = _kshiftli_mask64(vmask, 16);
3190 _mm512_mask_storeu_epi8(c2 - 32, vmask, vout0123x0123456789ABCDEF);
3191 vmask = _kshiftli_mask64(vmask, 16);
3192 _mm512_mask_storeu_epi8(c3 - 48, vmask, vout0123x0123456789ABCDEF);
3193
3194 nc = 0;
3195 }
3196 } while (nc != 0);
3197 }
3198
xnn_qs8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx(size_t mr,size_t nc,size_t kc,size_t ks,const int8_t ** restrict a,const void * restrict w,int8_t * restrict c,size_t cm_stride,size_t cn_stride,size_t a_offset,const int8_t * zero,const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])3199 void xnn_qs8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx(
3200 size_t mr,
3201 size_t nc,
3202 size_t kc,
3203 size_t ks,
3204 const int8_t** restrict a,
3205 const void* restrict w,
3206 int8_t* restrict c,
3207 size_t cm_stride,
3208 size_t cn_stride,
3209 size_t a_offset,
3210 const int8_t* zero,
3211 const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
3212 {
3213 assert(mr != 0);
3214 assert(mr <= 1);
3215 assert(nc != 0);
3216 assert(kc != 0);
3217 assert(kc % sizeof(int8_t) == 0);
3218 assert(a != NULL);
3219 assert(w != NULL);
3220 assert(c != NULL);
3221
3222 kc = round_up_po2(kc, 8);
3223 int8_t* c0 = c;
3224
3225 const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
3226 const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
3227 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
3228 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point);
3229 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min);
3230 do {
3231 __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
3232 __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
3233 __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
3234 __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
3235 w = (const void*) ((const int32_t*) w + 16);
3236
3237 size_t p = ks;
3238 do {
3239 const int8_t* restrict a0 = a[0];
3240 if XNN_UNPREDICTABLE(a0 != zero) {
3241 a0 = (const int8_t*) ((uintptr_t) a0 + a_offset);
3242 }
3243 a += 1;
3244
3245 size_t k = 0;
3246 while (k < kc) {
3247 const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
3248 a0 += 8;
3249
3250 const __m512i vb0123 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) w));
3251
3252 vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
3253 const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
3254
3255 vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
3256 const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
3257
3258 vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
3259 const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
3260
3261 vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
3262
3263 w = (const void*) ((const int8_t*) w + 128);
3264 k += 8 * sizeof(int8_t);
3265 }
3266 p -= 1 * sizeof(void*);
3267 } while (p != 0);
3268
3269 const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
3270 const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
3271
3272 __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
3273
3274 __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
3275
3276 vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale);
3277
3278 vscaled0x084C195D2A6E3B7F = _mm512_min_ps(vscaled0x084C195D2A6E3B7F, voutput_max_less_zero_point);
3279
3280 vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
3281
3282 const __m256i vacc0x084C2A6E195D3B7F = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0x084C195D2A6E3B7F), _mm512_extracti32x8_epi32(vacc0x084C195D2A6E3B7F, 1)), voutput_zero_point);
3283
3284 const __m128i vout0x084C2A6E195D3B7F = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x084C2A6E195D3B7F), _mm256_extracti128_si256(vacc0x084C2A6E195D3B7F, 1));
3285 __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x084C2A6E195D3B7F, _mm_set_epi8(15, 7, 11, 3, 13, 5, 9, 1, 14, 6, 10, 2, 12, 4, 8, 0));
3286 vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min);
3287
3288 if (nc >= 16) {
3289 _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF);
3290
3291 c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
3292
3293 a = (const int8_t**restrict) ((uintptr_t) a - ks);
3294
3295 nc -= 16;
3296 } else {
3297 // Prepare mask for valid 8-bit elements (depends on nc).
3298 const __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT32_C(1) << nc) - UINT32_C(1)));
3299
3300 _mm_mask_storeu_epi8(c0, vmask, vout0x0123456789ABCDEF);
3301
3302 nc = 0;
3303 }
3304 } while (nc != 0);
3305 }
3306
xnn_qs8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx(size_t mr,size_t nc,size_t kc,size_t ks,const int8_t ** restrict a,const void * restrict w,int8_t * restrict c,size_t cm_stride,size_t cn_stride,size_t a_offset,const int8_t * zero,const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])3307 void xnn_qs8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx(
3308 size_t mr,
3309 size_t nc,
3310 size_t kc,
3311 size_t ks,
3312 const int8_t** restrict a,
3313 const void* restrict w,
3314 int8_t* restrict c,
3315 size_t cm_stride,
3316 size_t cn_stride,
3317 size_t a_offset,
3318 const int8_t* zero,
3319 const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
3320 {
3321 assert(mr != 0);
3322 assert(mr <= 4);
3323 assert(nc != 0);
3324 assert(kc != 0);
3325 assert(kc % sizeof(int8_t) == 0);
3326 assert(a != NULL);
3327 assert(w != NULL);
3328 assert(c != NULL);
3329
3330 kc = round_up_po2(kc, 8);
3331 int8_t* c0 = c;
3332 int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
3333 if XNN_UNPREDICTABLE(mr < 2) {
3334 c1 = c0;
3335 }
3336 int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride);
3337 if XNN_UNPREDICTABLE(mr <= 2) {
3338 c2 = c1;
3339 }
3340 int8_t* c3 = (int8_t*) ((uintptr_t) c2 + cm_stride);
3341 if XNN_UNPREDICTABLE(mr != 4) {
3342 c3 = c2;
3343 }
3344
3345 const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
3346 const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
3347 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
3348 const __m512i voutput_zero_point = _mm512_load_si512(params->fp32_avx512.output_zero_point);
3349 const __m512i voutput_min = _mm512_load_si512(params->fp32_avx512.output_min);
3350 do {
3351 __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
3352 __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
3353 __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
3354 __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
3355 __m512i vacc1x0123 = vacc0x0123;
3356 __m512i vacc1x4567 = vacc0x4567;
3357 __m512i vacc1x89AB = vacc0x89AB;
3358 __m512i vacc1xCDEF = vacc0xCDEF;
3359 __m512i vacc2x0123 = vacc0x0123;
3360 __m512i vacc2x4567 = vacc0x4567;
3361 __m512i vacc2x89AB = vacc0x89AB;
3362 __m512i vacc2xCDEF = vacc0xCDEF;
3363 __m512i vacc3x0123 = vacc0x0123;
3364 __m512i vacc3x4567 = vacc0x4567;
3365 __m512i vacc3x89AB = vacc0x89AB;
3366 __m512i vacc3xCDEF = vacc0xCDEF;
3367 w = (const void*) ((const int32_t*) w + 16);
3368
3369 size_t p = ks;
3370 do {
3371 const int8_t* restrict a0 = a[0];
3372 if XNN_UNPREDICTABLE(a0 != zero) {
3373 a0 = (const int8_t*) ((uintptr_t) a0 + a_offset);
3374 }
3375 const int8_t* restrict a1 = a[1];
3376 if XNN_UNPREDICTABLE(a1 != zero) {
3377 a1 = (const int8_t*) ((uintptr_t) a1 + a_offset);
3378 }
3379 const int8_t* restrict a2 = a[2];
3380 if XNN_UNPREDICTABLE(a2 != zero) {
3381 a2 = (const int8_t*) ((uintptr_t) a2 + a_offset);
3382 }
3383 const int8_t* restrict a3 = a[3];
3384 if XNN_UNPREDICTABLE(a3 != zero) {
3385 a3 = (const int8_t*) ((uintptr_t) a3 + a_offset);
3386 }
3387 a += 4;
3388
3389 size_t k = 0;
3390 while (k < kc) {
3391 const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
3392 a0 += 8;
3393 const __m512i va1 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a1)));
3394 a1 += 8;
3395 const __m512i va2 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a2)));
3396 a2 += 8;
3397 const __m512i va3 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a3)));
3398 a3 += 8;
3399
3400 const __m512i vb0123 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) w));
3401
3402 vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
3403 vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
3404 vacc2x0123 = _mm512_add_epi32(vacc2x0123, _mm512_madd_epi16(va2, vb0123));
3405 vacc3x0123 = _mm512_add_epi32(vacc3x0123, _mm512_madd_epi16(va3, vb0123));
3406 const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
3407
3408 vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
3409 vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
3410 vacc2x4567 = _mm512_add_epi32(vacc2x4567, _mm512_madd_epi16(va2, vb4567));
3411 vacc3x4567 = _mm512_add_epi32(vacc3x4567, _mm512_madd_epi16(va3, vb4567));
3412 const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
3413
3414 vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
3415 vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
3416 vacc2x89AB = _mm512_add_epi32(vacc2x89AB, _mm512_madd_epi16(va2, vb89AB));
3417 vacc3x89AB = _mm512_add_epi32(vacc3x89AB, _mm512_madd_epi16(va3, vb89AB));
3418 const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
3419
3420 vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
3421 vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
3422 vacc2xCDEF = _mm512_add_epi32(vacc2xCDEF, _mm512_madd_epi16(va2, vbCDEF));
3423 vacc3xCDEF = _mm512_add_epi32(vacc3xCDEF, _mm512_madd_epi16(va3, vbCDEF));
3424
3425 w = (const void*) ((const int8_t*) w + 128);
3426 k += 8 * sizeof(int8_t);
3427 }
3428 p -= 4 * sizeof(void*);
3429 } while (p != 0);
3430
3431 const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
3432 const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
3433 const __m512i vacc1x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x0123, vacc1x4567), _mm512_unpackhi_epi32(vacc1x0123, vacc1x4567));
3434 const __m512i vacc1x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x89AB, vacc1xCDEF), _mm512_unpackhi_epi32(vacc1x89AB, vacc1xCDEF));
3435 const __m512i vacc2x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x0123, vacc2x4567), _mm512_unpackhi_epi32(vacc2x0123, vacc2x4567));
3436 const __m512i vacc2x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x89AB, vacc2xCDEF), _mm512_unpackhi_epi32(vacc2x89AB, vacc2xCDEF));
3437 const __m512i vacc3x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x0123, vacc3x4567), _mm512_unpackhi_epi32(vacc3x0123, vacc3x4567));
3438 const __m512i vacc3x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x89AB, vacc3xCDEF), _mm512_unpackhi_epi32(vacc3x89AB, vacc3xCDEF));
3439
3440 __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
3441 __m512i vacc1x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x04152637, vacc1x8C9DAEBF), _mm512_unpackhi_epi32(vacc1x04152637, vacc1x8C9DAEBF));
3442 __m512i vacc2x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x04152637, vacc2x8C9DAEBF), _mm512_unpackhi_epi32(vacc2x04152637, vacc2x8C9DAEBF));
3443 __m512i vacc3x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x04152637, vacc3x8C9DAEBF), _mm512_unpackhi_epi32(vacc3x04152637, vacc3x8C9DAEBF));
3444
3445 __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
3446 __m512 vscaled1x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc1x084C195D2A6E3B7F);
3447 __m512 vscaled2x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc2x084C195D2A6E3B7F);
3448 __m512 vscaled3x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc3x084C195D2A6E3B7F);
3449
3450 vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale);
3451 vscaled1x084C195D2A6E3B7F = _mm512_mul_ps(vscaled1x084C195D2A6E3B7F, vscale);
3452 vscaled2x084C195D2A6E3B7F = _mm512_mul_ps(vscaled2x084C195D2A6E3B7F, vscale);
3453 vscaled3x084C195D2A6E3B7F = _mm512_mul_ps(vscaled3x084C195D2A6E3B7F, vscale);
3454
3455 vscaled0x084C195D2A6E3B7F = _mm512_min_ps(vscaled0x084C195D2A6E3B7F, voutput_max_less_zero_point);
3456 vscaled1x084C195D2A6E3B7F = _mm512_min_ps(vscaled1x084C195D2A6E3B7F, voutput_max_less_zero_point);
3457 vscaled2x084C195D2A6E3B7F = _mm512_min_ps(vscaled2x084C195D2A6E3B7F, voutput_max_less_zero_point);
3458 vscaled3x084C195D2A6E3B7F = _mm512_min_ps(vscaled3x084C195D2A6E3B7F, voutput_max_less_zero_point);
3459
3460 vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
3461 vacc1x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled1x084C195D2A6E3B7F);
3462 vacc2x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled2x084C195D2A6E3B7F);
3463 vacc3x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled3x084C195D2A6E3B7F);
3464
3465 const __m512i vacc01x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc0x084C195D2A6E3B7F, vacc1x084C195D2A6E3B7F), voutput_zero_point);
3466 const __m512i vacc23x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc2x084C195D2A6E3B7F, vacc3x084C195D2A6E3B7F), voutput_zero_point);
3467
3468 __m512i vout0123x084Cx195Dx2A6Ex3B7F = _mm512_packs_epi16(vacc01x084Cx195Dx2A6Ex3B7F, vacc23x084Cx195Dx2A6Ex3B7F);
3469 vout0123x084Cx195Dx2A6Ex3B7F = _mm512_permutexvar_epi32(_mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0), vout0123x084Cx195Dx2A6Ex3B7F);
3470 __m512i vout0123x0123456789ABCDEF = _mm512_shuffle_epi8(vout0123x084Cx195Dx2A6Ex3B7F, _mm512_set_epi8(15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0));
3471 vout0123x0123456789ABCDEF = _mm512_max_epi8(vout0123x0123456789ABCDEF, voutput_min);
3472
3473 if (nc >= 16) {
3474 _mm_storeu_si128((__m128i*) c3, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 3));
3475 _mm_storeu_si128((__m128i*) c2, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 2));
3476 _mm_storeu_si128((__m128i*) c1, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 1));
3477 _mm_storeu_si128((__m128i*) c0, _mm512_castsi512_si128(vout0123x0123456789ABCDEF));
3478
3479 c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
3480 c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
3481 c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
3482 c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
3483
3484 a = (const int8_t**restrict) ((uintptr_t) a - ks);
3485
3486 nc -= 16;
3487 } else {
3488 // Prepare mask for valid 8-bit elements (depends on nc).
3489 __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT64_C(1) << (nc + 48)) - (UINT64_C(1) << 48)));
3490
3491 _mm512_mask_storeu_epi8(c3 - 48, vmask, vout0123x0123456789ABCDEF);
3492 vmask = _kshiftri_mask64(vmask, 16);
3493 _mm512_mask_storeu_epi8(c2 - 32, vmask, vout0123x0123456789ABCDEF);
3494 vmask = _kshiftri_mask64(vmask, 16);
3495 _mm512_mask_storeu_epi8(c1 - 16, vmask, vout0123x0123456789ABCDEF);
3496 vmask = _kshiftri_mask64(vmask, 16);
3497 _mm512_mask_storeu_epi8(c0, vmask, vout0123x0123456789ABCDEF);
3498
3499 nc = 0;
3500 }
3501 } while (nc != 0);
3502 }
3503
xnn_qs8_vadd_minmax_ukernel__avx512skx_mul32_ld128_x16(size_t n,const int8_t * input_a,const int8_t * input_b,int8_t * output,const union xnn_qs8_add_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])3504 void xnn_qs8_vadd_minmax_ukernel__avx512skx_mul32_ld128_x16(
3505 size_t n,
3506 const int8_t* input_a,
3507 const int8_t* input_b,
3508 int8_t* output,
3509 const union xnn_qs8_add_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
3510 {
3511 const __m512i vbias = _mm512_load_si512(params->avx512.bias);
3512 const __m512i va_multiplier = _mm512_load_si512(params->avx512.a_multiplier);
3513 const __m512i vb_multiplier = _mm512_load_si512(params->avx512.b_multiplier);
3514 const __m128i vshift = _mm_load_si128((const __m128i*) params->avx512.shift);
3515 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->avx512.output_zero_point);
3516 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->avx512.output_min);
3517 const __m128i voutput_max = _mm_load_si128((const __m128i*) params->avx512.output_max);
3518
3519 for (; n >= 16 * sizeof(int8_t); n -= 16 * sizeof(int8_t)) {
3520 const __m512i va0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) input_a));
3521 const __m512i vb0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) input_b));
3522 input_a += 16;
3523 input_b += 16;
3524
3525 __m512i vacc0123456789ABCDEF = _mm512_add_epi32(vbias, _mm512_mullo_epi32(va0123456789ABCDEF, va_multiplier));
3526
3527 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vb0123456789ABCDEF, vb_multiplier));
3528
3529 vacc0123456789ABCDEF = _mm512_sra_epi32(vacc0123456789ABCDEF, vshift);
3530
3531 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0123456789ABCDEF, 1)), voutput_zero_point);
3532
3533 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packs_epi16(_mm256_castsi256_si128(vout012389AB4567CDEF), _mm256_extracti128_si256(vout012389AB4567CDEF, 1)), _MM_SHUFFLE(3, 1, 2, 0));
3534
3535 vout0123456789ABCDEF = _mm_max_epi8(vout0123456789ABCDEF, voutput_min);
3536
3537 vout0123456789ABCDEF = _mm_min_epi8(vout0123456789ABCDEF, voutput_max);
3538
3539 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
3540 output += 16;
3541 }
3542 if XNN_UNLIKELY(n != 0) {
3543 {
3544 const __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << n) - UINT32_C(1)));
3545 const __m512i va0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_maskz_loadu_epi8(vmask, input_a));
3546 const __m512i vb0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_maskz_loadu_epi8(vmask, input_b));
3547
3548 __m512i vacc0123456789ABCDEF = _mm512_add_epi32(vbias, _mm512_mullo_epi32(va0123456789ABCDEF, va_multiplier));
3549
3550 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vb0123456789ABCDEF, vb_multiplier));
3551
3552 vacc0123456789ABCDEF = _mm512_sra_epi32(vacc0123456789ABCDEF, vshift);
3553
3554 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0123456789ABCDEF, 1)), voutput_zero_point);
3555 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packs_epi16(_mm256_castsi256_si128(vout012389AB4567CDEF), _mm256_extracti128_si256(vout012389AB4567CDEF, 1)), _MM_SHUFFLE(3, 1, 2, 0));
3556 vout0123456789ABCDEF = _mm_max_epi8(vout0123456789ABCDEF, voutput_min);
3557 vout0123456789ABCDEF = _mm_min_epi8(vout0123456789ABCDEF, voutput_max);
3558
3559 _mm_mask_storeu_epi8(output, vmask, vout0123456789ABCDEF);
3560 }
3561 }
3562 }
3563
xnn_qs8_vaddc_minmax_ukernel__avx512skx_mul32_ld128_x16(size_t n,const int8_t * input_a,const int8_t * input_b,int8_t * output,const union xnn_qs8_add_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])3564 void xnn_qs8_vaddc_minmax_ukernel__avx512skx_mul32_ld128_x16(
3565 size_t n,
3566 const int8_t* input_a,
3567 const int8_t* input_b,
3568 int8_t* output,
3569 const union xnn_qs8_add_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
3570 {
3571 const __m512i va_multiplier = _mm512_load_si512(params->avx512.a_multiplier);
3572 const __m128i vshift = _mm_load_si128((const __m128i*) params->avx512.shift);
3573 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->avx512.output_zero_point);
3574 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->avx512.output_min);
3575 const __m128i voutput_max = _mm_load_si128((const __m128i*) params->avx512.output_max);
3576
3577 const __m512i vbias = _mm512_add_epi32(
3578 _mm512_broadcastd_epi32(_mm_cvtsi32_si128(params->avx512.b_multiplier[0] * (int32_t) *input_b)),
3579 _mm512_load_si512(params->avx512.bias));
3580 for (; n >= 16 * sizeof(int8_t); n -= 16 * sizeof(int8_t)) {
3581 const __m512i va0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) input_a));
3582 input_a += 16;
3583
3584 __m512i vacc0123456789ABCDEF = _mm512_add_epi32(vbias, _mm512_mullo_epi32(va0123456789ABCDEF, va_multiplier));
3585
3586 vacc0123456789ABCDEF = _mm512_sra_epi32(vacc0123456789ABCDEF, vshift);
3587
3588 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0123456789ABCDEF, 1)), voutput_zero_point);
3589
3590 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packs_epi16(_mm256_castsi256_si128(vout012389AB4567CDEF), _mm256_extracti128_si256(vout012389AB4567CDEF, 1)), _MM_SHUFFLE(3, 1, 2, 0));
3591
3592 vout0123456789ABCDEF = _mm_max_epi8(vout0123456789ABCDEF, voutput_min);
3593
3594 vout0123456789ABCDEF = _mm_min_epi8(vout0123456789ABCDEF, voutput_max);
3595
3596 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
3597 output += 16;
3598 }
3599 if XNN_UNLIKELY(n != 0) {
3600 {
3601 const __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << n) - UINT32_C(1)));
3602 const __m512i va0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_maskz_loadu_epi8(vmask, input_a));
3603
3604 __m512i vacc0123456789ABCDEF = _mm512_add_epi32(vbias, _mm512_mullo_epi32(va0123456789ABCDEF, va_multiplier));
3605
3606 vacc0123456789ABCDEF = _mm512_sra_epi32(vacc0123456789ABCDEF, vshift);
3607
3608 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0123456789ABCDEF, 1)), voutput_zero_point);
3609 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packs_epi16(_mm256_castsi256_si128(vout012389AB4567CDEF), _mm256_extracti128_si256(vout012389AB4567CDEF, 1)), _MM_SHUFFLE(3, 1, 2, 0));
3610 vout0123456789ABCDEF = _mm_max_epi8(vout0123456789ABCDEF, voutput_min);
3611 vout0123456789ABCDEF = _mm_min_epi8(vout0123456789ABCDEF, voutput_max);
3612
3613 _mm_mask_storeu_epi8(output, vmask, vout0123456789ABCDEF);
3614 }
3615 }
3616 }
3617
xnn_qu8_dwconv_minmax_fp32_ukernel_up32x25__avx512skx_mul32(size_t channels,size_t output_width,const uint8_t ** input,const void * weights,uint8_t * output,size_t input_stride,size_t output_increment,size_t input_offset,const uint8_t * zero,const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])3618 void xnn_qu8_dwconv_minmax_fp32_ukernel_up32x25__avx512skx_mul32(
3619 size_t channels,
3620 size_t output_width,
3621 const uint8_t** input,
3622 const void* weights,
3623 uint8_t* output,
3624 size_t input_stride,
3625 size_t output_increment,
3626 size_t input_offset,
3627 const uint8_t* zero,
3628 const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_MSAN
3629 {
3630 assert(channels != 0);
3631 assert(output_width != 0);
3632
3633 const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
3634 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
3635 const __m512i voutput_zero_point = _mm512_load_si512(params->fp32_avx512.output_zero_point);
3636 const __m256i voutput_min = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_min);
3637 const __m256i vpermute_mask = _mm256_set_epi32(7, 3, 5, 1, 6, 2, 4, 0);
3638
3639 const __m512i vk_zero_point = _mm512_cvtepu16_epi32(_mm256_load_si256((const __m256i*) params->fp32_avx512.kernel_zero_point));
3640 do {
3641 const uint8_t* i0 = input[0];
3642 assert(i0 != NULL);
3643 if XNN_UNPREDICTABLE(i0 != zero) {
3644 i0 = (const uint8_t*) ((uintptr_t) i0 + input_offset);
3645 }
3646 const uint8_t* i1 = input[1];
3647 assert(i1 != NULL);
3648 if XNN_UNPREDICTABLE(i1 != zero) {
3649 i1 = (const uint8_t*) ((uintptr_t) i1 + input_offset);
3650 }
3651 const uint8_t* i2 = input[2];
3652 assert(i2 != NULL);
3653 if XNN_UNPREDICTABLE(i2 != zero) {
3654 i2 = (const uint8_t*) ((uintptr_t) i2 + input_offset);
3655 }
3656 const uint8_t* i3 = input[3];
3657 assert(i3 != NULL);
3658 if XNN_UNPREDICTABLE(i3 != zero) {
3659 i3 = (const uint8_t*) ((uintptr_t) i3 + input_offset);
3660 }
3661 const uint8_t* i4 = input[4];
3662 assert(i4 != NULL);
3663 if XNN_UNPREDICTABLE(i4 != zero) {
3664 i4 = (const uint8_t*) ((uintptr_t) i4 + input_offset);
3665 }
3666 const uint8_t* i5 = input[5];
3667 assert(i5 != NULL);
3668 if XNN_UNPREDICTABLE(i5 != zero) {
3669 i5 = (const uint8_t*) ((uintptr_t) i5 + input_offset);
3670 }
3671 const uint8_t* i6 = input[6];
3672 assert(i6 != NULL);
3673 if XNN_UNPREDICTABLE(i6 != zero) {
3674 i6 = (const uint8_t*) ((uintptr_t) i6 + input_offset);
3675 }
3676 const uint8_t* i7 = input[7];
3677 assert(i7 != NULL);
3678 if XNN_UNPREDICTABLE(i7 != zero) {
3679 i7 = (const uint8_t*) ((uintptr_t) i7 + input_offset);
3680 }
3681 const uint8_t* i8 = input[8];
3682 assert(i8 != NULL);
3683 if XNN_UNPREDICTABLE(i8 != zero) {
3684 i8 = (const uint8_t*) ((uintptr_t) i8 + input_offset);
3685 }
3686 const uint8_t* i9 = input[9];
3687 assert(i9 != NULL);
3688 if XNN_UNPREDICTABLE(i9 != zero) {
3689 i9 = (const uint8_t*) ((uintptr_t) i9 + input_offset);
3690 }
3691 const uint8_t* i10 = input[10];
3692 assert(i10 != NULL);
3693 if XNN_UNPREDICTABLE(i10 != zero) {
3694 i10 = (const uint8_t*) ((uintptr_t) i10 + input_offset);
3695 }
3696 const uint8_t* i11 = input[11];
3697 assert(i11 != NULL);
3698 if XNN_UNPREDICTABLE(i11 != zero) {
3699 i11 = (const uint8_t*) ((uintptr_t) i11 + input_offset);
3700 }
3701 const uint8_t* i12 = input[12];
3702 assert(i12 != NULL);
3703 if XNN_UNPREDICTABLE(i12 != zero) {
3704 i12 = (const uint8_t*) ((uintptr_t) i12 + input_offset);
3705 }
3706 const uint8_t* i13 = input[13];
3707 assert(i13 != NULL);
3708 if XNN_UNPREDICTABLE(i13 != zero) {
3709 i13 = (const uint8_t*) ((uintptr_t) i13 + input_offset);
3710 }
3711 const uint8_t* i14 = input[14];
3712 assert(i14 != NULL);
3713 if XNN_UNPREDICTABLE(i14 != zero) {
3714 i14 = (const uint8_t*) ((uintptr_t) i14 + input_offset);
3715 }
3716 const uint8_t* i15 = input[15];
3717 assert(i15 != NULL);
3718 if XNN_UNPREDICTABLE(i15 != zero) {
3719 i15 = (const uint8_t*) ((uintptr_t) i15 + input_offset);
3720 }
3721 const uint8_t* i16 = input[16];
3722 assert(i16 != NULL);
3723 if XNN_UNPREDICTABLE(i16 != zero) {
3724 i16 = (const uint8_t*) ((uintptr_t) i16 + input_offset);
3725 }
3726 const uint8_t* i17 = input[17];
3727 assert(i17 != NULL);
3728 if XNN_UNPREDICTABLE(i17 != zero) {
3729 i17 = (const uint8_t*) ((uintptr_t) i17 + input_offset);
3730 }
3731 const uint8_t* i18 = input[18];
3732 assert(i18 != NULL);
3733 if XNN_UNPREDICTABLE(i18 != zero) {
3734 i18 = (const uint8_t*) ((uintptr_t) i18 + input_offset);
3735 }
3736 const uint8_t* i19 = input[19];
3737 assert(i19 != NULL);
3738 if XNN_UNPREDICTABLE(i19 != zero) {
3739 i19 = (const uint8_t*) ((uintptr_t) i19 + input_offset);
3740 }
3741 const uint8_t* i20 = input[20];
3742 assert(i20 != NULL);
3743 if XNN_UNPREDICTABLE(i20 != zero) {
3744 i20 = (const uint8_t*) ((uintptr_t) i20 + input_offset);
3745 }
3746 const uint8_t* i21 = input[21];
3747 assert(i21 != NULL);
3748 if XNN_UNPREDICTABLE(i21 != zero) {
3749 i21 = (const uint8_t*) ((uintptr_t) i21 + input_offset);
3750 }
3751 const uint8_t* i22 = input[22];
3752 assert(i22 != NULL);
3753 if XNN_UNPREDICTABLE(i22 != zero) {
3754 i22 = (const uint8_t*) ((uintptr_t) i22 + input_offset);
3755 }
3756 const uint8_t* i23 = input[23];
3757 assert(i23 != NULL);
3758 if XNN_UNPREDICTABLE(i23 != zero) {
3759 i23 = (const uint8_t*) ((uintptr_t) i23 + input_offset);
3760 }
3761 const uint8_t* i24 = input[24];
3762 assert(i24 != NULL);
3763 if XNN_UNPREDICTABLE(i24 != zero) {
3764 i24 = (const uint8_t*) ((uintptr_t) i24 + input_offset);
3765 }
3766 input = (const uint8_t**) ((uintptr_t) input + input_stride);
3767
3768 size_t c = channels;
3769 const void* w = weights;
3770 for (; c >= 32; c -= 32) {
3771 __m512i vacc0123456789ABCDEF = _mm512_loadu_si512(w);
3772 __m512i vaccGHIJKLMNOPQRSTUV = _mm512_loadu_si512((const void*) ((uintptr_t) w + 16 * sizeof(int32_t)));
3773
3774
3775 const __m512i vi0x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i0));
3776 const __m512i vk0x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 0 * sizeof(uint8_t)))), vk_zero_point);
3777 const __m512i vi0xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i0 + 16)));
3778 const __m512i vk0xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 16 * sizeof(uint8_t)))), vk_zero_point);
3779 i0 += 32;
3780
3781 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF));
3782 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi0xGHIJKLMNOPQRSTUV, vk0xGHIJKLMNOPQRSTUV));
3783
3784 const __m512i vi1x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i1));
3785 const __m512i vk1x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 32 * sizeof(uint8_t)))), vk_zero_point);
3786 const __m512i vi1xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i1 + 16)));
3787 const __m512i vk1xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 48 * sizeof(uint8_t)))), vk_zero_point);
3788 i1 += 32;
3789
3790 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF));
3791 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi1xGHIJKLMNOPQRSTUV, vk1xGHIJKLMNOPQRSTUV));
3792
3793 const __m512i vi2x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i2));
3794 const __m512i vk2x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 64 * sizeof(uint8_t)))), vk_zero_point);
3795 const __m512i vi2xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i2 + 16)));
3796 const __m512i vk2xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 80 * sizeof(uint8_t)))), vk_zero_point);
3797 i2 += 32;
3798
3799 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF));
3800 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi2xGHIJKLMNOPQRSTUV, vk2xGHIJKLMNOPQRSTUV));
3801
3802 const __m512i vi3x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i3));
3803 const __m512i vk3x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 96 * sizeof(uint8_t)))), vk_zero_point);
3804 const __m512i vi3xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i3 + 16)));
3805 const __m512i vk3xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 112 * sizeof(uint8_t)))), vk_zero_point);
3806 i3 += 32;
3807
3808 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF));
3809 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi3xGHIJKLMNOPQRSTUV, vk3xGHIJKLMNOPQRSTUV));
3810
3811 const __m512i vi4x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i4));
3812 const __m512i vk4x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 128 * sizeof(uint8_t)))), vk_zero_point);
3813 const __m512i vi4xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i4 + 16)));
3814 const __m512i vk4xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 144 * sizeof(uint8_t)))), vk_zero_point);
3815 i4 += 32;
3816
3817 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF));
3818 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi4xGHIJKLMNOPQRSTUV, vk4xGHIJKLMNOPQRSTUV));
3819
3820 const __m512i vi5x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i5));
3821 const __m512i vk5x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 160 * sizeof(uint8_t)))), vk_zero_point);
3822 const __m512i vi5xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i5 + 16)));
3823 const __m512i vk5xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 176 * sizeof(uint8_t)))), vk_zero_point);
3824 i5 += 32;
3825
3826 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF));
3827 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi5xGHIJKLMNOPQRSTUV, vk5xGHIJKLMNOPQRSTUV));
3828
3829 const __m512i vi6x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i6));
3830 const __m512i vk6x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 192 * sizeof(uint8_t)))), vk_zero_point);
3831 const __m512i vi6xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i6 + 16)));
3832 const __m512i vk6xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 208 * sizeof(uint8_t)))), vk_zero_point);
3833 i6 += 32;
3834
3835 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF));
3836 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi6xGHIJKLMNOPQRSTUV, vk6xGHIJKLMNOPQRSTUV));
3837
3838 const __m512i vi7x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i7));
3839 const __m512i vk7x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 224 * sizeof(uint8_t)))), vk_zero_point);
3840 const __m512i vi7xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i7 + 16)));
3841 const __m512i vk7xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 240 * sizeof(uint8_t)))), vk_zero_point);
3842 i7 += 32;
3843
3844 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF));
3845 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi7xGHIJKLMNOPQRSTUV, vk7xGHIJKLMNOPQRSTUV));
3846
3847 const __m512i vi8x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i8));
3848 const __m512i vk8x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 256 * sizeof(uint8_t)))), vk_zero_point);
3849 const __m512i vi8xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i8 + 16)));
3850 const __m512i vk8xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 272 * sizeof(uint8_t)))), vk_zero_point);
3851 i8 += 32;
3852
3853 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF));
3854 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi8xGHIJKLMNOPQRSTUV, vk8xGHIJKLMNOPQRSTUV));
3855
3856 const __m512i vi9x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i9));
3857 const __m512i vk9x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 288 * sizeof(uint8_t)))), vk_zero_point);
3858 const __m512i vi9xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i9 + 16)));
3859 const __m512i vk9xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 304 * sizeof(uint8_t)))), vk_zero_point);
3860 i9 += 32;
3861
3862 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi9x0123456789ABCDEF, vk9x0123456789ABCDEF));
3863 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi9xGHIJKLMNOPQRSTUV, vk9xGHIJKLMNOPQRSTUV));
3864
3865 const __m512i vi10x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i10));
3866 const __m512i vk10x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 320 * sizeof(uint8_t)))), vk_zero_point);
3867 const __m512i vi10xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i10 + 16)));
3868 const __m512i vk10xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 336 * sizeof(uint8_t)))), vk_zero_point);
3869 i10 += 32;
3870
3871 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi10x0123456789ABCDEF, vk10x0123456789ABCDEF));
3872 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi10xGHIJKLMNOPQRSTUV, vk10xGHIJKLMNOPQRSTUV));
3873
3874 const __m512i vi11x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i11));
3875 const __m512i vk11x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 352 * sizeof(uint8_t)))), vk_zero_point);
3876 const __m512i vi11xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i11 + 16)));
3877 const __m512i vk11xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 368 * sizeof(uint8_t)))), vk_zero_point);
3878 i11 += 32;
3879
3880 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi11x0123456789ABCDEF, vk11x0123456789ABCDEF));
3881 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi11xGHIJKLMNOPQRSTUV, vk11xGHIJKLMNOPQRSTUV));
3882
3883 const __m512i vi12x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i12));
3884 const __m512i vk12x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 384 * sizeof(uint8_t)))), vk_zero_point);
3885 const __m512i vi12xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i12 + 16)));
3886 const __m512i vk12xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 400 * sizeof(uint8_t)))), vk_zero_point);
3887 i12 += 32;
3888
3889 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi12x0123456789ABCDEF, vk12x0123456789ABCDEF));
3890 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi12xGHIJKLMNOPQRSTUV, vk12xGHIJKLMNOPQRSTUV));
3891
3892 const __m512i vi13x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i13));
3893 const __m512i vk13x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 416 * sizeof(uint8_t)))), vk_zero_point);
3894 const __m512i vi13xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i13 + 16)));
3895 const __m512i vk13xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 432 * sizeof(uint8_t)))), vk_zero_point);
3896 i13 += 32;
3897
3898 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi13x0123456789ABCDEF, vk13x0123456789ABCDEF));
3899 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi13xGHIJKLMNOPQRSTUV, vk13xGHIJKLMNOPQRSTUV));
3900
3901 const __m512i vi14x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i14));
3902 const __m512i vk14x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 448 * sizeof(uint8_t)))), vk_zero_point);
3903 const __m512i vi14xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i14 + 16)));
3904 const __m512i vk14xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 464 * sizeof(uint8_t)))), vk_zero_point);
3905 i14 += 32;
3906
3907 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi14x0123456789ABCDEF, vk14x0123456789ABCDEF));
3908 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi14xGHIJKLMNOPQRSTUV, vk14xGHIJKLMNOPQRSTUV));
3909
3910 const __m512i vi15x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i15));
3911 const __m512i vk15x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 480 * sizeof(uint8_t)))), vk_zero_point);
3912 const __m512i vi15xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i15 + 16)));
3913 const __m512i vk15xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 496 * sizeof(uint8_t)))), vk_zero_point);
3914 i15 += 32;
3915
3916 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi15x0123456789ABCDEF, vk15x0123456789ABCDEF));
3917 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi15xGHIJKLMNOPQRSTUV, vk15xGHIJKLMNOPQRSTUV));
3918
3919 const __m512i vi16x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i16));
3920 const __m512i vk16x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 512 * sizeof(uint8_t)))), vk_zero_point);
3921 const __m512i vi16xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i16 + 16)));
3922 const __m512i vk16xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 528 * sizeof(uint8_t)))), vk_zero_point);
3923 i16 += 32;
3924
3925 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi16x0123456789ABCDEF, vk16x0123456789ABCDEF));
3926 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi16xGHIJKLMNOPQRSTUV, vk16xGHIJKLMNOPQRSTUV));
3927
3928 const __m512i vi17x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i17));
3929 const __m512i vk17x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 544 * sizeof(uint8_t)))), vk_zero_point);
3930 const __m512i vi17xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i17 + 16)));
3931 const __m512i vk17xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 560 * sizeof(uint8_t)))), vk_zero_point);
3932 i17 += 32;
3933
3934 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi17x0123456789ABCDEF, vk17x0123456789ABCDEF));
3935 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi17xGHIJKLMNOPQRSTUV, vk17xGHIJKLMNOPQRSTUV));
3936
3937 const __m512i vi18x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i18));
3938 const __m512i vk18x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 576 * sizeof(uint8_t)))), vk_zero_point);
3939 const __m512i vi18xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i18 + 16)));
3940 const __m512i vk18xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 592 * sizeof(uint8_t)))), vk_zero_point);
3941 i18 += 32;
3942
3943 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi18x0123456789ABCDEF, vk18x0123456789ABCDEF));
3944 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi18xGHIJKLMNOPQRSTUV, vk18xGHIJKLMNOPQRSTUV));
3945
3946 const __m512i vi19x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i19));
3947 const __m512i vk19x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 608 * sizeof(uint8_t)))), vk_zero_point);
3948 const __m512i vi19xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i19 + 16)));
3949 const __m512i vk19xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 624 * sizeof(uint8_t)))), vk_zero_point);
3950 i19 += 32;
3951
3952 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi19x0123456789ABCDEF, vk19x0123456789ABCDEF));
3953 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi19xGHIJKLMNOPQRSTUV, vk19xGHIJKLMNOPQRSTUV));
3954
3955 const __m512i vi20x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i20));
3956 const __m512i vk20x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 640 * sizeof(uint8_t)))), vk_zero_point);
3957 const __m512i vi20xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i20 + 16)));
3958 const __m512i vk20xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 656 * sizeof(uint8_t)))), vk_zero_point);
3959 i20 += 32;
3960
3961 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi20x0123456789ABCDEF, vk20x0123456789ABCDEF));
3962 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi20xGHIJKLMNOPQRSTUV, vk20xGHIJKLMNOPQRSTUV));
3963
3964 const __m512i vi21x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i21));
3965 const __m512i vk21x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 672 * sizeof(uint8_t)))), vk_zero_point);
3966 const __m512i vi21xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i21 + 16)));
3967 const __m512i vk21xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 688 * sizeof(uint8_t)))), vk_zero_point);
3968 i21 += 32;
3969
3970 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi21x0123456789ABCDEF, vk21x0123456789ABCDEF));
3971 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi21xGHIJKLMNOPQRSTUV, vk21xGHIJKLMNOPQRSTUV));
3972
3973 const __m512i vi22x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i22));
3974 const __m512i vk22x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 704 * sizeof(uint8_t)))), vk_zero_point);
3975 const __m512i vi22xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i22 + 16)));
3976 const __m512i vk22xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 720 * sizeof(uint8_t)))), vk_zero_point);
3977 i22 += 32;
3978
3979 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi22x0123456789ABCDEF, vk22x0123456789ABCDEF));
3980 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi22xGHIJKLMNOPQRSTUV, vk22xGHIJKLMNOPQRSTUV));
3981
3982 const __m512i vi23x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i23));
3983 const __m512i vk23x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 736 * sizeof(uint8_t)))), vk_zero_point);
3984 const __m512i vi23xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i23 + 16)));
3985 const __m512i vk23xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 752 * sizeof(uint8_t)))), vk_zero_point);
3986 i23 += 32;
3987
3988 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi23x0123456789ABCDEF, vk23x0123456789ABCDEF));
3989 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi23xGHIJKLMNOPQRSTUV, vk23xGHIJKLMNOPQRSTUV));
3990
3991 const __m512i vi24x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i24));
3992 const __m512i vk24x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 768 * sizeof(uint8_t)))), vk_zero_point);
3993 const __m512i vi24xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i24 + 16)));
3994 const __m512i vk24xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 784 * sizeof(uint8_t)))), vk_zero_point);
3995 i24 += 32;
3996
3997 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi24x0123456789ABCDEF, vk24x0123456789ABCDEF));
3998 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi24xGHIJKLMNOPQRSTUV, vk24xGHIJKLMNOPQRSTUV));
3999
4000 w = (const void*) ((uintptr_t) w + 32 * sizeof(int32_t) + 800 * sizeof(uint8_t));
4001
4002 __m512 vscaled0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0123456789ABCDEF);
4003 __m512 vscaledGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vaccGHIJKLMNOPQRSTUV);
4004
4005 vscaled0123456789ABCDEF = _mm512_mul_ps(vscaled0123456789ABCDEF, vscale);
4006 vscaledGHIJKLMNOPQRSTUV = _mm512_mul_ps(vscaledGHIJKLMNOPQRSTUV, vscale);
4007
4008 vscaled0123456789ABCDEF = _mm512_min_ps(vscaled0123456789ABCDEF, voutput_max_less_zero_point);
4009 vscaledGHIJKLMNOPQRSTUV = _mm512_min_ps(vscaledGHIJKLMNOPQRSTUV, voutput_max_less_zero_point);
4010
4011 vacc0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0123456789ABCDEF);
4012 vaccGHIJKLMNOPQRSTUV = _mm512_cvtps_epi32(vscaledGHIJKLMNOPQRSTUV);
4013
4014 __m512i vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV = _mm512_adds_epi16(_mm512_packs_epi32(vacc0123456789ABCDEF, vaccGHIJKLMNOPQRSTUV), voutput_zero_point);
4015 __m256i voutGHIJOPQRKLMNSTUV = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vaccGHIJKLMNOPQRSTUV), _mm512_extracti32x8_epi32(vaccGHIJKLMNOPQRSTUV, 1)), _mm512_castsi512_si256(voutput_zero_point));
4016
4017 const __m256i vout0123GHIJ4567KLMN = _mm512_castsi512_si256(vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV);
4018 const __m256i vout89ABOPQRCDEFSTUV = _mm512_extracti32x8_epi32(vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV, 1);
4019 const __m256i vout0123GHIJ89ABOPQR4567KLMNCDEFSTUV = _mm256_packus_epi16(vout0123GHIJ4567KLMN, vout89ABOPQRCDEFSTUV);
4020 __m256i vout0123456789ABCDEFGHIJKLMNOPQRSTUV = _mm256_permutevar8x32_epi32(vout0123GHIJ89ABOPQR4567KLMNCDEFSTUV, vpermute_mask);
4021 const __m128i voutGHIJOPQR = _mm256_castsi256_si128(voutGHIJOPQRKLMNSTUV);
4022 const __m128i voutKLMNSTUV = _mm256_extracti128_si256(voutGHIJOPQRKLMNSTUV, 1);
4023 __m128i voutGHIJKLMNOPQRSTUV = _mm_shuffle_epi32(_mm_packus_epi16(voutGHIJOPQR, voutKLMNSTUV), _MM_SHUFFLE(3, 1, 2, 0));
4024
4025 vout0123456789ABCDEFGHIJKLMNOPQRSTUV = _mm256_max_epu8(vout0123456789ABCDEFGHIJKLMNOPQRSTUV, voutput_min);
4026 voutGHIJKLMNOPQRSTUV = _mm_max_epu8(voutGHIJKLMNOPQRSTUV, _mm256_castsi256_si128(voutput_min));
4027
4028 _mm256_storeu_si256((__m256i*) output, vout0123456789ABCDEFGHIJKLMNOPQRSTUV);
4029 _mm_storeu_si128((__m128i*) (output + 16), voutGHIJKLMNOPQRSTUV);
4030 output += 32;
4031 }
4032 if XNN_UNLIKELY(c != 0) {
4033 // Prepare mask for valid 8-bit elements (depends on nc).
4034 const __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << (c & 15)) - UINT32_C(1)));
4035 const uint8_t* k = (const uint8_t*) ((uintptr_t) w + 32 * sizeof(int32_t));
4036 do {
4037 __m512i vacc0123456789ABCDEF = _mm512_loadu_si512(w);
4038
4039
4040 const __m512i vi0x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i0));
4041 const __m512i vk0x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) k)), vk_zero_point);
4042 i0 += 16;
4043
4044 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF));
4045
4046 const __m512i vi1x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i1));
4047 const __m512i vk1x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 32))), vk_zero_point);
4048 i1 += 16;
4049
4050 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF));
4051
4052 const __m512i vi2x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i2));
4053 const __m512i vk2x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 64))), vk_zero_point);
4054 i2 += 16;
4055
4056 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF));
4057
4058 const __m512i vi3x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i3));
4059 const __m512i vk3x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 96))), vk_zero_point);
4060 i3 += 16;
4061
4062 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF));
4063
4064 const __m512i vi4x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i4));
4065 const __m512i vk4x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 128))), vk_zero_point);
4066 i4 += 16;
4067
4068 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF));
4069
4070 const __m512i vi5x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i5));
4071 const __m512i vk5x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 160))), vk_zero_point);
4072 i5 += 16;
4073
4074 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF));
4075
4076 const __m512i vi6x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i6));
4077 const __m512i vk6x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 192))), vk_zero_point);
4078 i6 += 16;
4079
4080 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF));
4081
4082 const __m512i vi7x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i7));
4083 const __m512i vk7x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 224))), vk_zero_point);
4084 i7 += 16;
4085
4086 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF));
4087
4088 const __m512i vi8x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i8));
4089 const __m512i vk8x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 256))), vk_zero_point);
4090 i8 += 16;
4091
4092 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF));
4093
4094 const __m512i vi9x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i9));
4095 const __m512i vk9x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 288))), vk_zero_point);
4096 i9 += 16;
4097
4098 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi9x0123456789ABCDEF, vk9x0123456789ABCDEF));
4099
4100 const __m512i vi10x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i10));
4101 const __m512i vk10x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 320))), vk_zero_point);
4102 i10 += 16;
4103
4104 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi10x0123456789ABCDEF, vk10x0123456789ABCDEF));
4105
4106 const __m512i vi11x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i11));
4107 const __m512i vk11x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 352))), vk_zero_point);
4108 i11 += 16;
4109
4110 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi11x0123456789ABCDEF, vk11x0123456789ABCDEF));
4111
4112 const __m512i vi12x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i12));
4113 const __m512i vk12x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 384))), vk_zero_point);
4114 i12 += 16;
4115
4116 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi12x0123456789ABCDEF, vk12x0123456789ABCDEF));
4117
4118 const __m512i vi13x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i13));
4119 const __m512i vk13x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 416))), vk_zero_point);
4120 i13 += 16;
4121
4122 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi13x0123456789ABCDEF, vk13x0123456789ABCDEF));
4123
4124 const __m512i vi14x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i14));
4125 const __m512i vk14x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 448))), vk_zero_point);
4126 i14 += 16;
4127
4128 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi14x0123456789ABCDEF, vk14x0123456789ABCDEF));
4129
4130 const __m512i vi15x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i15));
4131 const __m512i vk15x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 480))), vk_zero_point);
4132 i15 += 16;
4133
4134 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi15x0123456789ABCDEF, vk15x0123456789ABCDEF));
4135
4136 const __m512i vi16x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i16));
4137 const __m512i vk16x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 512))), vk_zero_point);
4138 i16 += 16;
4139
4140 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi16x0123456789ABCDEF, vk16x0123456789ABCDEF));
4141
4142 const __m512i vi17x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i17));
4143 const __m512i vk17x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 544))), vk_zero_point);
4144 i17 += 16;
4145
4146 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi17x0123456789ABCDEF, vk17x0123456789ABCDEF));
4147
4148 const __m512i vi18x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i18));
4149 const __m512i vk18x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 576))), vk_zero_point);
4150 i18 += 16;
4151
4152 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi18x0123456789ABCDEF, vk18x0123456789ABCDEF));
4153
4154 const __m512i vi19x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i19));
4155 const __m512i vk19x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 608))), vk_zero_point);
4156 i19 += 16;
4157
4158 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi19x0123456789ABCDEF, vk19x0123456789ABCDEF));
4159
4160 const __m512i vi20x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i20));
4161 const __m512i vk20x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 640))), vk_zero_point);
4162 i20 += 16;
4163
4164 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi20x0123456789ABCDEF, vk20x0123456789ABCDEF));
4165
4166 const __m512i vi21x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i21));
4167 const __m512i vk21x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 672))), vk_zero_point);
4168 i21 += 16;
4169
4170 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi21x0123456789ABCDEF, vk21x0123456789ABCDEF));
4171
4172 const __m512i vi22x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i22));
4173 const __m512i vk22x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 704))), vk_zero_point);
4174 i22 += 16;
4175
4176 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi22x0123456789ABCDEF, vk22x0123456789ABCDEF));
4177
4178 const __m512i vi23x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i23));
4179 const __m512i vk23x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 736))), vk_zero_point);
4180 i23 += 16;
4181
4182 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi23x0123456789ABCDEF, vk23x0123456789ABCDEF));
4183
4184 const __m512i vi24x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i24));
4185 const __m512i vk24x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 768))), vk_zero_point);
4186 i24 += 16;
4187
4188 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi24x0123456789ABCDEF, vk24x0123456789ABCDEF));
4189
4190 k += 16;
4191
4192 __m512 vscaled0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0123456789ABCDEF);
4193 vscaled0123456789ABCDEF = _mm512_mul_ps(vscaled0123456789ABCDEF, vscale);
4194 vscaled0123456789ABCDEF = _mm512_min_ps(vscaled0123456789ABCDEF, voutput_max_less_zero_point);
4195 vacc0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0123456789ABCDEF);
4196
4197 w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
4198
4199 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0123456789ABCDEF, 1)), _mm512_castsi512_si256(voutput_zero_point));
4200
4201 const __m128i vout012389AB = _mm256_castsi256_si128(vout012389AB4567CDEF);
4202 const __m128i vout4567CDEF = _mm256_extracti128_si256(vout012389AB4567CDEF, 1);
4203 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packus_epi16(vout012389AB, vout4567CDEF), _MM_SHUFFLE(3, 1, 2, 0));
4204 vout0123456789ABCDEF = _mm_max_epu8(vout0123456789ABCDEF, _mm256_castsi256_si128(voutput_min));
4205
4206 if XNN_LIKELY(c >= 16) {
4207 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
4208 output += 16;
4209 c -= 16;
4210 } else {
4211 _mm_mask_storeu_epi8(output, vmask, vout0123456789ABCDEF);
4212 output = (uint8_t*) ((uintptr_t) output + c);
4213 c = 0;
4214 }
4215 } while (c != 0);
4216 }
4217
4218 output = (uint8_t*) ((uintptr_t) output + output_increment);
4219 } while (--output_width != 0);
4220 }
4221
xnn_qu8_dwconv_minmax_fp32_ukernel_up32x9__avx512skx_mul32(size_t channels,size_t output_width,const uint8_t ** input,const void * weights,uint8_t * output,size_t input_stride,size_t output_increment,size_t input_offset,const uint8_t * zero,const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])4222 void xnn_qu8_dwconv_minmax_fp32_ukernel_up32x9__avx512skx_mul32(
4223 size_t channels,
4224 size_t output_width,
4225 const uint8_t** input,
4226 const void* weights,
4227 uint8_t* output,
4228 size_t input_stride,
4229 size_t output_increment,
4230 size_t input_offset,
4231 const uint8_t* zero,
4232 const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_MSAN
4233 {
4234 assert(channels != 0);
4235 assert(output_width != 0);
4236
4237 const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
4238 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
4239 const __m512i voutput_zero_point = _mm512_load_si512(params->fp32_avx512.output_zero_point);
4240 const __m256i voutput_min = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_min);
4241 const __m256i vpermute_mask = _mm256_set_epi32(7, 3, 5, 1, 6, 2, 4, 0);
4242
4243 const __m512i vk_zero_point = _mm512_cvtepu16_epi32(_mm256_load_si256((const __m256i*) params->fp32_avx512.kernel_zero_point));
4244 do {
4245 const uint8_t* i0 = input[0];
4246 assert(i0 != NULL);
4247 if XNN_UNPREDICTABLE(i0 != zero) {
4248 i0 = (const uint8_t*) ((uintptr_t) i0 + input_offset);
4249 }
4250 const uint8_t* i1 = input[1];
4251 assert(i1 != NULL);
4252 if XNN_UNPREDICTABLE(i1 != zero) {
4253 i1 = (const uint8_t*) ((uintptr_t) i1 + input_offset);
4254 }
4255 const uint8_t* i2 = input[2];
4256 assert(i2 != NULL);
4257 if XNN_UNPREDICTABLE(i2 != zero) {
4258 i2 = (const uint8_t*) ((uintptr_t) i2 + input_offset);
4259 }
4260 const uint8_t* i3 = input[3];
4261 assert(i3 != NULL);
4262 if XNN_UNPREDICTABLE(i3 != zero) {
4263 i3 = (const uint8_t*) ((uintptr_t) i3 + input_offset);
4264 }
4265 const uint8_t* i4 = input[4];
4266 assert(i4 != NULL);
4267 if XNN_UNPREDICTABLE(i4 != zero) {
4268 i4 = (const uint8_t*) ((uintptr_t) i4 + input_offset);
4269 }
4270 const uint8_t* i5 = input[5];
4271 assert(i5 != NULL);
4272 if XNN_UNPREDICTABLE(i5 != zero) {
4273 i5 = (const uint8_t*) ((uintptr_t) i5 + input_offset);
4274 }
4275 const uint8_t* i6 = input[6];
4276 assert(i6 != NULL);
4277 if XNN_UNPREDICTABLE(i6 != zero) {
4278 i6 = (const uint8_t*) ((uintptr_t) i6 + input_offset);
4279 }
4280 const uint8_t* i7 = input[7];
4281 assert(i7 != NULL);
4282 if XNN_UNPREDICTABLE(i7 != zero) {
4283 i7 = (const uint8_t*) ((uintptr_t) i7 + input_offset);
4284 }
4285 const uint8_t* i8 = input[8];
4286 assert(i8 != NULL);
4287 if XNN_UNPREDICTABLE(i8 != zero) {
4288 i8 = (const uint8_t*) ((uintptr_t) i8 + input_offset);
4289 }
4290 input = (const uint8_t**) ((uintptr_t) input + input_stride);
4291
4292 size_t c = channels;
4293 const void* w = weights;
4294 for (; c >= 32; c -= 32) {
4295 __m512i vacc0123456789ABCDEF = _mm512_loadu_si512(w);
4296 __m512i vaccGHIJKLMNOPQRSTUV = _mm512_loadu_si512((const void*) ((uintptr_t) w + 16 * sizeof(int32_t)));
4297
4298
4299 const __m512i vi0x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i0));
4300 const __m512i vk0x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 0 * sizeof(uint8_t)))), vk_zero_point);
4301 const __m512i vi0xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i0 + 16)));
4302 const __m512i vk0xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 16 * sizeof(uint8_t)))), vk_zero_point);
4303 i0 += 32;
4304
4305 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF));
4306 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi0xGHIJKLMNOPQRSTUV, vk0xGHIJKLMNOPQRSTUV));
4307
4308 const __m512i vi1x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i1));
4309 const __m512i vk1x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 32 * sizeof(uint8_t)))), vk_zero_point);
4310 const __m512i vi1xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i1 + 16)));
4311 const __m512i vk1xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 48 * sizeof(uint8_t)))), vk_zero_point);
4312 i1 += 32;
4313
4314 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF));
4315 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi1xGHIJKLMNOPQRSTUV, vk1xGHIJKLMNOPQRSTUV));
4316
4317 const __m512i vi2x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i2));
4318 const __m512i vk2x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 64 * sizeof(uint8_t)))), vk_zero_point);
4319 const __m512i vi2xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i2 + 16)));
4320 const __m512i vk2xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 80 * sizeof(uint8_t)))), vk_zero_point);
4321 i2 += 32;
4322
4323 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF));
4324 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi2xGHIJKLMNOPQRSTUV, vk2xGHIJKLMNOPQRSTUV));
4325
4326 const __m512i vi3x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i3));
4327 const __m512i vk3x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 96 * sizeof(uint8_t)))), vk_zero_point);
4328 const __m512i vi3xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i3 + 16)));
4329 const __m512i vk3xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 112 * sizeof(uint8_t)))), vk_zero_point);
4330 i3 += 32;
4331
4332 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF));
4333 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi3xGHIJKLMNOPQRSTUV, vk3xGHIJKLMNOPQRSTUV));
4334
4335 const __m512i vi4x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i4));
4336 const __m512i vk4x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 128 * sizeof(uint8_t)))), vk_zero_point);
4337 const __m512i vi4xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i4 + 16)));
4338 const __m512i vk4xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 144 * sizeof(uint8_t)))), vk_zero_point);
4339 i4 += 32;
4340
4341 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF));
4342 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi4xGHIJKLMNOPQRSTUV, vk4xGHIJKLMNOPQRSTUV));
4343
4344 const __m512i vi5x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i5));
4345 const __m512i vk5x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 160 * sizeof(uint8_t)))), vk_zero_point);
4346 const __m512i vi5xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i5 + 16)));
4347 const __m512i vk5xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 176 * sizeof(uint8_t)))), vk_zero_point);
4348 i5 += 32;
4349
4350 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF));
4351 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi5xGHIJKLMNOPQRSTUV, vk5xGHIJKLMNOPQRSTUV));
4352
4353 const __m512i vi6x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i6));
4354 const __m512i vk6x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 192 * sizeof(uint8_t)))), vk_zero_point);
4355 const __m512i vi6xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i6 + 16)));
4356 const __m512i vk6xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 208 * sizeof(uint8_t)))), vk_zero_point);
4357 i6 += 32;
4358
4359 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF));
4360 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi6xGHIJKLMNOPQRSTUV, vk6xGHIJKLMNOPQRSTUV));
4361
4362 const __m512i vi7x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i7));
4363 const __m512i vk7x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 224 * sizeof(uint8_t)))), vk_zero_point);
4364 const __m512i vi7xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i7 + 16)));
4365 const __m512i vk7xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 240 * sizeof(uint8_t)))), vk_zero_point);
4366 i7 += 32;
4367
4368 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF));
4369 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi7xGHIJKLMNOPQRSTUV, vk7xGHIJKLMNOPQRSTUV));
4370
4371 const __m512i vi8x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i8));
4372 const __m512i vk8x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 256 * sizeof(uint8_t)))), vk_zero_point);
4373 const __m512i vi8xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i8 + 16)));
4374 const __m512i vk8xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 272 * sizeof(uint8_t)))), vk_zero_point);
4375 i8 += 32;
4376
4377 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF));
4378 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi8xGHIJKLMNOPQRSTUV, vk8xGHIJKLMNOPQRSTUV));
4379
4380 w = (const void*) ((uintptr_t) w + 32 * sizeof(int32_t) + 288 * sizeof(uint8_t));
4381
4382 __m512 vscaled0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0123456789ABCDEF);
4383 __m512 vscaledGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vaccGHIJKLMNOPQRSTUV);
4384
4385 vscaled0123456789ABCDEF = _mm512_mul_ps(vscaled0123456789ABCDEF, vscale);
4386 vscaledGHIJKLMNOPQRSTUV = _mm512_mul_ps(vscaledGHIJKLMNOPQRSTUV, vscale);
4387
4388 vscaled0123456789ABCDEF = _mm512_min_ps(vscaled0123456789ABCDEF, voutput_max_less_zero_point);
4389 vscaledGHIJKLMNOPQRSTUV = _mm512_min_ps(vscaledGHIJKLMNOPQRSTUV, voutput_max_less_zero_point);
4390
4391 vacc0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0123456789ABCDEF);
4392 vaccGHIJKLMNOPQRSTUV = _mm512_cvtps_epi32(vscaledGHIJKLMNOPQRSTUV);
4393
4394 __m512i vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV = _mm512_adds_epi16(_mm512_packs_epi32(vacc0123456789ABCDEF, vaccGHIJKLMNOPQRSTUV), voutput_zero_point);
4395 __m256i voutGHIJOPQRKLMNSTUV = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vaccGHIJKLMNOPQRSTUV), _mm512_extracti32x8_epi32(vaccGHIJKLMNOPQRSTUV, 1)), _mm512_castsi512_si256(voutput_zero_point));
4396
4397 const __m256i vout0123GHIJ4567KLMN = _mm512_castsi512_si256(vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV);
4398 const __m256i vout89ABOPQRCDEFSTUV = _mm512_extracti32x8_epi32(vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV, 1);
4399 const __m256i vout0123GHIJ89ABOPQR4567KLMNCDEFSTUV = _mm256_packus_epi16(vout0123GHIJ4567KLMN, vout89ABOPQRCDEFSTUV);
4400 __m256i vout0123456789ABCDEFGHIJKLMNOPQRSTUV = _mm256_permutevar8x32_epi32(vout0123GHIJ89ABOPQR4567KLMNCDEFSTUV, vpermute_mask);
4401 const __m128i voutGHIJOPQR = _mm256_castsi256_si128(voutGHIJOPQRKLMNSTUV);
4402 const __m128i voutKLMNSTUV = _mm256_extracti128_si256(voutGHIJOPQRKLMNSTUV, 1);
4403 __m128i voutGHIJKLMNOPQRSTUV = _mm_shuffle_epi32(_mm_packus_epi16(voutGHIJOPQR, voutKLMNSTUV), _MM_SHUFFLE(3, 1, 2, 0));
4404
4405 vout0123456789ABCDEFGHIJKLMNOPQRSTUV = _mm256_max_epu8(vout0123456789ABCDEFGHIJKLMNOPQRSTUV, voutput_min);
4406 voutGHIJKLMNOPQRSTUV = _mm_max_epu8(voutGHIJKLMNOPQRSTUV, _mm256_castsi256_si128(voutput_min));
4407
4408 _mm256_storeu_si256((__m256i*) output, vout0123456789ABCDEFGHIJKLMNOPQRSTUV);
4409 _mm_storeu_si128((__m128i*) (output + 16), voutGHIJKLMNOPQRSTUV);
4410 output += 32;
4411 }
4412 if XNN_UNLIKELY(c != 0) {
4413 // Prepare mask for valid 8-bit elements (depends on nc).
4414 const __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << (c & 15)) - UINT32_C(1)));
4415 const uint8_t* k = (const uint8_t*) ((uintptr_t) w + 32 * sizeof(int32_t));
4416 do {
4417 __m512i vacc0123456789ABCDEF = _mm512_loadu_si512(w);
4418
4419
4420 const __m512i vi0x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i0));
4421 const __m512i vk0x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) k)), vk_zero_point);
4422 i0 += 16;
4423
4424 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF));
4425
4426 const __m512i vi1x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i1));
4427 const __m512i vk1x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 32))), vk_zero_point);
4428 i1 += 16;
4429
4430 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF));
4431
4432 const __m512i vi2x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i2));
4433 const __m512i vk2x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 64))), vk_zero_point);
4434 i2 += 16;
4435
4436 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF));
4437
4438 const __m512i vi3x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i3));
4439 const __m512i vk3x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 96))), vk_zero_point);
4440 i3 += 16;
4441
4442 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF));
4443
4444 const __m512i vi4x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i4));
4445 const __m512i vk4x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 128))), vk_zero_point);
4446 i4 += 16;
4447
4448 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF));
4449
4450 const __m512i vi5x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i5));
4451 const __m512i vk5x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 160))), vk_zero_point);
4452 i5 += 16;
4453
4454 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF));
4455
4456 const __m512i vi6x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i6));
4457 const __m512i vk6x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 192))), vk_zero_point);
4458 i6 += 16;
4459
4460 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF));
4461
4462 const __m512i vi7x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i7));
4463 const __m512i vk7x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 224))), vk_zero_point);
4464 i7 += 16;
4465
4466 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF));
4467
4468 const __m512i vi8x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i8));
4469 const __m512i vk8x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 256))), vk_zero_point);
4470 i8 += 16;
4471
4472 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF));
4473
4474 k += 16;
4475
4476 __m512 vscaled0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0123456789ABCDEF);
4477 vscaled0123456789ABCDEF = _mm512_mul_ps(vscaled0123456789ABCDEF, vscale);
4478 vscaled0123456789ABCDEF = _mm512_min_ps(vscaled0123456789ABCDEF, voutput_max_less_zero_point);
4479 vacc0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0123456789ABCDEF);
4480
4481 w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
4482
4483 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0123456789ABCDEF, 1)), _mm512_castsi512_si256(voutput_zero_point));
4484
4485 const __m128i vout012389AB = _mm256_castsi256_si128(vout012389AB4567CDEF);
4486 const __m128i vout4567CDEF = _mm256_extracti128_si256(vout012389AB4567CDEF, 1);
4487 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packus_epi16(vout012389AB, vout4567CDEF), _MM_SHUFFLE(3, 1, 2, 0));
4488 vout0123456789ABCDEF = _mm_max_epu8(vout0123456789ABCDEF, _mm256_castsi256_si128(voutput_min));
4489
4490 if XNN_LIKELY(c >= 16) {
4491 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
4492 output += 16;
4493 c -= 16;
4494 } else {
4495 _mm_mask_storeu_epi8(output, vmask, vout0123456789ABCDEF);
4496 output = (uint8_t*) ((uintptr_t) output + c);
4497 c = 0;
4498 }
4499 } while (c != 0);
4500 }
4501
4502 output = (uint8_t*) ((uintptr_t) output + output_increment);
4503 } while (--output_width != 0);
4504 }
4505
xnn_qu8_f32_vcvt_ukernel__avx512skx_x32(size_t n,const uint8_t * x,float * y,const union xnn_qu8_f32_cvt_params params[restrict XNN_MIN_ELEMENTS (1)])4506 void xnn_qu8_f32_vcvt_ukernel__avx512skx_x32(
4507 size_t n,
4508 const uint8_t* x,
4509 float* y,
4510 const union xnn_qu8_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
4511 {
4512 assert(n != 0);
4513 assert(n % sizeof(uint8_t) == 0);
4514 assert(x != NULL);
4515 assert(y != NULL);
4516
4517 const __m512i vminus_zero_point = _mm512_load_si512(params->avx512.minus_zero_point);
4518 const __m512 vscale = _mm512_load_ps(params->avx512.scale);
4519 for (; n >= 32 * sizeof(uint8_t); n -= 32 * sizeof(uint8_t)) {
4520 __m512i vx0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) x));
4521 __m512i vxGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (x + 16)));
4522 x += 32;
4523
4524 vx0123456789ABCDEF = _mm512_add_epi32(vx0123456789ABCDEF, vminus_zero_point);
4525 vxGHIJKLMNOPQRSTUV = _mm512_add_epi32(vxGHIJKLMNOPQRSTUV, vminus_zero_point);
4526
4527 __m512 vy0123456789ABCDEF = _mm512_cvtepi32_ps(vx0123456789ABCDEF);
4528 __m512 vyGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vxGHIJKLMNOPQRSTUV);
4529
4530 vy0123456789ABCDEF = _mm512_mul_ps(vy0123456789ABCDEF, vscale);
4531 vyGHIJKLMNOPQRSTUV = _mm512_mul_ps(vyGHIJKLMNOPQRSTUV, vscale);
4532
4533 _mm512_storeu_ps(y, vy0123456789ABCDEF);
4534 _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
4535 y += 32;
4536 }
4537 for (; n >= 16 * sizeof(uint8_t); n -= 16 * sizeof(uint8_t)) {
4538 __m512i vx = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) x));
4539 vx = _mm512_add_epi32(vx, vminus_zero_point);
4540 x += 16;
4541
4542 __m512 vy = _mm512_cvtepi32_ps(vx);
4543 vy = _mm512_mul_ps(vy, vscale);
4544
4545 _mm512_storeu_ps(y, vy);
4546 y += 16;
4547 }
4548 if XNN_UNLIKELY(n != 0) {
4549 assert(n >= 1 * sizeof(uint8_t));
4550 assert(n <= 15 * sizeof(uint8_t));
4551
4552 // Prepare mask for valid elements (depends on n).
4553 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
4554
4555 __m512i vx = _mm512_cvtepu8_epi32(_mm_maskz_loadu_epi8(vmask, x));
4556 vx = _mm512_add_epi32(vx, vminus_zero_point);
4557
4558 __m512 vy = _mm512_cvtepi32_ps(vx);
4559 vy = _mm512_mul_ps(vy, vscale);
4560
4561 _mm512_mask_storeu_ps(y, vmask, vy);
4562 }
4563 }
4564
xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx(size_t mr,size_t nc,size_t kc,const uint8_t * restrict a,size_t a_stride,const void * restrict w,uint8_t * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])4565 void xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx(
4566 size_t mr,
4567 size_t nc,
4568 size_t kc,
4569 const uint8_t* restrict a,
4570 size_t a_stride,
4571 const void* restrict w,
4572 uint8_t* restrict c,
4573 size_t cm_stride,
4574 size_t cn_stride,
4575 const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
4576 {
4577 assert(mr != 0);
4578 assert(mr <= 1);
4579 assert(nc != 0);
4580 assert(kc != 0);
4581 assert(kc % sizeof(uint8_t) == 0);
4582 assert(a != NULL);
4583 assert(w != NULL);
4584 assert(c != NULL);
4585
4586 kc = round_up_po2(kc, 8);
4587 const uint8_t* a0 = a;
4588 uint8_t* c0 = c;
4589
4590 const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
4591 const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
4592 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
4593 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point);
4594 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min);
4595 do {
4596 __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
4597 __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
4598 __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
4599 __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
4600 w = (const void*) ((const int32_t*) w + 16);
4601
4602 size_t k = 0;
4603 const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point);
4604 while (k < kc) {
4605 const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
4606 a0 += 8;
4607
4608 const __m512i vb0123 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) w)), vb_zero_point);
4609
4610 vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
4611 const __m512i vb4567 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 32))), vb_zero_point);
4612
4613 vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
4614 const __m512i vb89AB = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 64))), vb_zero_point);
4615
4616 vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
4617 const __m512i vbCDEF = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 96))), vb_zero_point);
4618
4619 vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
4620
4621 w = (const void*) ((const uint8_t*) w + 128);
4622 k += 8 * sizeof(uint8_t);
4623 }
4624
4625 const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
4626 const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
4627
4628 __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
4629
4630 __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
4631
4632 vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale);
4633
4634 vscaled0x084C195D2A6E3B7F = _mm512_min_ps(vscaled0x084C195D2A6E3B7F, voutput_max_less_zero_point);
4635
4636 vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
4637
4638 const __m256i vacc0x084C2A6E195D3B7F = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0x084C195D2A6E3B7F), _mm512_extracti32x8_epi32(vacc0x084C195D2A6E3B7F, 1)), voutput_zero_point);
4639
4640 const __m128i vout0x084C2A6E195D3B7F = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x084C2A6E195D3B7F), _mm256_extracti128_si256(vacc0x084C2A6E195D3B7F, 1));
4641 __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x084C2A6E195D3B7F, _mm_set_epi8(15, 7, 11, 3, 13, 5, 9, 1, 14, 6, 10, 2, 12, 4, 8, 0));
4642 vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min);
4643
4644 if (nc >= 16) {
4645 _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF);
4646
4647 a0 = (const uint8_t*) ((uintptr_t) a0 - k);
4648
4649 c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride);
4650
4651 nc -= 16;
4652 } else {
4653 // Prepare mask for valid 8-bit elements (depends on nc).
4654 const __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT32_C(1) << nc) - UINT32_C(1)));
4655
4656 _mm_mask_storeu_epi8(c0, vmask, vout0x0123456789ABCDEF);
4657
4658 nc = 0;
4659 }
4660 } while (nc != 0);
4661 }
4662
xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx(size_t mr,size_t nc,size_t kc,const uint8_t * restrict a,size_t a_stride,const void * restrict w,uint8_t * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])4663 void xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx(
4664 size_t mr,
4665 size_t nc,
4666 size_t kc,
4667 const uint8_t* restrict a,
4668 size_t a_stride,
4669 const void* restrict w,
4670 uint8_t* restrict c,
4671 size_t cm_stride,
4672 size_t cn_stride,
4673 const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
4674 {
4675 assert(mr != 0);
4676 assert(mr <= 4);
4677 assert(nc != 0);
4678 assert(kc != 0);
4679 assert(kc % sizeof(uint8_t) == 0);
4680 assert(a != NULL);
4681 assert(w != NULL);
4682 assert(c != NULL);
4683
4684 kc = round_up_po2(kc, 8);
4685 const uint8_t* a0 = a;
4686 uint8_t* c0 = c;
4687 const uint8_t* a1 = (const uint8_t*) ((uintptr_t) a0 + a_stride);
4688 uint8_t* c1 = (uint8_t*) ((uintptr_t) c0 + cm_stride);
4689 if XNN_UNPREDICTABLE(mr < 2) {
4690 a1 = a0;
4691 c1 = c0;
4692 }
4693 const uint8_t* a2 = (const uint8_t*) ((uintptr_t) a1 + a_stride);
4694 uint8_t* c2 = (uint8_t*) ((uintptr_t) c1 + cm_stride);
4695 if XNN_UNPREDICTABLE(mr <= 2) {
4696 a2 = a1;
4697 c2 = c1;
4698 }
4699 const uint8_t* a3 = (const uint8_t*) ((uintptr_t) a2 + a_stride);
4700 uint8_t* c3 = (uint8_t*) ((uintptr_t) c2 + cm_stride);
4701 if XNN_UNPREDICTABLE(mr != 4) {
4702 a3 = a2;
4703 c3 = c2;
4704 }
4705
4706 const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
4707 const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
4708 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
4709 const __m512i voutput_zero_point = _mm512_load_si512(params->fp32_avx512.output_zero_point);
4710 const __m512i voutput_min = _mm512_load_si512(params->fp32_avx512.output_min);
4711 do {
4712 __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
4713 __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
4714 __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
4715 __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
4716 __m512i vacc1x0123 = vacc0x0123;
4717 __m512i vacc1x4567 = vacc0x4567;
4718 __m512i vacc1x89AB = vacc0x89AB;
4719 __m512i vacc1xCDEF = vacc0xCDEF;
4720 __m512i vacc2x0123 = vacc0x0123;
4721 __m512i vacc2x4567 = vacc0x4567;
4722 __m512i vacc2x89AB = vacc0x89AB;
4723 __m512i vacc2xCDEF = vacc0xCDEF;
4724 __m512i vacc3x0123 = vacc0x0123;
4725 __m512i vacc3x4567 = vacc0x4567;
4726 __m512i vacc3x89AB = vacc0x89AB;
4727 __m512i vacc3xCDEF = vacc0xCDEF;
4728 w = (const void*) ((const int32_t*) w + 16);
4729
4730 size_t k = 0;
4731 const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point);
4732 while (k < kc) {
4733 const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
4734 a0 += 8;
4735 const __m512i va1 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a1)));
4736 a1 += 8;
4737 const __m512i va2 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a2)));
4738 a2 += 8;
4739 const __m512i va3 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a3)));
4740 a3 += 8;
4741
4742 const __m512i vb0123 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) w)), vb_zero_point);
4743
4744 vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
4745 vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
4746 vacc2x0123 = _mm512_add_epi32(vacc2x0123, _mm512_madd_epi16(va2, vb0123));
4747 vacc3x0123 = _mm512_add_epi32(vacc3x0123, _mm512_madd_epi16(va3, vb0123));
4748 const __m512i vb4567 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 32))), vb_zero_point);
4749
4750 vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
4751 vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
4752 vacc2x4567 = _mm512_add_epi32(vacc2x4567, _mm512_madd_epi16(va2, vb4567));
4753 vacc3x4567 = _mm512_add_epi32(vacc3x4567, _mm512_madd_epi16(va3, vb4567));
4754 const __m512i vb89AB = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 64))), vb_zero_point);
4755
4756 vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
4757 vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
4758 vacc2x89AB = _mm512_add_epi32(vacc2x89AB, _mm512_madd_epi16(va2, vb89AB));
4759 vacc3x89AB = _mm512_add_epi32(vacc3x89AB, _mm512_madd_epi16(va3, vb89AB));
4760 const __m512i vbCDEF = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 96))), vb_zero_point);
4761
4762 vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
4763 vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
4764 vacc2xCDEF = _mm512_add_epi32(vacc2xCDEF, _mm512_madd_epi16(va2, vbCDEF));
4765 vacc3xCDEF = _mm512_add_epi32(vacc3xCDEF, _mm512_madd_epi16(va3, vbCDEF));
4766
4767 w = (const void*) ((const uint8_t*) w + 128);
4768 k += 8 * sizeof(uint8_t);
4769 }
4770
4771 const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
4772 const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
4773 const __m512i vacc1x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x0123, vacc1x4567), _mm512_unpackhi_epi32(vacc1x0123, vacc1x4567));
4774 const __m512i vacc1x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x89AB, vacc1xCDEF), _mm512_unpackhi_epi32(vacc1x89AB, vacc1xCDEF));
4775 const __m512i vacc2x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x0123, vacc2x4567), _mm512_unpackhi_epi32(vacc2x0123, vacc2x4567));
4776 const __m512i vacc2x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x89AB, vacc2xCDEF), _mm512_unpackhi_epi32(vacc2x89AB, vacc2xCDEF));
4777 const __m512i vacc3x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x0123, vacc3x4567), _mm512_unpackhi_epi32(vacc3x0123, vacc3x4567));
4778 const __m512i vacc3x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x89AB, vacc3xCDEF), _mm512_unpackhi_epi32(vacc3x89AB, vacc3xCDEF));
4779
4780 __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
4781 __m512i vacc1x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x04152637, vacc1x8C9DAEBF), _mm512_unpackhi_epi32(vacc1x04152637, vacc1x8C9DAEBF));
4782 __m512i vacc2x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x04152637, vacc2x8C9DAEBF), _mm512_unpackhi_epi32(vacc2x04152637, vacc2x8C9DAEBF));
4783 __m512i vacc3x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x04152637, vacc3x8C9DAEBF), _mm512_unpackhi_epi32(vacc3x04152637, vacc3x8C9DAEBF));
4784
4785 __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
4786 __m512 vscaled1x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc1x084C195D2A6E3B7F);
4787 __m512 vscaled2x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc2x084C195D2A6E3B7F);
4788 __m512 vscaled3x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc3x084C195D2A6E3B7F);
4789
4790 vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale);
4791 vscaled1x084C195D2A6E3B7F = _mm512_mul_ps(vscaled1x084C195D2A6E3B7F, vscale);
4792 vscaled2x084C195D2A6E3B7F = _mm512_mul_ps(vscaled2x084C195D2A6E3B7F, vscale);
4793 vscaled3x084C195D2A6E3B7F = _mm512_mul_ps(vscaled3x084C195D2A6E3B7F, vscale);
4794
4795 vscaled0x084C195D2A6E3B7F = _mm512_min_ps(vscaled0x084C195D2A6E3B7F, voutput_max_less_zero_point);
4796 vscaled1x084C195D2A6E3B7F = _mm512_min_ps(vscaled1x084C195D2A6E3B7F, voutput_max_less_zero_point);
4797 vscaled2x084C195D2A6E3B7F = _mm512_min_ps(vscaled2x084C195D2A6E3B7F, voutput_max_less_zero_point);
4798 vscaled3x084C195D2A6E3B7F = _mm512_min_ps(vscaled3x084C195D2A6E3B7F, voutput_max_less_zero_point);
4799
4800 vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
4801 vacc1x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled1x084C195D2A6E3B7F);
4802 vacc2x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled2x084C195D2A6E3B7F);
4803 vacc3x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled3x084C195D2A6E3B7F);
4804
4805 const __m512i vacc01x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc0x084C195D2A6E3B7F, vacc1x084C195D2A6E3B7F), voutput_zero_point);
4806 const __m512i vacc23x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc2x084C195D2A6E3B7F, vacc3x084C195D2A6E3B7F), voutput_zero_point);
4807
4808 __m512i vout0123x084Cx195Dx2A6Ex3B7F = _mm512_packus_epi16(vacc01x084Cx195Dx2A6Ex3B7F, vacc23x084Cx195Dx2A6Ex3B7F);
4809 vout0123x084Cx195Dx2A6Ex3B7F = _mm512_permutexvar_epi32(_mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0), vout0123x084Cx195Dx2A6Ex3B7F);
4810 __m512i vout0123x0123456789ABCDEF = _mm512_shuffle_epi8(vout0123x084Cx195Dx2A6Ex3B7F, _mm512_set_epi8(15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0));
4811 vout0123x0123456789ABCDEF = _mm512_max_epu8(vout0123x0123456789ABCDEF, voutput_min);
4812
4813 if (nc >= 16) {
4814 _mm_storeu_si128((__m128i*) c0, _mm512_castsi512_si128(vout0123x0123456789ABCDEF));
4815 _mm_storeu_si128((__m128i*) c1, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 1));
4816 _mm_storeu_si128((__m128i*) c2, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 2));
4817 _mm_storeu_si128((__m128i*) c3, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 3));
4818
4819 a0 = (const uint8_t*) ((uintptr_t) a0 - k);
4820 a1 = (const uint8_t*) ((uintptr_t) a1 - k);
4821 a2 = (const uint8_t*) ((uintptr_t) a2 - k);
4822 a3 = (const uint8_t*) ((uintptr_t) a3 - k);
4823
4824 c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride);
4825 c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride);
4826 c2 = (uint8_t*) ((uintptr_t) c2 + cn_stride);
4827 c3 = (uint8_t*) ((uintptr_t) c3 + cn_stride);
4828
4829 nc -= 16;
4830 } else {
4831 // Prepare mask for valid 8-bit elements (depends on nc).
4832 __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT32_C(1) << nc) - UINT32_C(1)));
4833
4834 _mm512_mask_storeu_epi8(c0, vmask, vout0123x0123456789ABCDEF);
4835 vmask = _kshiftli_mask64(vmask, 16);
4836 _mm512_mask_storeu_epi8(c1 - 16, vmask, vout0123x0123456789ABCDEF);
4837 vmask = _kshiftli_mask64(vmask, 16);
4838 _mm512_mask_storeu_epi8(c2 - 32, vmask, vout0123x0123456789ABCDEF);
4839 vmask = _kshiftli_mask64(vmask, 16);
4840 _mm512_mask_storeu_epi8(c3 - 48, vmask, vout0123x0123456789ABCDEF);
4841
4842 nc = 0;
4843 }
4844 } while (nc != 0);
4845 }
4846
xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx(size_t mr,size_t nc,size_t kc,size_t ks,const uint8_t ** restrict a,const void * restrict w,uint8_t * restrict c,size_t cm_stride,size_t cn_stride,size_t a_offset,const uint8_t * zero,const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])4847 void xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx(
4848 size_t mr,
4849 size_t nc,
4850 size_t kc,
4851 size_t ks,
4852 const uint8_t** restrict a,
4853 const void* restrict w,
4854 uint8_t* restrict c,
4855 size_t cm_stride,
4856 size_t cn_stride,
4857 size_t a_offset,
4858 const uint8_t* zero,
4859 const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
4860 {
4861 assert(mr != 0);
4862 assert(mr <= 1);
4863 assert(nc != 0);
4864 assert(kc != 0);
4865 assert(kc % sizeof(uint8_t) == 0);
4866 assert(a != NULL);
4867 assert(w != NULL);
4868 assert(c != NULL);
4869
4870 kc = round_up_po2(kc, 8);
4871 uint8_t* c0 = c;
4872
4873 const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
4874 const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
4875 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
4876 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point);
4877 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min);
4878 do {
4879 __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
4880 __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
4881 __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
4882 __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
4883 w = (const void*) ((const int32_t*) w + 16);
4884
4885 size_t p = ks;
4886 do {
4887 const uint8_t* restrict a0 = a[0];
4888 if XNN_UNPREDICTABLE(a0 != zero) {
4889 a0 = (const uint8_t*) ((uintptr_t) a0 + a_offset);
4890 }
4891 a += 1;
4892
4893 size_t k = 0;
4894 const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point);
4895 while (k < kc) {
4896 const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
4897 a0 += 8;
4898
4899 const __m512i vb0123 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) w)), vb_zero_point);
4900
4901 vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
4902 const __m512i vb4567 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 32))), vb_zero_point);
4903
4904 vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
4905 const __m512i vb89AB = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 64))), vb_zero_point);
4906
4907 vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
4908 const __m512i vbCDEF = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 96))), vb_zero_point);
4909
4910 vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
4911
4912 w = (const void*) ((const uint8_t*) w + 128);
4913 k += 8 * sizeof(uint8_t);
4914 }
4915 p -= 1 * sizeof(void*);
4916 } while (p != 0);
4917
4918 const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
4919 const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
4920
4921 __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
4922
4923 __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
4924
4925 vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale);
4926
4927 vscaled0x084C195D2A6E3B7F = _mm512_min_ps(vscaled0x084C195D2A6E3B7F, voutput_max_less_zero_point);
4928
4929 vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
4930
4931 const __m256i vacc0x084C2A6E195D3B7F = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0x084C195D2A6E3B7F), _mm512_extracti32x8_epi32(vacc0x084C195D2A6E3B7F, 1)), voutput_zero_point);
4932
4933 const __m128i vout0x084C2A6E195D3B7F = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x084C2A6E195D3B7F), _mm256_extracti128_si256(vacc0x084C2A6E195D3B7F, 1));
4934 __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x084C2A6E195D3B7F, _mm_set_epi8(15, 7, 11, 3, 13, 5, 9, 1, 14, 6, 10, 2, 12, 4, 8, 0));
4935 vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min);
4936
4937 if (nc >= 16) {
4938 _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF);
4939
4940 c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride);
4941
4942 a = (const uint8_t**restrict) ((uintptr_t) a - ks);
4943
4944 nc -= 16;
4945 } else {
4946 // Prepare mask for valid 8-bit elements (depends on nc).
4947 const __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT32_C(1) << nc) - UINT32_C(1)));
4948
4949 _mm_mask_storeu_epi8(c0, vmask, vout0x0123456789ABCDEF);
4950
4951 nc = 0;
4952 }
4953 } while (nc != 0);
4954 }
4955
xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx(size_t mr,size_t nc,size_t kc,size_t ks,const uint8_t ** restrict a,const void * restrict w,uint8_t * restrict c,size_t cm_stride,size_t cn_stride,size_t a_offset,const uint8_t * zero,const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])4956 void xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx(
4957 size_t mr,
4958 size_t nc,
4959 size_t kc,
4960 size_t ks,
4961 const uint8_t** restrict a,
4962 const void* restrict w,
4963 uint8_t* restrict c,
4964 size_t cm_stride,
4965 size_t cn_stride,
4966 size_t a_offset,
4967 const uint8_t* zero,
4968 const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
4969 {
4970 assert(mr != 0);
4971 assert(mr <= 4);
4972 assert(nc != 0);
4973 assert(kc != 0);
4974 assert(kc % sizeof(uint8_t) == 0);
4975 assert(a != NULL);
4976 assert(w != NULL);
4977 assert(c != NULL);
4978
4979 kc = round_up_po2(kc, 8);
4980 uint8_t* c0 = c;
4981 uint8_t* c1 = (uint8_t*) ((uintptr_t) c0 + cm_stride);
4982 if XNN_UNPREDICTABLE(mr < 2) {
4983 c1 = c0;
4984 }
4985 uint8_t* c2 = (uint8_t*) ((uintptr_t) c1 + cm_stride);
4986 if XNN_UNPREDICTABLE(mr <= 2) {
4987 c2 = c1;
4988 }
4989 uint8_t* c3 = (uint8_t*) ((uintptr_t) c2 + cm_stride);
4990 if XNN_UNPREDICTABLE(mr != 4) {
4991 c3 = c2;
4992 }
4993
4994 const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
4995 const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
4996 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
4997 const __m512i voutput_zero_point = _mm512_load_si512(params->fp32_avx512.output_zero_point);
4998 const __m512i voutput_min = _mm512_load_si512(params->fp32_avx512.output_min);
4999 do {
5000 __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
5001 __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
5002 __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
5003 __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
5004 __m512i vacc1x0123 = vacc0x0123;
5005 __m512i vacc1x4567 = vacc0x4567;
5006 __m512i vacc1x89AB = vacc0x89AB;
5007 __m512i vacc1xCDEF = vacc0xCDEF;
5008 __m512i vacc2x0123 = vacc0x0123;
5009 __m512i vacc2x4567 = vacc0x4567;
5010 __m512i vacc2x89AB = vacc0x89AB;
5011 __m512i vacc2xCDEF = vacc0xCDEF;
5012 __m512i vacc3x0123 = vacc0x0123;
5013 __m512i vacc3x4567 = vacc0x4567;
5014 __m512i vacc3x89AB = vacc0x89AB;
5015 __m512i vacc3xCDEF = vacc0xCDEF;
5016 w = (const void*) ((const int32_t*) w + 16);
5017
5018 size_t p = ks;
5019 do {
5020 const uint8_t* restrict a0 = a[0];
5021 if XNN_UNPREDICTABLE(a0 != zero) {
5022 a0 = (const uint8_t*) ((uintptr_t) a0 + a_offset);
5023 }
5024 const uint8_t* restrict a1 = a[1];
5025 if XNN_UNPREDICTABLE(a1 != zero) {
5026 a1 = (const uint8_t*) ((uintptr_t) a1 + a_offset);
5027 }
5028 const uint8_t* restrict a2 = a[2];
5029 if XNN_UNPREDICTABLE(a2 != zero) {
5030 a2 = (const uint8_t*) ((uintptr_t) a2 + a_offset);
5031 }
5032 const uint8_t* restrict a3 = a[3];
5033 if XNN_UNPREDICTABLE(a3 != zero) {
5034 a3 = (const uint8_t*) ((uintptr_t) a3 + a_offset);
5035 }
5036 a += 4;
5037
5038 size_t k = 0;
5039 const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point);
5040 while (k < kc) {
5041 const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
5042 a0 += 8;
5043 const __m512i va1 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a1)));
5044 a1 += 8;
5045 const __m512i va2 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a2)));
5046 a2 += 8;
5047 const __m512i va3 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a3)));
5048 a3 += 8;
5049
5050 const __m512i vb0123 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) w)), vb_zero_point);
5051
5052 vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
5053 vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
5054 vacc2x0123 = _mm512_add_epi32(vacc2x0123, _mm512_madd_epi16(va2, vb0123));
5055 vacc3x0123 = _mm512_add_epi32(vacc3x0123, _mm512_madd_epi16(va3, vb0123));
5056 const __m512i vb4567 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 32))), vb_zero_point);
5057
5058 vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
5059 vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
5060 vacc2x4567 = _mm512_add_epi32(vacc2x4567, _mm512_madd_epi16(va2, vb4567));
5061 vacc3x4567 = _mm512_add_epi32(vacc3x4567, _mm512_madd_epi16(va3, vb4567));
5062 const __m512i vb89AB = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 64))), vb_zero_point);
5063
5064 vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
5065 vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
5066 vacc2x89AB = _mm512_add_epi32(vacc2x89AB, _mm512_madd_epi16(va2, vb89AB));
5067 vacc3x89AB = _mm512_add_epi32(vacc3x89AB, _mm512_madd_epi16(va3, vb89AB));
5068 const __m512i vbCDEF = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 96))), vb_zero_point);
5069
5070 vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
5071 vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
5072 vacc2xCDEF = _mm512_add_epi32(vacc2xCDEF, _mm512_madd_epi16(va2, vbCDEF));
5073 vacc3xCDEF = _mm512_add_epi32(vacc3xCDEF, _mm512_madd_epi16(va3, vbCDEF));
5074
5075 w = (const void*) ((const uint8_t*) w + 128);
5076 k += 8 * sizeof(uint8_t);
5077 }
5078 p -= 4 * sizeof(void*);
5079 } while (p != 0);
5080
5081 const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
5082 const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
5083 const __m512i vacc1x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x0123, vacc1x4567), _mm512_unpackhi_epi32(vacc1x0123, vacc1x4567));
5084 const __m512i vacc1x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x89AB, vacc1xCDEF), _mm512_unpackhi_epi32(vacc1x89AB, vacc1xCDEF));
5085 const __m512i vacc2x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x0123, vacc2x4567), _mm512_unpackhi_epi32(vacc2x0123, vacc2x4567));
5086 const __m512i vacc2x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x89AB, vacc2xCDEF), _mm512_unpackhi_epi32(vacc2x89AB, vacc2xCDEF));
5087 const __m512i vacc3x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x0123, vacc3x4567), _mm512_unpackhi_epi32(vacc3x0123, vacc3x4567));
5088 const __m512i vacc3x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x89AB, vacc3xCDEF), _mm512_unpackhi_epi32(vacc3x89AB, vacc3xCDEF));
5089
5090 __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
5091 __m512i vacc1x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x04152637, vacc1x8C9DAEBF), _mm512_unpackhi_epi32(vacc1x04152637, vacc1x8C9DAEBF));
5092 __m512i vacc2x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x04152637, vacc2x8C9DAEBF), _mm512_unpackhi_epi32(vacc2x04152637, vacc2x8C9DAEBF));
5093 __m512i vacc3x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x04152637, vacc3x8C9DAEBF), _mm512_unpackhi_epi32(vacc3x04152637, vacc3x8C9DAEBF));
5094
5095 __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
5096 __m512 vscaled1x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc1x084C195D2A6E3B7F);
5097 __m512 vscaled2x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc2x084C195D2A6E3B7F);
5098 __m512 vscaled3x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc3x084C195D2A6E3B7F);
5099
5100 vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale);
5101 vscaled1x084C195D2A6E3B7F = _mm512_mul_ps(vscaled1x084C195D2A6E3B7F, vscale);
5102 vscaled2x084C195D2A6E3B7F = _mm512_mul_ps(vscaled2x084C195D2A6E3B7F, vscale);
5103 vscaled3x084C195D2A6E3B7F = _mm512_mul_ps(vscaled3x084C195D2A6E3B7F, vscale);
5104
5105 vscaled0x084C195D2A6E3B7F = _mm512_min_ps(vscaled0x084C195D2A6E3B7F, voutput_max_less_zero_point);
5106 vscaled1x084C195D2A6E3B7F = _mm512_min_ps(vscaled1x084C195D2A6E3B7F, voutput_max_less_zero_point);
5107 vscaled2x084C195D2A6E3B7F = _mm512_min_ps(vscaled2x084C195D2A6E3B7F, voutput_max_less_zero_point);
5108 vscaled3x084C195D2A6E3B7F = _mm512_min_ps(vscaled3x084C195D2A6E3B7F, voutput_max_less_zero_point);
5109
5110 vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
5111 vacc1x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled1x084C195D2A6E3B7F);
5112 vacc2x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled2x084C195D2A6E3B7F);
5113 vacc3x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled3x084C195D2A6E3B7F);
5114
5115 const __m512i vacc01x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc0x084C195D2A6E3B7F, vacc1x084C195D2A6E3B7F), voutput_zero_point);
5116 const __m512i vacc23x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc2x084C195D2A6E3B7F, vacc3x084C195D2A6E3B7F), voutput_zero_point);
5117
5118 __m512i vout0123x084Cx195Dx2A6Ex3B7F = _mm512_packus_epi16(vacc01x084Cx195Dx2A6Ex3B7F, vacc23x084Cx195Dx2A6Ex3B7F);
5119 vout0123x084Cx195Dx2A6Ex3B7F = _mm512_permutexvar_epi32(_mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0), vout0123x084Cx195Dx2A6Ex3B7F);
5120 __m512i vout0123x0123456789ABCDEF = _mm512_shuffle_epi8(vout0123x084Cx195Dx2A6Ex3B7F, _mm512_set_epi8(15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0));
5121 vout0123x0123456789ABCDEF = _mm512_max_epu8(vout0123x0123456789ABCDEF, voutput_min);
5122
5123 if (nc >= 16) {
5124 _mm_storeu_si128((__m128i*) c3, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 3));
5125 _mm_storeu_si128((__m128i*) c2, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 2));
5126 _mm_storeu_si128((__m128i*) c1, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 1));
5127 _mm_storeu_si128((__m128i*) c0, _mm512_castsi512_si128(vout0123x0123456789ABCDEF));
5128
5129 c3 = (uint8_t*) ((uintptr_t) c3 + cn_stride);
5130 c2 = (uint8_t*) ((uintptr_t) c2 + cn_stride);
5131 c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride);
5132 c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride);
5133
5134 a = (const uint8_t**restrict) ((uintptr_t) a - ks);
5135
5136 nc -= 16;
5137 } else {
5138 // Prepare mask for valid 8-bit elements (depends on nc).
5139 __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT64_C(1) << (nc + 48)) - (UINT64_C(1) << 48)));
5140
5141 _mm512_mask_storeu_epi8(c3 - 48, vmask, vout0123x0123456789ABCDEF);
5142 vmask = _kshiftri_mask64(vmask, 16);
5143 _mm512_mask_storeu_epi8(c2 - 32, vmask, vout0123x0123456789ABCDEF);
5144 vmask = _kshiftri_mask64(vmask, 16);
5145 _mm512_mask_storeu_epi8(c1 - 16, vmask, vout0123x0123456789ABCDEF);
5146 vmask = _kshiftri_mask64(vmask, 16);
5147 _mm512_mask_storeu_epi8(c0, vmask, vout0123x0123456789ABCDEF);
5148
5149 nc = 0;
5150 }
5151 } while (nc != 0);
5152 }
5153
xnn_qu8_vadd_minmax_ukernel__avx512skx_mul32_ld128_x16(size_t n,const uint8_t * input_a,const uint8_t * input_b,uint8_t * output,const union xnn_qu8_add_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])5154 void xnn_qu8_vadd_minmax_ukernel__avx512skx_mul32_ld128_x16(
5155 size_t n,
5156 const uint8_t* input_a,
5157 const uint8_t* input_b,
5158 uint8_t* output,
5159 const union xnn_qu8_add_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
5160 {
5161 const __m512i vbias = _mm512_load_si512(params->avx512.bias);
5162 const __m512i va_multiplier = _mm512_load_si512(params->avx512.a_multiplier);
5163 const __m512i vb_multiplier = _mm512_load_si512(params->avx512.b_multiplier);
5164 const __m128i vshift = _mm_load_si128((const __m128i*) params->avx512.shift);
5165 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->avx512.output_zero_point);
5166 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->avx512.output_min);
5167 const __m128i voutput_max = _mm_load_si128((const __m128i*) params->avx512.output_max);
5168
5169 for (; n >= 16 * sizeof(uint8_t); n -= 16 * sizeof(uint8_t)) {
5170 const __m512i va0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) input_a));
5171 const __m512i vb0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) input_b));
5172 input_a += 16;
5173 input_b += 16;
5174
5175 __m512i vacc0123456789ABCDEF = _mm512_add_epi32(vbias, _mm512_mullo_epi32(va0123456789ABCDEF, va_multiplier));
5176
5177 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vb0123456789ABCDEF, vb_multiplier));
5178
5179 vacc0123456789ABCDEF = _mm512_sra_epi32(vacc0123456789ABCDEF, vshift);
5180
5181 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0123456789ABCDEF, 1)), voutput_zero_point);
5182
5183 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packus_epi16(_mm256_castsi256_si128(vout012389AB4567CDEF), _mm256_extracti128_si256(vout012389AB4567CDEF, 1)), _MM_SHUFFLE(3, 1, 2, 0));
5184
5185 vout0123456789ABCDEF = _mm_max_epu8(vout0123456789ABCDEF, voutput_min);
5186
5187 vout0123456789ABCDEF = _mm_min_epu8(vout0123456789ABCDEF, voutput_max);
5188
5189 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
5190 output += 16;
5191 }
5192 if XNN_UNLIKELY(n != 0) {
5193 {
5194 const __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << n) - UINT32_C(1)));
5195 const __m512i va0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_maskz_loadu_epi8(vmask, input_a));
5196 const __m512i vb0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_maskz_loadu_epi8(vmask, input_b));
5197
5198 __m512i vacc0123456789ABCDEF = _mm512_add_epi32(vbias, _mm512_mullo_epi32(va0123456789ABCDEF, va_multiplier));
5199
5200 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vb0123456789ABCDEF, vb_multiplier));
5201
5202 vacc0123456789ABCDEF = _mm512_sra_epi32(vacc0123456789ABCDEF, vshift);
5203
5204 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0123456789ABCDEF, 1)), voutput_zero_point);
5205 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packus_epi16(_mm256_castsi256_si128(vout012389AB4567CDEF), _mm256_extracti128_si256(vout012389AB4567CDEF, 1)), _MM_SHUFFLE(3, 1, 2, 0));
5206 vout0123456789ABCDEF = _mm_max_epu8(vout0123456789ABCDEF, voutput_min);
5207 vout0123456789ABCDEF = _mm_min_epu8(vout0123456789ABCDEF, voutput_max);
5208
5209 _mm_mask_storeu_epi8(output, vmask, vout0123456789ABCDEF);
5210 }
5211 }
5212 }
5213
xnn_qu8_vaddc_minmax_ukernel__avx512skx_mul32_ld128_x16(size_t n,const uint8_t * input_a,const uint8_t * input_b,uint8_t * output,const union xnn_qu8_add_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])5214 void xnn_qu8_vaddc_minmax_ukernel__avx512skx_mul32_ld128_x16(
5215 size_t n,
5216 const uint8_t* input_a,
5217 const uint8_t* input_b,
5218 uint8_t* output,
5219 const union xnn_qu8_add_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
5220 {
5221 const __m512i va_multiplier = _mm512_load_si512(params->avx512.a_multiplier);
5222 const __m128i vshift = _mm_load_si128((const __m128i*) params->avx512.shift);
5223 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->avx512.output_zero_point);
5224 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->avx512.output_min);
5225 const __m128i voutput_max = _mm_load_si128((const __m128i*) params->avx512.output_max);
5226
5227 const __m512i vbias = _mm512_add_epi32(
5228 _mm512_broadcastd_epi32(_mm_cvtsi32_si128(params->avx512.b_multiplier[0] * (int32_t) *input_b)),
5229 _mm512_load_si512(params->avx512.bias));
5230 for (; n >= 16 * sizeof(uint8_t); n -= 16 * sizeof(uint8_t)) {
5231 const __m512i va0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) input_a));
5232 input_a += 16;
5233
5234 __m512i vacc0123456789ABCDEF = _mm512_add_epi32(vbias, _mm512_mullo_epi32(va0123456789ABCDEF, va_multiplier));
5235
5236 vacc0123456789ABCDEF = _mm512_sra_epi32(vacc0123456789ABCDEF, vshift);
5237
5238 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0123456789ABCDEF, 1)), voutput_zero_point);
5239
5240 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packus_epi16(_mm256_castsi256_si128(vout012389AB4567CDEF), _mm256_extracti128_si256(vout012389AB4567CDEF, 1)), _MM_SHUFFLE(3, 1, 2, 0));
5241
5242 vout0123456789ABCDEF = _mm_max_epu8(vout0123456789ABCDEF, voutput_min);
5243
5244 vout0123456789ABCDEF = _mm_min_epu8(vout0123456789ABCDEF, voutput_max);
5245
5246 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
5247 output += 16;
5248 }
5249 if XNN_UNLIKELY(n != 0) {
5250 {
5251 const __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << n) - UINT32_C(1)));
5252 const __m512i va0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_maskz_loadu_epi8(vmask, input_a));
5253
5254 __m512i vacc0123456789ABCDEF = _mm512_add_epi32(vbias, _mm512_mullo_epi32(va0123456789ABCDEF, va_multiplier));
5255
5256 vacc0123456789ABCDEF = _mm512_sra_epi32(vacc0123456789ABCDEF, vshift);
5257
5258 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0123456789ABCDEF, 1)), voutput_zero_point);
5259 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packus_epi16(_mm256_castsi256_si128(vout012389AB4567CDEF), _mm256_extracti128_si256(vout012389AB4567CDEF, 1)), _MM_SHUFFLE(3, 1, 2, 0));
5260 vout0123456789ABCDEF = _mm_max_epu8(vout0123456789ABCDEF, voutput_min);
5261 vout0123456789ABCDEF = _mm_min_epu8(vout0123456789ABCDEF, voutput_max);
5262
5263 _mm_mask_storeu_epi8(output, vmask, vout0123456789ABCDEF);
5264 }
5265 }
5266 }
5267
xnn_x8_lut_ukernel__avx512skx_vpshufb_x64(size_t n,const uint8_t * x,uint8_t * y,const uint8_t t[restrict XNN_MIN_ELEMENTS (256)])5268 void xnn_x8_lut_ukernel__avx512skx_vpshufb_x64(
5269 size_t n,
5270 const uint8_t* x,
5271 uint8_t* y,
5272 const uint8_t t[restrict XNN_MIN_ELEMENTS(256)])
5273 {
5274 assert(n != 0);
5275 assert(x != NULL);
5276 assert(y != NULL);
5277
5278 const __m512i vt0 = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) t));
5279 const __m512i vt1 = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) (t + 16)));
5280 const __m512i vt2 = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) (t + 32)));
5281 const __m512i vt3 = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) (t + 48)));
5282 const __m512i vt4 = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) (t + 64)));
5283 const __m512i vt5 = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) (t + 80)));
5284 const __m512i vt6 = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) (t + 96)));
5285 const __m512i vt7 = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) (t + 112)));
5286 const __m512i vt8 = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) (t + 128)));
5287 const __m512i vt9 = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) (t + 144)));
5288 const __m512i vtA = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) (t + 160)));
5289 const __m512i vtB = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) (t + 176)));
5290 const __m512i vtC = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) (t + 192)));
5291 const __m512i vtD = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) (t + 208)));
5292 const __m512i vtE = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) (t + 224)));
5293 const __m512i vtF = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) (t + 240)));
5294
5295 const __m512i vtable0 = vt0;
5296 const __m512i vtable1 = _mm512_xor_si512(vt0, vt1);
5297 const __m512i vtable2 = _mm512_xor_si512(vt1, vt2);
5298 const __m512i vtable3 = _mm512_xor_si512(vt2, vt3);
5299 const __m512i vtable4 = _mm512_xor_si512(vt3, vt4);
5300 const __m512i vtable5 = _mm512_xor_si512(vt4, vt5);
5301 const __m512i vtable6 = _mm512_xor_si512(vt5, vt6);
5302 const __m512i vtable7 = _mm512_xor_si512(vt6, vt7);
5303 const __m512i vtable8 = _mm512_xor_si512(_mm512_xor_si512(vt7, vt8), vtable0);
5304 const __m512i vtable9 = _mm512_xor_si512(_mm512_xor_si512(vt8, vt9), vtable1);
5305 const __m512i vtableA = _mm512_xor_si512(_mm512_xor_si512(vt9, vtA), vtable2);
5306 const __m512i vtableB = _mm512_xor_si512(_mm512_xor_si512(vtA, vtB), vtable3);
5307 const __m512i vtableC = _mm512_xor_si512(_mm512_xor_si512(vtB, vtC), vtable4);
5308 const __m512i vtableD = _mm512_xor_si512(_mm512_xor_si512(vtC, vtD), vtable5);
5309 const __m512i vtableE = _mm512_xor_si512(_mm512_xor_si512(vtD, vtE), vtable6);
5310 const __m512i vtableF = _mm512_xor_si512(_mm512_xor_si512(vtE, vtF), vtable7);
5311
5312 const __m512i voffset = _mm512_set1_epi8(16);
5313 for (; n >= 64 * sizeof(uint8_t); n -= 64 * sizeof(uint8_t)) {
5314 __m512i vx = _mm512_loadu_si512(x);
5315 x += 64;
5316
5317 __m512i vy = _mm512_shuffle_epi8(vtable0, vx);
5318
5319 vx = _mm512_sub_epi8(vx, voffset);
5320 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable1, vx));
5321 vx = _mm512_sub_epi8(vx, voffset);
5322 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable2, vx));
5323 vx = _mm512_sub_epi8(vx, voffset);
5324 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable3, vx));
5325 vx = _mm512_sub_epi8(vx, voffset);
5326 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable4, vx));
5327 vx = _mm512_sub_epi8(vx, voffset);
5328 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable5, vx));
5329 vx = _mm512_sub_epi8(vx, voffset);
5330 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable6, vx));
5331 vx = _mm512_sub_epi8(vx, voffset);
5332 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable7, vx));
5333 vx = _mm512_sub_epi8(vx, voffset);
5334 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable8, vx));
5335
5336 vx = _mm512_subs_epi8(vx, voffset);
5337 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable9, vx));
5338 vx = _mm512_subs_epi8(vx, voffset);
5339 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtableA, vx));
5340 vx = _mm512_subs_epi8(vx, voffset);
5341 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtableB, vx));
5342 vx = _mm512_subs_epi8(vx, voffset);
5343 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtableC, vx));
5344 vx = _mm512_subs_epi8(vx, voffset);
5345 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtableD, vx));
5346 vx = _mm512_subs_epi8(vx, voffset);
5347 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtableE, vx));
5348 vx = _mm512_subs_epi8(vx, voffset);
5349 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtableF, vx));
5350
5351 _mm512_storeu_si512(y, vy);
5352 y += 64;
5353 }
5354 if XNN_UNLIKELY(n != 0) {
5355 assert(n < 64);
5356 const __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT64_C(1) << n) - UINT64_C(1)));
5357
5358 __m512i vx = _mm512_maskz_loadu_epi8(vmask, x);
5359
5360 __m512i vy = _mm512_shuffle_epi8(vtable0, vx);
5361
5362 vx = _mm512_sub_epi8(vx, voffset);
5363 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable1, vx));
5364 vx = _mm512_sub_epi8(vx, voffset);
5365 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable2, vx));
5366 vx = _mm512_sub_epi8(vx, voffset);
5367 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable3, vx));
5368 vx = _mm512_sub_epi8(vx, voffset);
5369 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable4, vx));
5370 vx = _mm512_sub_epi8(vx, voffset);
5371 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable5, vx));
5372 vx = _mm512_sub_epi8(vx, voffset);
5373 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable6, vx));
5374 vx = _mm512_sub_epi8(vx, voffset);
5375 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable7, vx));
5376 vx = _mm512_sub_epi8(vx, voffset);
5377 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable8, vx));
5378
5379 vx = _mm512_subs_epi8(vx, voffset);
5380 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable9, vx));
5381 vx = _mm512_subs_epi8(vx, voffset);
5382 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtableA, vx));
5383 vx = _mm512_subs_epi8(vx, voffset);
5384 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtableB, vx));
5385 vx = _mm512_subs_epi8(vx, voffset);
5386 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtableC, vx));
5387 vx = _mm512_subs_epi8(vx, voffset);
5388 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtableD, vx));
5389 vx = _mm512_subs_epi8(vx, voffset);
5390 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtableE, vx));
5391 vx = _mm512_subs_epi8(vx, voffset);
5392 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtableF, vx));
5393
5394 _mm512_mask_storeu_epi8(y, vmask, vy);
5395 }
5396 }
5397