1 // Copyright 2021 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7
8 #include <immintrin.h>
9
10 #include <xnnpack/common.h>
11 #include <xnnpack/dwconv.h>
12 #include <xnnpack/gemm.h>
13 #include <xnnpack/ibilinear.h>
14 #include <xnnpack/igemm.h>
15 #include <xnnpack/intrinsics-polyfill.h>
16 #include <xnnpack/math.h>
17 #include <xnnpack/vmulcaddc.h>
18 #include <xnnpack/vunary.h>
19
20
xnn_f16_dwconv_minmax_ukernel_up16x3__fma3(size_t channels,size_t output_width,const void ** input,const void * weights,void * output,size_t input_stride,size_t output_increment,size_t input_offset,const void * zero,const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])21 void xnn_f16_dwconv_minmax_ukernel_up16x3__fma3(
22 size_t channels,
23 size_t output_width,
24 const void** input,
25 const void* weights,
26 void* output,
27 size_t input_stride,
28 size_t output_increment,
29 size_t input_offset,
30 const void* zero,
31 const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
32 {
33 assert(channels != 0);
34 assert(output_width != 0);
35
36 const __m256 vmax = _mm256_load_ps(params->avx.max);
37 const __m256 vmin = _mm256_load_ps(params->avx.min);
38
39 uint16_t* o = (uint16_t*) output;
40 do {
41 const uint16_t* i0 = input[0];
42 assert(i0 != NULL);
43 if XNN_UNPREDICTABLE(i0 != zero) {
44 i0 = (const uint16_t*) ((uintptr_t) i0 + input_offset);
45 }
46 const uint16_t* i1 = input[1];
47 assert(i1 != NULL);
48 if XNN_UNPREDICTABLE(i1 != zero) {
49 i1 = (const uint16_t*) ((uintptr_t) i1 + input_offset);
50 }
51 const uint16_t* i2 = input[2];
52 assert(i2 != NULL);
53 if XNN_UNPREDICTABLE(i2 != zero) {
54 i2 = (const uint16_t*) ((uintptr_t) i2 + input_offset);
55 }
56 input = (const void**) ((uintptr_t) input + input_stride);
57
58 size_t c = channels;
59 const uint16_t* w = weights;
60 for (; c >= 16; c -= 16) {
61 __m256 vacc01234567p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
62 __m256 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 8)));
63
64
65 const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
66 const __m256 vi0x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i0 + 8)));
67 i0 += 16;
68
69 const __m256 vk0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 16)));
70 const __m256 vk0x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 24)));
71 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
72 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi0x89ABCDEF, vk0x89ABCDEF, vacc89ABCDEFp0), _MM_FROUND_NO_EXC));
73
74 const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
75 const __m256 vi1x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i1 + 8)));
76 i1 += 16;
77
78 const __m256 vk1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 32)));
79 const __m256 vk1x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 40)));
80 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
81 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi1x89ABCDEF, vk1x89ABCDEF, vacc89ABCDEFp0), _MM_FROUND_NO_EXC));
82
83 const __m256 vi2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
84 const __m256 vi2x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i2 + 8)));
85 i2 += 16;
86
87 const __m256 vk2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 48)));
88 const __m256 vk2x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 56)));
89 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
90 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi2x89ABCDEF, vk2x89ABCDEF, vacc89ABCDEFp0), _MM_FROUND_NO_EXC));
91
92 w += 64;
93
94
95 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
96 __m256 vacc89ABCDEF = _mm256_max_ps(vacc89ABCDEFp0, vmin);
97 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
98 vacc89ABCDEF = _mm256_min_ps(vacc89ABCDEF, vmax);
99
100 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vacc01234567, _MM_FROUND_NO_EXC));
101 _mm_storeu_si128((__m128i*) (o + 8), _mm256_cvtps_ph(vacc89ABCDEF, _MM_FROUND_NO_EXC));
102 o += 16;
103 }
104 for (; c >= 8; c -= 8) {
105 __m256 vacc01234567p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
106
107 const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
108 i0 += 8;
109
110 const __m256 vk0x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 16)));
111 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
112
113 const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
114 i1 += 8;
115
116 const __m256 vk1x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 32)));
117 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
118
119 const __m256 vi2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
120 i2 += 8;
121
122 const __m256 vk2x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 48)));
123 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
124
125 w += 8;
126
127
128 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
129 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
130
131 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vacc01234567, _MM_FROUND_NO_EXC));
132 o += 8;
133 }
134 if XNN_UNLIKELY(c != 0) {
135 assert(c >= 1);
136 assert(c <= 7);
137
138 __m256 vacc01234567p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
139
140 const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
141
142 const __m256 vk0x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 16)));
143 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
144
145 const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
146
147 const __m256 vk1x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 32)));
148 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
149
150 const __m256 vi2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
151
152 const __m256 vk2x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 48)));
153 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
154
155
156 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
157 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
158
159 __m128i vh01234567 = _mm256_cvtps_ph(vacc01234567, _MM_FROUND_NO_EXC);
160 if (c & 4) {
161 _mm_storel_epi64((__m128i*) o, vh01234567);
162 vh01234567 = _mm_unpackhi_epi64(vh01234567, vh01234567);
163 o += 4;
164 }
165 if (c & 2) {
166 _mm_storeu_si32(o, vh01234567);
167 vh01234567 = _mm_srli_epi64(vh01234567, 32);
168 o += 2;
169 }
170 if (c & 1) {
171 *o = (uint16_t) _mm_extract_epi16(vh01234567, 0);
172 o += 1;
173 }
174 }
175
176 o = (uint16_t*) ((uintptr_t) o + output_increment);
177 } while (--output_width != 0);
178 }
179
xnn_f16_dwconv_minmax_ukernel_up16x4__fma3(size_t channels,size_t output_width,const void ** input,const void * weights,void * output,size_t input_stride,size_t output_increment,size_t input_offset,const void * zero,const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])180 void xnn_f16_dwconv_minmax_ukernel_up16x4__fma3(
181 size_t channels,
182 size_t output_width,
183 const void** input,
184 const void* weights,
185 void* output,
186 size_t input_stride,
187 size_t output_increment,
188 size_t input_offset,
189 const void* zero,
190 const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
191 {
192 assert(channels != 0);
193 assert(output_width != 0);
194
195 const __m256 vmax = _mm256_load_ps(params->avx.max);
196 const __m256 vmin = _mm256_load_ps(params->avx.min);
197
198 uint16_t* o = (uint16_t*) output;
199 do {
200 const uint16_t* i0 = input[0];
201 assert(i0 != NULL);
202 if XNN_UNPREDICTABLE(i0 != zero) {
203 i0 = (const uint16_t*) ((uintptr_t) i0 + input_offset);
204 }
205 const uint16_t* i1 = input[1];
206 assert(i1 != NULL);
207 if XNN_UNPREDICTABLE(i1 != zero) {
208 i1 = (const uint16_t*) ((uintptr_t) i1 + input_offset);
209 }
210 const uint16_t* i2 = input[2];
211 assert(i2 != NULL);
212 if XNN_UNPREDICTABLE(i2 != zero) {
213 i2 = (const uint16_t*) ((uintptr_t) i2 + input_offset);
214 }
215 const uint16_t* i3 = input[3];
216 assert(i3 != NULL);
217 if XNN_UNPREDICTABLE(i3 != zero) {
218 i3 = (const uint16_t*) ((uintptr_t) i3 + input_offset);
219 }
220 input = (const void**) ((uintptr_t) input + input_stride);
221
222 size_t c = channels;
223 const uint16_t* w = weights;
224 for (; c >= 16; c -= 16) {
225 __m256 vacc01234567p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
226 __m256 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 8)));
227
228
229 const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
230 const __m256 vi0x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i0 + 8)));
231 i0 += 16;
232
233 const __m256 vk0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 16)));
234 const __m256 vk0x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 24)));
235 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
236 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi0x89ABCDEF, vk0x89ABCDEF, vacc89ABCDEFp0), _MM_FROUND_NO_EXC));
237
238 const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
239 const __m256 vi1x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i1 + 8)));
240 i1 += 16;
241
242 const __m256 vk1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 32)));
243 const __m256 vk1x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 40)));
244 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
245 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi1x89ABCDEF, vk1x89ABCDEF, vacc89ABCDEFp0), _MM_FROUND_NO_EXC));
246
247 const __m256 vi2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
248 const __m256 vi2x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i2 + 8)));
249 i2 += 16;
250
251 const __m256 vk2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 48)));
252 const __m256 vk2x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 56)));
253 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
254 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi2x89ABCDEF, vk2x89ABCDEF, vacc89ABCDEFp0), _MM_FROUND_NO_EXC));
255
256 const __m256 vi3x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
257 const __m256 vi3x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i3 + 8)));
258 i3 += 16;
259
260 const __m256 vk3x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 64)));
261 const __m256 vk3x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 72)));
262 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
263 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi3x89ABCDEF, vk3x89ABCDEF, vacc89ABCDEFp0), _MM_FROUND_NO_EXC));
264
265 w += 80;
266
267
268 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
269 __m256 vacc89ABCDEF = _mm256_max_ps(vacc89ABCDEFp0, vmin);
270 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
271 vacc89ABCDEF = _mm256_min_ps(vacc89ABCDEF, vmax);
272
273 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vacc01234567, _MM_FROUND_NO_EXC));
274 _mm_storeu_si128((__m128i*) (o + 8), _mm256_cvtps_ph(vacc89ABCDEF, _MM_FROUND_NO_EXC));
275 o += 16;
276 }
277 for (; c >= 8; c -= 8) {
278 __m256 vacc01234567p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
279
280 const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
281 i0 += 8;
282
283 const __m256 vk0x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 16)));
284 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
285
286 const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
287 i1 += 8;
288
289 const __m256 vk1x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 32)));
290 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
291
292 const __m256 vi2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
293 i2 += 8;
294
295 const __m256 vk2x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 48)));
296 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
297
298 const __m256 vi3x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
299 i3 += 8;
300
301 const __m256 vk3x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 64)));
302 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
303
304 w += 8;
305
306
307 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
308 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
309
310 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vacc01234567, _MM_FROUND_NO_EXC));
311 o += 8;
312 }
313 if XNN_UNLIKELY(c != 0) {
314 assert(c >= 1);
315 assert(c <= 7);
316
317 __m256 vacc01234567p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
318
319 const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
320
321 const __m256 vk0x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 16)));
322 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
323
324 const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
325
326 const __m256 vk1x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 32)));
327 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
328
329 const __m256 vi2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
330
331 const __m256 vk2x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 48)));
332 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
333
334 const __m256 vi3x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
335
336 const __m256 vk3x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 64)));
337 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
338
339
340 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
341 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
342
343 __m128i vh01234567 = _mm256_cvtps_ph(vacc01234567, _MM_FROUND_NO_EXC);
344 if (c & 4) {
345 _mm_storel_epi64((__m128i*) o, vh01234567);
346 vh01234567 = _mm_unpackhi_epi64(vh01234567, vh01234567);
347 o += 4;
348 }
349 if (c & 2) {
350 _mm_storeu_si32(o, vh01234567);
351 vh01234567 = _mm_srli_epi64(vh01234567, 32);
352 o += 2;
353 }
354 if (c & 1) {
355 *o = (uint16_t) _mm_extract_epi16(vh01234567, 0);
356 o += 1;
357 }
358 }
359
360 o = (uint16_t*) ((uintptr_t) o + output_increment);
361 } while (--output_width != 0);
362 }
363
xnn_f16_dwconv_minmax_ukernel_up16x9__fma3(size_t channels,size_t output_width,const void ** input,const void * weights,void * output,size_t input_stride,size_t output_increment,size_t input_offset,const void * zero,const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])364 void xnn_f16_dwconv_minmax_ukernel_up16x9__fma3(
365 size_t channels,
366 size_t output_width,
367 const void** input,
368 const void* weights,
369 void* output,
370 size_t input_stride,
371 size_t output_increment,
372 size_t input_offset,
373 const void* zero,
374 const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
375 {
376 assert(channels != 0);
377 assert(output_width != 0);
378
379 const __m256 vmax = _mm256_load_ps(params->avx.max);
380 const __m256 vmin = _mm256_load_ps(params->avx.min);
381
382 uint16_t* o = (uint16_t*) output;
383 do {
384 const uint16_t* i0 = input[0];
385 assert(i0 != NULL);
386 if XNN_UNPREDICTABLE(i0 != zero) {
387 i0 = (const uint16_t*) ((uintptr_t) i0 + input_offset);
388 }
389 const uint16_t* i1 = input[1];
390 assert(i1 != NULL);
391 if XNN_UNPREDICTABLE(i1 != zero) {
392 i1 = (const uint16_t*) ((uintptr_t) i1 + input_offset);
393 }
394 const uint16_t* i2 = input[2];
395 assert(i2 != NULL);
396 if XNN_UNPREDICTABLE(i2 != zero) {
397 i2 = (const uint16_t*) ((uintptr_t) i2 + input_offset);
398 }
399 const uint16_t* i3 = input[3];
400 assert(i3 != NULL);
401 if XNN_UNPREDICTABLE(i3 != zero) {
402 i3 = (const uint16_t*) ((uintptr_t) i3 + input_offset);
403 }
404 const uint16_t* i4 = input[4];
405 assert(i4 != NULL);
406 if XNN_UNPREDICTABLE(i4 != zero) {
407 i4 = (const uint16_t*) ((uintptr_t) i4 + input_offset);
408 }
409 const uint16_t* i5 = input[5];
410 assert(i5 != NULL);
411 if XNN_UNPREDICTABLE(i5 != zero) {
412 i5 = (const uint16_t*) ((uintptr_t) i5 + input_offset);
413 }
414 const uint16_t* i6 = input[6];
415 assert(i6 != NULL);
416 if XNN_UNPREDICTABLE(i6 != zero) {
417 i6 = (const uint16_t*) ((uintptr_t) i6 + input_offset);
418 }
419 const uint16_t* i7 = input[7];
420 assert(i7 != NULL);
421 if XNN_UNPREDICTABLE(i7 != zero) {
422 i7 = (const uint16_t*) ((uintptr_t) i7 + input_offset);
423 }
424 const uint16_t* i8 = input[8];
425 assert(i8 != NULL);
426 if XNN_UNPREDICTABLE(i8 != zero) {
427 i8 = (const uint16_t*) ((uintptr_t) i8 + input_offset);
428 }
429 input = (const void**) ((uintptr_t) input + input_stride);
430
431 size_t c = channels;
432 const uint16_t* w = weights;
433 for (; c >= 16; c -= 16) {
434 __m256 vacc01234567p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
435 __m256 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 8)));
436
437
438 const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
439 const __m256 vi0x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i0 + 8)));
440 i0 += 16;
441
442 const __m256 vk0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 16)));
443 const __m256 vk0x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 24)));
444 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
445 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi0x89ABCDEF, vk0x89ABCDEF, vacc89ABCDEFp0), _MM_FROUND_NO_EXC));
446
447 const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
448 const __m256 vi1x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i1 + 8)));
449 i1 += 16;
450
451 const __m256 vk1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 32)));
452 const __m256 vk1x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 40)));
453 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
454 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi1x89ABCDEF, vk1x89ABCDEF, vacc89ABCDEFp0), _MM_FROUND_NO_EXC));
455
456 const __m256 vi2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
457 const __m256 vi2x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i2 + 8)));
458 i2 += 16;
459
460 const __m256 vk2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 48)));
461 const __m256 vk2x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 56)));
462 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
463 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi2x89ABCDEF, vk2x89ABCDEF, vacc89ABCDEFp0), _MM_FROUND_NO_EXC));
464
465 const __m256 vi3x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
466 const __m256 vi3x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i3 + 8)));
467 i3 += 16;
468
469 const __m256 vk3x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 64)));
470 const __m256 vk3x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 72)));
471 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
472 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi3x89ABCDEF, vk3x89ABCDEF, vacc89ABCDEFp0), _MM_FROUND_NO_EXC));
473
474 const __m256 vi4x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4));
475 const __m256 vi4x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i4 + 8)));
476 i4 += 16;
477
478 const __m256 vk4x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 80)));
479 const __m256 vk4x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 88)));
480 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi4x01234567, vk4x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
481 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi4x89ABCDEF, vk4x89ABCDEF, vacc89ABCDEFp0), _MM_FROUND_NO_EXC));
482
483 const __m256 vi5x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5));
484 const __m256 vi5x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i5 + 8)));
485 i5 += 16;
486
487 const __m256 vk5x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 96)));
488 const __m256 vk5x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 104)));
489 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi5x01234567, vk5x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
490 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi5x89ABCDEF, vk5x89ABCDEF, vacc89ABCDEFp0), _MM_FROUND_NO_EXC));
491
492 const __m256 vi6x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6));
493 const __m256 vi6x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i6 + 8)));
494 i6 += 16;
495
496 const __m256 vk6x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 112)));
497 const __m256 vk6x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 120)));
498 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi6x01234567, vk6x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
499 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi6x89ABCDEF, vk6x89ABCDEF, vacc89ABCDEFp0), _MM_FROUND_NO_EXC));
500
501 const __m256 vi7x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i7));
502 const __m256 vi7x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i7 + 8)));
503 i7 += 16;
504
505 const __m256 vk7x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 128)));
506 const __m256 vk7x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 136)));
507 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi7x01234567, vk7x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
508 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi7x89ABCDEF, vk7x89ABCDEF, vacc89ABCDEFp0), _MM_FROUND_NO_EXC));
509
510 const __m256 vi8x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i8));
511 const __m256 vi8x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i8 + 8)));
512 i8 += 16;
513
514 const __m256 vk8x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 144)));
515 const __m256 vk8x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 152)));
516 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi8x01234567, vk8x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
517 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi8x89ABCDEF, vk8x89ABCDEF, vacc89ABCDEFp0), _MM_FROUND_NO_EXC));
518
519 w += 160;
520
521
522 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
523 __m256 vacc89ABCDEF = _mm256_max_ps(vacc89ABCDEFp0, vmin);
524 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
525 vacc89ABCDEF = _mm256_min_ps(vacc89ABCDEF, vmax);
526
527 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vacc01234567, _MM_FROUND_NO_EXC));
528 _mm_storeu_si128((__m128i*) (o + 8), _mm256_cvtps_ph(vacc89ABCDEF, _MM_FROUND_NO_EXC));
529 o += 16;
530 }
531 for (; c >= 8; c -= 8) {
532 __m256 vacc01234567p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
533
534 const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
535 i0 += 8;
536
537 const __m256 vk0x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 16)));
538 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
539
540 const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
541 i1 += 8;
542
543 const __m256 vk1x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 32)));
544 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
545
546 const __m256 vi2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
547 i2 += 8;
548
549 const __m256 vk2x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 48)));
550 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
551
552 const __m256 vi3x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
553 i3 += 8;
554
555 const __m256 vk3x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 64)));
556 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
557
558 const __m256 vi4x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4));
559 i4 += 8;
560
561 const __m256 vk4x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 80)));
562 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi4x01234567, vk4x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
563
564 const __m256 vi5x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5));
565 i5 += 8;
566
567 const __m256 vk5x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 96)));
568 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi5x01234567, vk5x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
569
570 const __m256 vi6x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6));
571 i6 += 8;
572
573 const __m256 vk6x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 112)));
574 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi6x01234567, vk6x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
575
576 const __m256 vi7x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i7));
577 i7 += 8;
578
579 const __m256 vk7x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 128)));
580 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi7x01234567, vk7x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
581
582 const __m256 vi8x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i8));
583 i8 += 8;
584
585 const __m256 vk8x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 144)));
586 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi8x01234567, vk8x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
587
588 w += 8;
589
590
591 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
592 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
593
594 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vacc01234567, _MM_FROUND_NO_EXC));
595 o += 8;
596 }
597 if XNN_UNLIKELY(c != 0) {
598 assert(c >= 1);
599 assert(c <= 7);
600
601 __m256 vacc01234567p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
602
603 const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
604
605 const __m256 vk0x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 16)));
606 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
607
608 const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
609
610 const __m256 vk1x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 32)));
611 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
612
613 const __m256 vi2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
614
615 const __m256 vk2x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 48)));
616 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
617
618 const __m256 vi3x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
619
620 const __m256 vk3x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 64)));
621 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
622
623 const __m256 vi4x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4));
624
625 const __m256 vk4x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 80)));
626 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi4x01234567, vk4x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
627
628 const __m256 vi5x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5));
629
630 const __m256 vk5x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 96)));
631 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi5x01234567, vk5x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
632
633 const __m256 vi6x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6));
634
635 const __m256 vk6x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 112)));
636 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi6x01234567, vk6x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
637
638 const __m256 vi7x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i7));
639
640 const __m256 vk7x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 128)));
641 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi7x01234567, vk7x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
642
643 const __m256 vi8x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i8));
644
645 const __m256 vk8x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 144)));
646 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi8x01234567, vk8x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
647
648
649 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
650 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
651
652 __m128i vh01234567 = _mm256_cvtps_ph(vacc01234567, _MM_FROUND_NO_EXC);
653 if (c & 4) {
654 _mm_storel_epi64((__m128i*) o, vh01234567);
655 vh01234567 = _mm_unpackhi_epi64(vh01234567, vh01234567);
656 o += 4;
657 }
658 if (c & 2) {
659 _mm_storeu_si32(o, vh01234567);
660 vh01234567 = _mm_srli_epi64(vh01234567, 32);
661 o += 2;
662 }
663 if (c & 1) {
664 *o = (uint16_t) _mm_extract_epi16(vh01234567, 0);
665 o += 1;
666 }
667 }
668
669 o = (uint16_t*) ((uintptr_t) o + output_increment);
670 } while (--output_width != 0);
671 }
672
xnn_f16_dwconv_minmax_ukernel_up8x25__fma3_acc2(size_t channels,size_t output_width,const void ** input,const void * weights,void * output,size_t input_stride,size_t output_increment,size_t input_offset,const void * zero,const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])673 void xnn_f16_dwconv_minmax_ukernel_up8x25__fma3_acc2(
674 size_t channels,
675 size_t output_width,
676 const void** input,
677 const void* weights,
678 void* output,
679 size_t input_stride,
680 size_t output_increment,
681 size_t input_offset,
682 const void* zero,
683 const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
684 {
685 assert(channels != 0);
686 assert(output_width != 0);
687
688 const __m256 vmax = _mm256_load_ps(params->avx.max);
689 const __m256 vmin = _mm256_load_ps(params->avx.min);
690
691 uint16_t* o = (uint16_t*) output;
692 do {
693 const uint16_t* i0 = input[0];
694 assert(i0 != NULL);
695 if XNN_UNPREDICTABLE(i0 != zero) {
696 i0 = (const uint16_t*) ((uintptr_t) i0 + input_offset);
697 }
698 const uint16_t* i1 = input[1];
699 assert(i1 != NULL);
700 if XNN_UNPREDICTABLE(i1 != zero) {
701 i1 = (const uint16_t*) ((uintptr_t) i1 + input_offset);
702 }
703 const uint16_t* i2 = input[2];
704 assert(i2 != NULL);
705 if XNN_UNPREDICTABLE(i2 != zero) {
706 i2 = (const uint16_t*) ((uintptr_t) i2 + input_offset);
707 }
708 const uint16_t* i3 = input[3];
709 assert(i3 != NULL);
710 if XNN_UNPREDICTABLE(i3 != zero) {
711 i3 = (const uint16_t*) ((uintptr_t) i3 + input_offset);
712 }
713 const uint16_t* i4 = input[4];
714 assert(i4 != NULL);
715 if XNN_UNPREDICTABLE(i4 != zero) {
716 i4 = (const uint16_t*) ((uintptr_t) i4 + input_offset);
717 }
718 const uint16_t* i5 = input[5];
719 assert(i5 != NULL);
720 if XNN_UNPREDICTABLE(i5 != zero) {
721 i5 = (const uint16_t*) ((uintptr_t) i5 + input_offset);
722 }
723 const uint16_t* i6 = input[6];
724 assert(i6 != NULL);
725 if XNN_UNPREDICTABLE(i6 != zero) {
726 i6 = (const uint16_t*) ((uintptr_t) i6 + input_offset);
727 }
728 const uint16_t* i7 = input[7];
729 assert(i7 != NULL);
730 if XNN_UNPREDICTABLE(i7 != zero) {
731 i7 = (const uint16_t*) ((uintptr_t) i7 + input_offset);
732 }
733 const uint16_t* i8 = input[8];
734 assert(i8 != NULL);
735 if XNN_UNPREDICTABLE(i8 != zero) {
736 i8 = (const uint16_t*) ((uintptr_t) i8 + input_offset);
737 }
738 const uint16_t* i9 = input[9];
739 assert(i9 != NULL);
740 if XNN_UNPREDICTABLE(i9 != zero) {
741 i9 = (const uint16_t*) ((uintptr_t) i9 + input_offset);
742 }
743 const uint16_t* i10 = input[10];
744 assert(i10 != NULL);
745 if XNN_UNPREDICTABLE(i10 != zero) {
746 i10 = (const uint16_t*) ((uintptr_t) i10 + input_offset);
747 }
748 const uint16_t* i11 = input[11];
749 assert(i11 != NULL);
750 if XNN_UNPREDICTABLE(i11 != zero) {
751 i11 = (const uint16_t*) ((uintptr_t) i11 + input_offset);
752 }
753 const uint16_t* i12 = input[12];
754 assert(i12 != NULL);
755 if XNN_UNPREDICTABLE(i12 != zero) {
756 i12 = (const uint16_t*) ((uintptr_t) i12 + input_offset);
757 }
758 const uint16_t* i13 = input[13];
759 assert(i13 != NULL);
760 if XNN_UNPREDICTABLE(i13 != zero) {
761 i13 = (const uint16_t*) ((uintptr_t) i13 + input_offset);
762 }
763 const uint16_t* i14 = input[14];
764 assert(i14 != NULL);
765 if XNN_UNPREDICTABLE(i14 != zero) {
766 i14 = (const uint16_t*) ((uintptr_t) i14 + input_offset);
767 }
768 const uint16_t* i15 = input[15];
769 assert(i15 != NULL);
770 if XNN_UNPREDICTABLE(i15 != zero) {
771 i15 = (const uint16_t*) ((uintptr_t) i15 + input_offset);
772 }
773 const uint16_t* i16 = input[16];
774 assert(i16 != NULL);
775 if XNN_UNPREDICTABLE(i16 != zero) {
776 i16 = (const uint16_t*) ((uintptr_t) i16 + input_offset);
777 }
778 const uint16_t* i17 = input[17];
779 assert(i17 != NULL);
780 if XNN_UNPREDICTABLE(i17 != zero) {
781 i17 = (const uint16_t*) ((uintptr_t) i17 + input_offset);
782 }
783 const uint16_t* i18 = input[18];
784 assert(i18 != NULL);
785 if XNN_UNPREDICTABLE(i18 != zero) {
786 i18 = (const uint16_t*) ((uintptr_t) i18 + input_offset);
787 }
788 const uint16_t* i19 = input[19];
789 assert(i19 != NULL);
790 if XNN_UNPREDICTABLE(i19 != zero) {
791 i19 = (const uint16_t*) ((uintptr_t) i19 + input_offset);
792 }
793 const uint16_t* i20 = input[20];
794 assert(i20 != NULL);
795 if XNN_UNPREDICTABLE(i20 != zero) {
796 i20 = (const uint16_t*) ((uintptr_t) i20 + input_offset);
797 }
798 const uint16_t* i21 = input[21];
799 assert(i21 != NULL);
800 if XNN_UNPREDICTABLE(i21 != zero) {
801 i21 = (const uint16_t*) ((uintptr_t) i21 + input_offset);
802 }
803 const uint16_t* i22 = input[22];
804 assert(i22 != NULL);
805 if XNN_UNPREDICTABLE(i22 != zero) {
806 i22 = (const uint16_t*) ((uintptr_t) i22 + input_offset);
807 }
808 const uint16_t* i23 = input[23];
809 assert(i23 != NULL);
810 if XNN_UNPREDICTABLE(i23 != zero) {
811 i23 = (const uint16_t*) ((uintptr_t) i23 + input_offset);
812 }
813 const uint16_t* i24 = input[24];
814 assert(i24 != NULL);
815 if XNN_UNPREDICTABLE(i24 != zero) {
816 i24 = (const uint16_t*) ((uintptr_t) i24 + input_offset);
817 }
818 input = (const void**) ((uintptr_t) input + input_stride);
819
820 size_t c = channels;
821 const uint16_t* w = weights;
822 for (; c >= 8; c -= 8) {
823 __m256 vacc01234567p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
824
825
826 const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
827 i0 += 8;
828
829 const __m256 vk0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 8)));
830 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
831
832 const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
833 i1 += 8;
834
835 const __m256 vk1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 16)));
836 __m256 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(vi1x01234567, vk1x01234567), _MM_FROUND_NO_EXC));
837
838 const __m256 vi2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
839 i2 += 8;
840
841 const __m256 vk2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 24)));
842 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
843
844 const __m256 vi3x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
845 i3 += 8;
846
847 const __m256 vk3x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 32)));
848 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
849
850 const __m256 vi4x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4));
851 i4 += 8;
852
853 const __m256 vk4x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 40)));
854 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi4x01234567, vk4x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
855
856 const __m256 vi5x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5));
857 i5 += 8;
858
859 const __m256 vk5x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 48)));
860 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi5x01234567, vk5x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
861
862 const __m256 vi6x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6));
863 i6 += 8;
864
865 const __m256 vk6x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 56)));
866 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi6x01234567, vk6x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
867
868 const __m256 vi7x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i7));
869 i7 += 8;
870
871 const __m256 vk7x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 64)));
872 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi7x01234567, vk7x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
873
874 const __m256 vi8x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i8));
875 i8 += 8;
876
877 const __m256 vk8x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 72)));
878 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi8x01234567, vk8x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
879
880 const __m256 vi9x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i9));
881 i9 += 8;
882
883 const __m256 vk9x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 80)));
884 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi9x01234567, vk9x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
885
886 const __m256 vi10x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i10));
887 i10 += 8;
888
889 const __m256 vk10x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 88)));
890 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi10x01234567, vk10x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
891
892 const __m256 vi11x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i11));
893 i11 += 8;
894
895 const __m256 vk11x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 96)));
896 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi11x01234567, vk11x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
897
898 const __m256 vi12x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i12));
899 i12 += 8;
900
901 const __m256 vk12x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 104)));
902 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi12x01234567, vk12x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
903
904 const __m256 vi13x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i13));
905 i13 += 8;
906
907 const __m256 vk13x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 112)));
908 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi13x01234567, vk13x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
909
910 const __m256 vi14x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i14));
911 i14 += 8;
912
913 const __m256 vk14x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 120)));
914 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi14x01234567, vk14x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
915
916 const __m256 vi15x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i15));
917 i15 += 8;
918
919 const __m256 vk15x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 128)));
920 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi15x01234567, vk15x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
921
922 const __m256 vi16x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i16));
923 i16 += 8;
924
925 const __m256 vk16x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 136)));
926 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi16x01234567, vk16x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
927
928 const __m256 vi17x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i17));
929 i17 += 8;
930
931 const __m256 vk17x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 144)));
932 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi17x01234567, vk17x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
933
934 const __m256 vi18x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i18));
935 i18 += 8;
936
937 const __m256 vk18x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 152)));
938 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi18x01234567, vk18x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
939
940 const __m256 vi19x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i19));
941 i19 += 8;
942
943 const __m256 vk19x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 160)));
944 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi19x01234567, vk19x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
945
946 const __m256 vi20x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i20));
947 i20 += 8;
948
949 const __m256 vk20x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 168)));
950 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi20x01234567, vk20x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
951
952 const __m256 vi21x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i21));
953 i21 += 8;
954
955 const __m256 vk21x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 176)));
956 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi21x01234567, vk21x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
957
958 const __m256 vi22x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i22));
959 i22 += 8;
960
961 const __m256 vk22x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 184)));
962 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi22x01234567, vk22x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
963
964 const __m256 vi23x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i23));
965 i23 += 8;
966
967 const __m256 vk23x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 192)));
968 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi23x01234567, vk23x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
969
970 const __m256 vi24x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i24));
971 i24 += 8;
972
973 const __m256 vk24x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 200)));
974 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi24x01234567, vk24x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
975
976 w += 208;
977
978 // Add up all accumulators to vacc01234567p0
979 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vacc01234567p0, vacc01234567p1), _MM_FROUND_NO_EXC));
980
981 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
982 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
983
984 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vacc01234567, _MM_FROUND_NO_EXC));
985 o += 8;
986 }
987 if XNN_UNLIKELY(c != 0) {
988 assert(c >= 1);
989 assert(c <= 7);
990
991 __m256 vacc01234567p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
992
993 const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
994
995 const __m256 vk0x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 8)));
996 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
997
998 const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
999
1000 const __m256 vk1x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 16)));
1001 __m256 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(vi1x01234567, vk1x01234567), _MM_FROUND_NO_EXC));
1002
1003 const __m256 vi2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
1004
1005 const __m256 vk2x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 24)));
1006 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
1007
1008 const __m256 vi3x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
1009
1010 const __m256 vk3x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 32)));
1011 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
1012
1013 const __m256 vi4x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4));
1014
1015 const __m256 vk4x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 40)));
1016 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi4x01234567, vk4x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
1017
1018 const __m256 vi5x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5));
1019
1020 const __m256 vk5x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 48)));
1021 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi5x01234567, vk5x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
1022
1023 const __m256 vi6x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6));
1024
1025 const __m256 vk6x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 56)));
1026 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi6x01234567, vk6x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
1027
1028 const __m256 vi7x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i7));
1029
1030 const __m256 vk7x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 64)));
1031 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi7x01234567, vk7x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
1032
1033 const __m256 vi8x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i8));
1034
1035 const __m256 vk8x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 72)));
1036 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi8x01234567, vk8x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
1037
1038 const __m256 vi9x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i9));
1039
1040 const __m256 vk9x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 80)));
1041 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi9x01234567, vk9x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
1042
1043 const __m256 vi10x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i10));
1044
1045 const __m256 vk10x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 88)));
1046 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi10x01234567, vk10x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
1047
1048 const __m256 vi11x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i11));
1049
1050 const __m256 vk11x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 96)));
1051 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi11x01234567, vk11x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
1052
1053 const __m256 vi12x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i12));
1054
1055 const __m256 vk12x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 104)));
1056 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi12x01234567, vk12x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
1057
1058 const __m256 vi13x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i13));
1059
1060 const __m256 vk13x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 112)));
1061 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi13x01234567, vk13x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
1062
1063 const __m256 vi14x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i14));
1064
1065 const __m256 vk14x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 120)));
1066 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi14x01234567, vk14x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
1067
1068 const __m256 vi15x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i15));
1069
1070 const __m256 vk15x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 128)));
1071 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi15x01234567, vk15x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
1072
1073 const __m256 vi16x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i16));
1074
1075 const __m256 vk16x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 136)));
1076 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi16x01234567, vk16x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
1077
1078 const __m256 vi17x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i17));
1079
1080 const __m256 vk17x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 144)));
1081 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi17x01234567, vk17x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
1082
1083 const __m256 vi18x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i18));
1084
1085 const __m256 vk18x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 152)));
1086 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi18x01234567, vk18x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
1087
1088 const __m256 vi19x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i19));
1089
1090 const __m256 vk19x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 160)));
1091 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi19x01234567, vk19x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
1092
1093 const __m256 vi20x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i20));
1094
1095 const __m256 vk20x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 168)));
1096 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi20x01234567, vk20x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
1097
1098 const __m256 vi21x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i21));
1099
1100 const __m256 vk21x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 176)));
1101 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi21x01234567, vk21x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
1102
1103 const __m256 vi22x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i22));
1104
1105 const __m256 vk22x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 184)));
1106 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi22x01234567, vk22x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
1107
1108 const __m256 vi23x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i23));
1109
1110 const __m256 vk23x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 192)));
1111 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi23x01234567, vk23x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
1112
1113 const __m256 vi24x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i24));
1114
1115 const __m256 vk24x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 200)));
1116 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi24x01234567, vk24x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
1117
1118 // Add up all accumulators to vacc01234567p0
1119 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vacc01234567p0, vacc01234567p1), _MM_FROUND_NO_EXC));
1120
1121 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
1122 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
1123
1124 __m128i vh01234567 = _mm256_cvtps_ph(vacc01234567, _MM_FROUND_NO_EXC);
1125 if (c & 4) {
1126 _mm_storel_epi64((__m128i*) o, vh01234567);
1127 vh01234567 = _mm_unpackhi_epi64(vh01234567, vh01234567);
1128 o += 4;
1129 }
1130 if (c & 2) {
1131 _mm_storeu_si32(o, vh01234567);
1132 vh01234567 = _mm_srli_epi64(vh01234567, 32);
1133 o += 2;
1134 }
1135 if (c & 1) {
1136 *o = (uint16_t) _mm_extract_epi16(vh01234567, 0);
1137 o += 1;
1138 }
1139 }
1140
1141 o = (uint16_t*) ((uintptr_t) o + output_increment);
1142 } while (--output_width != 0);
1143 }
1144
xnn_f16_ibilinear_ukernel__fma3_c8(size_t output_pixels,size_t channels,const void ** restrict input,size_t input_offset,const void * restrict weights,void * restrict output,size_t output_increment)1145 void xnn_f16_ibilinear_ukernel__fma3_c8(
1146 size_t output_pixels,
1147 size_t channels,
1148 const void**restrict input,
1149 size_t input_offset,
1150 const void*restrict weights,
1151 void*restrict output,
1152 size_t output_increment) XNN_OOB_READS
1153 {
1154 assert(output_pixels != 0);
1155 assert(channels != 0);
1156 assert(channels % sizeof(uint16_t) == 0);
1157
1158 uint16_t* o = (uint16_t*) output;
1159 do {
1160 const uint16_t* i0 = (const uint16_t*) ((uintptr_t) input[0] + input_offset);
1161 const uint16_t* i1 = (const uint16_t*) ((uintptr_t) input[1] + input_offset);
1162 const uint16_t* i2 = (const uint16_t*) ((uintptr_t) input[2] + input_offset);
1163 const uint16_t* i3 = (const uint16_t*) ((uintptr_t) input[3] + input_offset);
1164 input += 4;
1165
1166 const __m256 valphahv = _mm256_cvtph_ps(_mm_castps_si128(_mm_broadcast_ss(weights)));
1167 const __m256 valphah = _mm256_permute_ps(valphahv, _MM_SHUFFLE(2, 0, 2, 0));
1168 const __m256 valphav = _mm256_permute_ps(valphahv, _MM_SHUFFLE(3, 1, 3, 1));
1169 weights = (const uint16_t*) weights + 2;
1170
1171 size_t c = channels;
1172 for (; c >= 8 * sizeof(uint16_t); c -= 8 * sizeof(uint16_t)) {
1173 const __m256 vtl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
1174 i0 += 8;
1175 const __m256 vtr = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
1176 i1 += 8;
1177 const __m256 vbl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
1178 i2 += 8;
1179 const __m256 vbr = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
1180 i3 += 8;
1181
1182 const __m256 vtd = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_sub_ps(vtr, vtl), _MM_FROUND_NO_EXC));
1183 const __m256 vbd = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_sub_ps(vbr, vbl), _MM_FROUND_NO_EXC));
1184
1185 const __m256 vt = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vtd, valphah, vtl), _MM_FROUND_NO_EXC));
1186 const __m256 vb = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vbd, valphah, vbl), _MM_FROUND_NO_EXC));
1187
1188 const __m256 vd = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_sub_ps(vb, vt), _MM_FROUND_NO_EXC));
1189
1190 const __m128i vo = _mm256_cvtps_ph(_mm256_fmadd_ps(vd, valphav, vt), _MM_FROUND_NO_EXC);
1191
1192 _mm_storeu_si128((__m128i*) o, vo);
1193 o += 8;
1194 }
1195 if XNN_UNLIKELY(c != 0) {
1196 const __m256 vtl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
1197 i0 += 8;
1198 const __m256 vtr = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
1199 i1 += 8;
1200 const __m256 vbl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
1201 i2 += 8;
1202 const __m256 vbr = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
1203 i3 += 8;
1204
1205 const __m256 vtd = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_sub_ps(vtr, vtl), _MM_FROUND_NO_EXC));
1206 const __m256 vbd = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_sub_ps(vbr, vbl), _MM_FROUND_NO_EXC));
1207
1208 const __m256 vt = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vtd, valphah, vtl), _MM_FROUND_NO_EXC));
1209 const __m256 vb = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vbd, valphah, vbl), _MM_FROUND_NO_EXC));
1210
1211 const __m256 vd = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_sub_ps(vb, vt), _MM_FROUND_NO_EXC));
1212
1213 __m128i vo = _mm256_cvtps_ph(_mm256_fmadd_ps(vd, valphav, vt), _MM_FROUND_NO_EXC);
1214 if (c & (4 * sizeof(uint16_t))) {
1215 _mm_storel_epi64((__m128i*) o, vo);
1216 vo = _mm_unpackhi_epi64(vo, vo);
1217 o += 4;
1218 }
1219 if (c & (2 * sizeof(uint16_t))) {
1220 _mm_storeu_si32(o, vo);
1221 vo = _mm_srli_epi64(vo, 32);
1222 o += 2;
1223 }
1224 if (c & (1 * sizeof(uint16_t))) {
1225 *o = (uint16_t) _mm_extract_epi16(vo, 0);
1226 o += 1;
1227 }
1228 }
1229
1230 o = (uint16_t*) ((uintptr_t) o + output_increment);
1231 } while (--output_pixels != 0);
1232 }
1233
xnn_f16_vmulcaddc_minmax_ukernel_c8__fma3_2x(size_t rows,size_t channels,const void * restrict input,size_t input_stride,const void * restrict weights,void * restrict output,size_t output_stride,const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1234 void xnn_f16_vmulcaddc_minmax_ukernel_c8__fma3_2x(
1235 size_t rows,
1236 size_t channels,
1237 const void*restrict input,
1238 size_t input_stride,
1239 const void*restrict weights,
1240 void*restrict output,
1241 size_t output_stride,
1242 const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
1243 {
1244 assert(rows != 0);
1245 assert(channels != 0);
1246 assert(channels % sizeof(uint16_t) == 0);
1247
1248 const uint16_t* i0 = (const uint16_t*) input;
1249 uint16_t* o0 = (uint16_t*) output;
1250 const uint16_t* i1 = (const uint16_t*) ((uintptr_t) i0 + input_stride);
1251 uint16_t* o1 = (uint16_t*) ((uintptr_t) o0 + output_stride);
1252
1253 const size_t input_increment = input_stride * 2 - channels;
1254 const size_t output_increment = output_stride * 2 - channels;
1255
1256 const __m256 vmin = _mm256_load_ps(params->avx.min);
1257 const __m256 vmax = _mm256_load_ps(params->avx.max);
1258 do {
1259 if XNN_UNPREDICTABLE(rows < 2) {
1260 i1 = i0;
1261 o1 = o0;
1262 }
1263
1264 const uint16_t* w = (const uint16_t*) weights;
1265 size_t c = channels;
1266 for (; c >= 8 * sizeof(uint16_t); c -= 8 * sizeof(uint16_t)) {
1267 const __m256 vscale = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) w));
1268
1269 __m256 vacc0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
1270 i0 += 8;
1271 __m256 vacc1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
1272 i1 += 8;
1273
1274 const __m256 vbias = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 8)));
1275 w += 16;
1276
1277 vacc0 = _mm256_fmadd_ps(vacc0, vscale, vbias);
1278 vacc1 = _mm256_fmadd_ps(vacc1, vscale, vbias);
1279
1280 vacc0 = _mm256_max_ps(vacc0, vmin);
1281 vacc1 = _mm256_max_ps(vacc1, vmin);
1282
1283 vacc0 = _mm256_min_ps(vacc0, vmax);
1284 vacc1 = _mm256_min_ps(vacc1, vmax);
1285
1286 _mm_storeu_si128((__m128i*) o0, _mm256_cvtps_ph(vacc0, _MM_FROUND_NO_EXC));
1287 o0 += 8;
1288 _mm_storeu_si128((__m128i*) o1, _mm256_cvtps_ph(vacc1, _MM_FROUND_NO_EXC));
1289 o1 += 8;
1290 }
1291 if XNN_UNLIKELY(c != 0) {
1292 const __m256 vscale = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) w));
1293
1294 __m256 vacc0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
1295 i0 = (const uint16_t*) ((uintptr_t) i0 + c);
1296 __m256 vacc1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
1297 i1 = (const uint16_t*) ((uintptr_t) i1 + c);
1298
1299 const __m256 vbias = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 8)));
1300
1301 vacc0 = _mm256_fmadd_ps(vacc0, vscale, vbias);
1302 vacc1 = _mm256_fmadd_ps(vacc1, vscale, vbias);
1303
1304 vacc0 = _mm256_max_ps(vacc0, vmin);
1305 vacc1 = _mm256_max_ps(vacc1, vmin);
1306
1307 vacc0 = _mm256_min_ps(vacc0, vmax);
1308 vacc1 = _mm256_min_ps(vacc1, vmax);
1309
1310 __m128i vh0 = _mm256_cvtps_ph(vacc0, _MM_FROUND_NO_EXC);
1311 __m128i vh1 = _mm256_cvtps_ph(vacc1, _MM_FROUND_NO_EXC);
1312
1313 if (c & (4 * sizeof(uint16_t))) {
1314 _mm_storel_epi64((__m128i*) o0, vh0);
1315 _mm_storel_epi64((__m128i*) o1, vh1);
1316
1317 vh0 = _mm_unpackhi_epi64(vh0, vh0);
1318 vh1 = _mm_unpackhi_epi64(vh1, vh1);
1319
1320 o0 += 4;
1321 o1 += 4;
1322 }
1323 if (c & (2 * sizeof(uint16_t))) {
1324 _mm_storeu_si32(o0, vh0);
1325 _mm_storeu_si32(o1, vh1);
1326
1327 vh0 = _mm_srli_epi64(vh0, 32);
1328 vh1 = _mm_srli_epi64(vh1, 32);
1329
1330 o0 += 2;
1331 o1 += 2;
1332 }
1333 if (c & (1 * sizeof(uint16_t))) {
1334 *o0 = (uint16_t) _mm_extract_epi16(vh0, 0);
1335 *o1 = (uint16_t) _mm_extract_epi16(vh1, 0);
1336
1337 o0 += 1;
1338 o1 += 1;
1339 }
1340 }
1341 i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment);
1342 o0 = (uint16_t*) ((uintptr_t) o0 + output_increment);
1343 i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment);
1344 o1 = (uint16_t*) ((uintptr_t) o1 + output_increment);
1345 rows = doz(rows, 2);
1346 } while (rows != 0);
1347 }
1348
xnn_f32_dwconv_minmax_ukernel_up16x3__fma3(size_t channels,size_t output_width,const float ** input,const float * weights,float * output,size_t input_stride,size_t output_increment,size_t input_offset,const float * zero,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1349 void xnn_f32_dwconv_minmax_ukernel_up16x3__fma3(
1350 size_t channels,
1351 size_t output_width,
1352 const float** input,
1353 const float* weights,
1354 float* output,
1355 size_t input_stride,
1356 size_t output_increment,
1357 size_t input_offset,
1358 const float* zero,
1359 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
1360 {
1361 assert(channels != 0);
1362 assert(output_width != 0);
1363
1364 const __m256 vmax = _mm256_load_ps(params->avx.max);
1365 const __m256 vmin = _mm256_load_ps(params->avx.min);
1366 do {
1367 const float* i0 = input[0];
1368 assert(i0 != NULL);
1369 if XNN_UNPREDICTABLE(i0 != zero) {
1370 i0 = (const float*) ((uintptr_t) i0 + input_offset);
1371 }
1372 const float* i1 = input[1];
1373 assert(i1 != NULL);
1374 if XNN_UNPREDICTABLE(i1 != zero) {
1375 i1 = (const float*) ((uintptr_t) i1 + input_offset);
1376 }
1377 const float* i2 = input[2];
1378 assert(i2 != NULL);
1379 if XNN_UNPREDICTABLE(i2 != zero) {
1380 i2 = (const float*) ((uintptr_t) i2 + input_offset);
1381 }
1382 input = (const float**) ((uintptr_t) input + input_stride);
1383
1384 size_t c = channels;
1385 const float* w = weights;
1386 for (; c >= 16; c -= 16) {
1387 __m256 vacc01234567p0 = _mm256_load_ps(w);
1388 __m256 vacc89ABCDEFp0 = _mm256_load_ps(w + 8);
1389
1390
1391 const __m256 vi0x01234567 = _mm256_loadu_ps(i0);
1392 const __m256 vi0x89ABCDEF = _mm256_loadu_ps(i0 + 8);
1393 i0 += 16;
1394
1395 const __m256 vk0x01234567 = _mm256_load_ps(w + 16);
1396 const __m256 vk0x89ABCDEF = _mm256_load_ps(w + 24);
1397 vacc01234567p0 = _mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0);
1398 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi0x89ABCDEF, vk0x89ABCDEF, vacc89ABCDEFp0);
1399
1400 const __m256 vi1x01234567 = _mm256_loadu_ps(i1);
1401 const __m256 vi1x89ABCDEF = _mm256_loadu_ps(i1 + 8);
1402 i1 += 16;
1403
1404 const __m256 vk1x01234567 = _mm256_load_ps(w + 32);
1405 const __m256 vk1x89ABCDEF = _mm256_load_ps(w + 40);
1406 vacc01234567p0 = _mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0);
1407 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi1x89ABCDEF, vk1x89ABCDEF, vacc89ABCDEFp0);
1408
1409 const __m256 vi2x01234567 = _mm256_loadu_ps(i2);
1410 const __m256 vi2x89ABCDEF = _mm256_loadu_ps(i2 + 8);
1411 i2 += 16;
1412
1413 const __m256 vk2x01234567 = _mm256_load_ps(w + 48);
1414 const __m256 vk2x89ABCDEF = _mm256_load_ps(w + 56);
1415 vacc01234567p0 = _mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0);
1416 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi2x89ABCDEF, vk2x89ABCDEF, vacc89ABCDEFp0);
1417
1418 w += 64;
1419
1420
1421 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
1422 __m256 vacc89ABCDEF = _mm256_max_ps(vacc89ABCDEFp0, vmin);
1423 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
1424 vacc89ABCDEF = _mm256_min_ps(vacc89ABCDEF, vmax);
1425
1426 _mm256_storeu_ps(output, vacc01234567);
1427 _mm256_storeu_ps(output + 8, vacc89ABCDEF);
1428 output += 16;
1429 }
1430 for (; c >= 8; c -= 8) {
1431 __m256 vacc01234567p0 = _mm256_load_ps(w);
1432
1433 const __m256 vi0x01234567 = _mm256_loadu_ps(i0);
1434 i0 += 8;
1435
1436 const __m256 vk0x01234567 = _mm256_load_ps(w + 16);
1437 vacc01234567p0 = _mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0);
1438
1439 const __m256 vi1x01234567 = _mm256_loadu_ps(i1);
1440 i1 += 8;
1441
1442 const __m256 vk1x01234567 = _mm256_load_ps(w + 32);
1443 vacc01234567p0 = _mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0);
1444
1445 const __m256 vi2x01234567 = _mm256_loadu_ps(i2);
1446 i2 += 8;
1447
1448 const __m256 vk2x01234567 = _mm256_load_ps(w + 48);
1449 vacc01234567p0 = _mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0);
1450
1451 w += 8;
1452
1453
1454 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
1455 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
1456
1457 _mm256_storeu_ps(output, vacc01234567);
1458 output += 8;
1459 }
1460 if XNN_UNLIKELY(c != 0) {
1461 assert(c >= 1);
1462 assert(c <= 7);
1463 const __m256i vmask = _mm256_loadu_si256((const __m256i*) ¶ms->avx.mask_table[7 - c]);
1464
1465 __m256 vacc01234567p0 = _mm256_load_ps(w);
1466
1467 const __m256 vi0x01234567 = _mm256_maskload_ps(i0, vmask);
1468 const __m256 vk0x01234567 = _mm256_load_ps(w + 16);
1469 vacc01234567p0 = _mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0);
1470
1471 const __m256 vi1x01234567 = _mm256_maskload_ps(i1, vmask);
1472 const __m256 vk1x01234567 = _mm256_load_ps(w + 32);
1473 vacc01234567p0 = _mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0);
1474
1475 const __m256 vi2x01234567 = _mm256_maskload_ps(i2, vmask);
1476 const __m256 vk2x01234567 = _mm256_load_ps(w + 48);
1477 vacc01234567p0 = _mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0);
1478
1479
1480 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
1481 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
1482
1483 __m128 vacc0123 = _mm256_castps256_ps128(vacc01234567);
1484 if (c & 4) {
1485 _mm_storeu_ps(output, vacc0123);
1486 vacc0123 = _mm256_extractf128_ps(vacc01234567, 1);
1487 output += 4;
1488 }
1489 if (c & 2) {
1490 _mm_storel_pi((__m64*) output, vacc0123);
1491 vacc0123 = _mm_movehl_ps(vacc0123, vacc0123);
1492 output += 2;
1493 }
1494 if (c & 1) {
1495 _mm_store_ss(output, vacc0123);
1496 output += 1;
1497 }
1498 }
1499
1500 output = (float*) ((uintptr_t) output + output_increment);
1501 } while (--output_width != 0);
1502 }
1503
xnn_f32_dwconv_minmax_ukernel_up16x4__fma3(size_t channels,size_t output_width,const float ** input,const float * weights,float * output,size_t input_stride,size_t output_increment,size_t input_offset,const float * zero,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1504 void xnn_f32_dwconv_minmax_ukernel_up16x4__fma3(
1505 size_t channels,
1506 size_t output_width,
1507 const float** input,
1508 const float* weights,
1509 float* output,
1510 size_t input_stride,
1511 size_t output_increment,
1512 size_t input_offset,
1513 const float* zero,
1514 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
1515 {
1516 assert(channels != 0);
1517 assert(output_width != 0);
1518
1519 const __m256 vmax = _mm256_load_ps(params->avx.max);
1520 const __m256 vmin = _mm256_load_ps(params->avx.min);
1521 do {
1522 const float* i0 = input[0];
1523 assert(i0 != NULL);
1524 if XNN_UNPREDICTABLE(i0 != zero) {
1525 i0 = (const float*) ((uintptr_t) i0 + input_offset);
1526 }
1527 const float* i1 = input[1];
1528 assert(i1 != NULL);
1529 if XNN_UNPREDICTABLE(i1 != zero) {
1530 i1 = (const float*) ((uintptr_t) i1 + input_offset);
1531 }
1532 const float* i2 = input[2];
1533 assert(i2 != NULL);
1534 if XNN_UNPREDICTABLE(i2 != zero) {
1535 i2 = (const float*) ((uintptr_t) i2 + input_offset);
1536 }
1537 const float* i3 = input[3];
1538 assert(i3 != NULL);
1539 if XNN_UNPREDICTABLE(i3 != zero) {
1540 i3 = (const float*) ((uintptr_t) i3 + input_offset);
1541 }
1542 input = (const float**) ((uintptr_t) input + input_stride);
1543
1544 size_t c = channels;
1545 const float* w = weights;
1546 for (; c >= 16; c -= 16) {
1547 __m256 vacc01234567p0 = _mm256_load_ps(w);
1548 __m256 vacc89ABCDEFp0 = _mm256_load_ps(w + 8);
1549
1550
1551 const __m256 vi0x01234567 = _mm256_loadu_ps(i0);
1552 const __m256 vi0x89ABCDEF = _mm256_loadu_ps(i0 + 8);
1553 i0 += 16;
1554
1555 const __m256 vk0x01234567 = _mm256_load_ps(w + 16);
1556 const __m256 vk0x89ABCDEF = _mm256_load_ps(w + 24);
1557 vacc01234567p0 = _mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0);
1558 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi0x89ABCDEF, vk0x89ABCDEF, vacc89ABCDEFp0);
1559
1560 const __m256 vi1x01234567 = _mm256_loadu_ps(i1);
1561 const __m256 vi1x89ABCDEF = _mm256_loadu_ps(i1 + 8);
1562 i1 += 16;
1563
1564 const __m256 vk1x01234567 = _mm256_load_ps(w + 32);
1565 const __m256 vk1x89ABCDEF = _mm256_load_ps(w + 40);
1566 vacc01234567p0 = _mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0);
1567 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi1x89ABCDEF, vk1x89ABCDEF, vacc89ABCDEFp0);
1568
1569 const __m256 vi2x01234567 = _mm256_loadu_ps(i2);
1570 const __m256 vi2x89ABCDEF = _mm256_loadu_ps(i2 + 8);
1571 i2 += 16;
1572
1573 const __m256 vk2x01234567 = _mm256_load_ps(w + 48);
1574 const __m256 vk2x89ABCDEF = _mm256_load_ps(w + 56);
1575 vacc01234567p0 = _mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0);
1576 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi2x89ABCDEF, vk2x89ABCDEF, vacc89ABCDEFp0);
1577
1578 const __m256 vi3x01234567 = _mm256_loadu_ps(i3);
1579 const __m256 vi3x89ABCDEF = _mm256_loadu_ps(i3 + 8);
1580 i3 += 16;
1581
1582 const __m256 vk3x01234567 = _mm256_load_ps(w + 64);
1583 const __m256 vk3x89ABCDEF = _mm256_load_ps(w + 72);
1584 vacc01234567p0 = _mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p0);
1585 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi3x89ABCDEF, vk3x89ABCDEF, vacc89ABCDEFp0);
1586
1587 w += 80;
1588
1589
1590 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
1591 __m256 vacc89ABCDEF = _mm256_max_ps(vacc89ABCDEFp0, vmin);
1592 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
1593 vacc89ABCDEF = _mm256_min_ps(vacc89ABCDEF, vmax);
1594
1595 _mm256_storeu_ps(output, vacc01234567);
1596 _mm256_storeu_ps(output + 8, vacc89ABCDEF);
1597 output += 16;
1598 }
1599 for (; c >= 8; c -= 8) {
1600 __m256 vacc01234567p0 = _mm256_load_ps(w);
1601
1602 const __m256 vi0x01234567 = _mm256_loadu_ps(i0);
1603 i0 += 8;
1604
1605 const __m256 vk0x01234567 = _mm256_load_ps(w + 16);
1606 vacc01234567p0 = _mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0);
1607
1608 const __m256 vi1x01234567 = _mm256_loadu_ps(i1);
1609 i1 += 8;
1610
1611 const __m256 vk1x01234567 = _mm256_load_ps(w + 32);
1612 vacc01234567p0 = _mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0);
1613
1614 const __m256 vi2x01234567 = _mm256_loadu_ps(i2);
1615 i2 += 8;
1616
1617 const __m256 vk2x01234567 = _mm256_load_ps(w + 48);
1618 vacc01234567p0 = _mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0);
1619
1620 const __m256 vi3x01234567 = _mm256_loadu_ps(i3);
1621 i3 += 8;
1622
1623 const __m256 vk3x01234567 = _mm256_load_ps(w + 64);
1624 vacc01234567p0 = _mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p0);
1625
1626 w += 8;
1627
1628
1629 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
1630 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
1631
1632 _mm256_storeu_ps(output, vacc01234567);
1633 output += 8;
1634 }
1635 if XNN_UNLIKELY(c != 0) {
1636 assert(c >= 1);
1637 assert(c <= 7);
1638 const __m256i vmask = _mm256_loadu_si256((const __m256i*) ¶ms->avx.mask_table[7 - c]);
1639
1640 __m256 vacc01234567p0 = _mm256_load_ps(w);
1641
1642 const __m256 vi0x01234567 = _mm256_maskload_ps(i0, vmask);
1643 const __m256 vk0x01234567 = _mm256_load_ps(w + 16);
1644 vacc01234567p0 = _mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0);
1645
1646 const __m256 vi1x01234567 = _mm256_maskload_ps(i1, vmask);
1647 const __m256 vk1x01234567 = _mm256_load_ps(w + 32);
1648 vacc01234567p0 = _mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0);
1649
1650 const __m256 vi2x01234567 = _mm256_maskload_ps(i2, vmask);
1651 const __m256 vk2x01234567 = _mm256_load_ps(w + 48);
1652 vacc01234567p0 = _mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0);
1653
1654 const __m256 vi3x01234567 = _mm256_maskload_ps(i3, vmask);
1655 const __m256 vk3x01234567 = _mm256_load_ps(w + 64);
1656 vacc01234567p0 = _mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p0);
1657
1658
1659 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
1660 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
1661
1662 __m128 vacc0123 = _mm256_castps256_ps128(vacc01234567);
1663 if (c & 4) {
1664 _mm_storeu_ps(output, vacc0123);
1665 vacc0123 = _mm256_extractf128_ps(vacc01234567, 1);
1666 output += 4;
1667 }
1668 if (c & 2) {
1669 _mm_storel_pi((__m64*) output, vacc0123);
1670 vacc0123 = _mm_movehl_ps(vacc0123, vacc0123);
1671 output += 2;
1672 }
1673 if (c & 1) {
1674 _mm_store_ss(output, vacc0123);
1675 output += 1;
1676 }
1677 }
1678
1679 output = (float*) ((uintptr_t) output + output_increment);
1680 } while (--output_width != 0);
1681 }
1682
xnn_f32_dwconv_minmax_ukernel_up16x9__fma3(size_t channels,size_t output_width,const float ** input,const float * weights,float * output,size_t input_stride,size_t output_increment,size_t input_offset,const float * zero,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1683 void xnn_f32_dwconv_minmax_ukernel_up16x9__fma3(
1684 size_t channels,
1685 size_t output_width,
1686 const float** input,
1687 const float* weights,
1688 float* output,
1689 size_t input_stride,
1690 size_t output_increment,
1691 size_t input_offset,
1692 const float* zero,
1693 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
1694 {
1695 assert(channels != 0);
1696 assert(output_width != 0);
1697
1698 const __m256 vmax = _mm256_load_ps(params->avx.max);
1699 const __m256 vmin = _mm256_load_ps(params->avx.min);
1700 do {
1701 const float* i0 = input[0];
1702 assert(i0 != NULL);
1703 if XNN_UNPREDICTABLE(i0 != zero) {
1704 i0 = (const float*) ((uintptr_t) i0 + input_offset);
1705 }
1706 const float* i1 = input[1];
1707 assert(i1 != NULL);
1708 if XNN_UNPREDICTABLE(i1 != zero) {
1709 i1 = (const float*) ((uintptr_t) i1 + input_offset);
1710 }
1711 const float* i2 = input[2];
1712 assert(i2 != NULL);
1713 if XNN_UNPREDICTABLE(i2 != zero) {
1714 i2 = (const float*) ((uintptr_t) i2 + input_offset);
1715 }
1716 const float* i3 = input[3];
1717 assert(i3 != NULL);
1718 if XNN_UNPREDICTABLE(i3 != zero) {
1719 i3 = (const float*) ((uintptr_t) i3 + input_offset);
1720 }
1721 const float* i4 = input[4];
1722 assert(i4 != NULL);
1723 if XNN_UNPREDICTABLE(i4 != zero) {
1724 i4 = (const float*) ((uintptr_t) i4 + input_offset);
1725 }
1726 const float* i5 = input[5];
1727 assert(i5 != NULL);
1728 if XNN_UNPREDICTABLE(i5 != zero) {
1729 i5 = (const float*) ((uintptr_t) i5 + input_offset);
1730 }
1731 const float* i6 = input[6];
1732 assert(i6 != NULL);
1733 if XNN_UNPREDICTABLE(i6 != zero) {
1734 i6 = (const float*) ((uintptr_t) i6 + input_offset);
1735 }
1736 const float* i7 = input[7];
1737 assert(i7 != NULL);
1738 if XNN_UNPREDICTABLE(i7 != zero) {
1739 i7 = (const float*) ((uintptr_t) i7 + input_offset);
1740 }
1741 const float* i8 = input[8];
1742 assert(i8 != NULL);
1743 if XNN_UNPREDICTABLE(i8 != zero) {
1744 i8 = (const float*) ((uintptr_t) i8 + input_offset);
1745 }
1746 input = (const float**) ((uintptr_t) input + input_stride);
1747
1748 size_t c = channels;
1749 const float* w = weights;
1750 for (; c >= 16; c -= 16) {
1751 __m256 vacc01234567p0 = _mm256_load_ps(w);
1752 __m256 vacc89ABCDEFp0 = _mm256_load_ps(w + 8);
1753
1754
1755 const __m256 vi0x01234567 = _mm256_loadu_ps(i0);
1756 const __m256 vi0x89ABCDEF = _mm256_loadu_ps(i0 + 8);
1757 i0 += 16;
1758
1759 const __m256 vk0x01234567 = _mm256_load_ps(w + 16);
1760 const __m256 vk0x89ABCDEF = _mm256_load_ps(w + 24);
1761 vacc01234567p0 = _mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0);
1762 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi0x89ABCDEF, vk0x89ABCDEF, vacc89ABCDEFp0);
1763
1764 const __m256 vi1x01234567 = _mm256_loadu_ps(i1);
1765 const __m256 vi1x89ABCDEF = _mm256_loadu_ps(i1 + 8);
1766 i1 += 16;
1767
1768 const __m256 vk1x01234567 = _mm256_load_ps(w + 32);
1769 const __m256 vk1x89ABCDEF = _mm256_load_ps(w + 40);
1770 vacc01234567p0 = _mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0);
1771 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi1x89ABCDEF, vk1x89ABCDEF, vacc89ABCDEFp0);
1772
1773 const __m256 vi2x01234567 = _mm256_loadu_ps(i2);
1774 const __m256 vi2x89ABCDEF = _mm256_loadu_ps(i2 + 8);
1775 i2 += 16;
1776
1777 const __m256 vk2x01234567 = _mm256_load_ps(w + 48);
1778 const __m256 vk2x89ABCDEF = _mm256_load_ps(w + 56);
1779 vacc01234567p0 = _mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0);
1780 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi2x89ABCDEF, vk2x89ABCDEF, vacc89ABCDEFp0);
1781
1782 const __m256 vi3x01234567 = _mm256_loadu_ps(i3);
1783 const __m256 vi3x89ABCDEF = _mm256_loadu_ps(i3 + 8);
1784 i3 += 16;
1785
1786 const __m256 vk3x01234567 = _mm256_load_ps(w + 64);
1787 const __m256 vk3x89ABCDEF = _mm256_load_ps(w + 72);
1788 vacc01234567p0 = _mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p0);
1789 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi3x89ABCDEF, vk3x89ABCDEF, vacc89ABCDEFp0);
1790
1791 const __m256 vi4x01234567 = _mm256_loadu_ps(i4);
1792 const __m256 vi4x89ABCDEF = _mm256_loadu_ps(i4 + 8);
1793 i4 += 16;
1794
1795 const __m256 vk4x01234567 = _mm256_load_ps(w + 80);
1796 const __m256 vk4x89ABCDEF = _mm256_load_ps(w + 88);
1797 vacc01234567p0 = _mm256_fmadd_ps(vi4x01234567, vk4x01234567, vacc01234567p0);
1798 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi4x89ABCDEF, vk4x89ABCDEF, vacc89ABCDEFp0);
1799
1800 const __m256 vi5x01234567 = _mm256_loadu_ps(i5);
1801 const __m256 vi5x89ABCDEF = _mm256_loadu_ps(i5 + 8);
1802 i5 += 16;
1803
1804 const __m256 vk5x01234567 = _mm256_load_ps(w + 96);
1805 const __m256 vk5x89ABCDEF = _mm256_load_ps(w + 104);
1806 vacc01234567p0 = _mm256_fmadd_ps(vi5x01234567, vk5x01234567, vacc01234567p0);
1807 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi5x89ABCDEF, vk5x89ABCDEF, vacc89ABCDEFp0);
1808
1809 const __m256 vi6x01234567 = _mm256_loadu_ps(i6);
1810 const __m256 vi6x89ABCDEF = _mm256_loadu_ps(i6 + 8);
1811 i6 += 16;
1812
1813 const __m256 vk6x01234567 = _mm256_load_ps(w + 112);
1814 const __m256 vk6x89ABCDEF = _mm256_load_ps(w + 120);
1815 vacc01234567p0 = _mm256_fmadd_ps(vi6x01234567, vk6x01234567, vacc01234567p0);
1816 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi6x89ABCDEF, vk6x89ABCDEF, vacc89ABCDEFp0);
1817
1818 const __m256 vi7x01234567 = _mm256_loadu_ps(i7);
1819 const __m256 vi7x89ABCDEF = _mm256_loadu_ps(i7 + 8);
1820 i7 += 16;
1821
1822 const __m256 vk7x01234567 = _mm256_load_ps(w + 128);
1823 const __m256 vk7x89ABCDEF = _mm256_load_ps(w + 136);
1824 vacc01234567p0 = _mm256_fmadd_ps(vi7x01234567, vk7x01234567, vacc01234567p0);
1825 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi7x89ABCDEF, vk7x89ABCDEF, vacc89ABCDEFp0);
1826
1827 const __m256 vi8x01234567 = _mm256_loadu_ps(i8);
1828 const __m256 vi8x89ABCDEF = _mm256_loadu_ps(i8 + 8);
1829 i8 += 16;
1830
1831 const __m256 vk8x01234567 = _mm256_load_ps(w + 144);
1832 const __m256 vk8x89ABCDEF = _mm256_load_ps(w + 152);
1833 vacc01234567p0 = _mm256_fmadd_ps(vi8x01234567, vk8x01234567, vacc01234567p0);
1834 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi8x89ABCDEF, vk8x89ABCDEF, vacc89ABCDEFp0);
1835
1836 w += 160;
1837
1838
1839 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
1840 __m256 vacc89ABCDEF = _mm256_max_ps(vacc89ABCDEFp0, vmin);
1841 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
1842 vacc89ABCDEF = _mm256_min_ps(vacc89ABCDEF, vmax);
1843
1844 _mm256_storeu_ps(output, vacc01234567);
1845 _mm256_storeu_ps(output + 8, vacc89ABCDEF);
1846 output += 16;
1847 }
1848 for (; c >= 8; c -= 8) {
1849 __m256 vacc01234567p0 = _mm256_load_ps(w);
1850
1851 const __m256 vi0x01234567 = _mm256_loadu_ps(i0);
1852 i0 += 8;
1853
1854 const __m256 vk0x01234567 = _mm256_load_ps(w + 16);
1855 vacc01234567p0 = _mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0);
1856
1857 const __m256 vi1x01234567 = _mm256_loadu_ps(i1);
1858 i1 += 8;
1859
1860 const __m256 vk1x01234567 = _mm256_load_ps(w + 32);
1861 vacc01234567p0 = _mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0);
1862
1863 const __m256 vi2x01234567 = _mm256_loadu_ps(i2);
1864 i2 += 8;
1865
1866 const __m256 vk2x01234567 = _mm256_load_ps(w + 48);
1867 vacc01234567p0 = _mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0);
1868
1869 const __m256 vi3x01234567 = _mm256_loadu_ps(i3);
1870 i3 += 8;
1871
1872 const __m256 vk3x01234567 = _mm256_load_ps(w + 64);
1873 vacc01234567p0 = _mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p0);
1874
1875 const __m256 vi4x01234567 = _mm256_loadu_ps(i4);
1876 i4 += 8;
1877
1878 const __m256 vk4x01234567 = _mm256_load_ps(w + 80);
1879 vacc01234567p0 = _mm256_fmadd_ps(vi4x01234567, vk4x01234567, vacc01234567p0);
1880
1881 const __m256 vi5x01234567 = _mm256_loadu_ps(i5);
1882 i5 += 8;
1883
1884 const __m256 vk5x01234567 = _mm256_load_ps(w + 96);
1885 vacc01234567p0 = _mm256_fmadd_ps(vi5x01234567, vk5x01234567, vacc01234567p0);
1886
1887 const __m256 vi6x01234567 = _mm256_loadu_ps(i6);
1888 i6 += 8;
1889
1890 const __m256 vk6x01234567 = _mm256_load_ps(w + 112);
1891 vacc01234567p0 = _mm256_fmadd_ps(vi6x01234567, vk6x01234567, vacc01234567p0);
1892
1893 const __m256 vi7x01234567 = _mm256_loadu_ps(i7);
1894 i7 += 8;
1895
1896 const __m256 vk7x01234567 = _mm256_load_ps(w + 128);
1897 vacc01234567p0 = _mm256_fmadd_ps(vi7x01234567, vk7x01234567, vacc01234567p0);
1898
1899 const __m256 vi8x01234567 = _mm256_loadu_ps(i8);
1900 i8 += 8;
1901
1902 const __m256 vk8x01234567 = _mm256_load_ps(w + 144);
1903 vacc01234567p0 = _mm256_fmadd_ps(vi8x01234567, vk8x01234567, vacc01234567p0);
1904
1905 w += 8;
1906
1907
1908 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
1909 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
1910
1911 _mm256_storeu_ps(output, vacc01234567);
1912 output += 8;
1913 }
1914 if XNN_UNLIKELY(c != 0) {
1915 assert(c >= 1);
1916 assert(c <= 7);
1917 const __m256i vmask = _mm256_loadu_si256((const __m256i*) ¶ms->avx.mask_table[7 - c]);
1918
1919 __m256 vacc01234567p0 = _mm256_load_ps(w);
1920
1921 const __m256 vi0x01234567 = _mm256_maskload_ps(i0, vmask);
1922 const __m256 vk0x01234567 = _mm256_load_ps(w + 16);
1923 vacc01234567p0 = _mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0);
1924
1925 const __m256 vi1x01234567 = _mm256_maskload_ps(i1, vmask);
1926 const __m256 vk1x01234567 = _mm256_load_ps(w + 32);
1927 vacc01234567p0 = _mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0);
1928
1929 const __m256 vi2x01234567 = _mm256_maskload_ps(i2, vmask);
1930 const __m256 vk2x01234567 = _mm256_load_ps(w + 48);
1931 vacc01234567p0 = _mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0);
1932
1933 const __m256 vi3x01234567 = _mm256_maskload_ps(i3, vmask);
1934 const __m256 vk3x01234567 = _mm256_load_ps(w + 64);
1935 vacc01234567p0 = _mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p0);
1936
1937 const __m256 vi4x01234567 = _mm256_maskload_ps(i4, vmask);
1938 const __m256 vk4x01234567 = _mm256_load_ps(w + 80);
1939 vacc01234567p0 = _mm256_fmadd_ps(vi4x01234567, vk4x01234567, vacc01234567p0);
1940
1941 const __m256 vi5x01234567 = _mm256_maskload_ps(i5, vmask);
1942 const __m256 vk5x01234567 = _mm256_load_ps(w + 96);
1943 vacc01234567p0 = _mm256_fmadd_ps(vi5x01234567, vk5x01234567, vacc01234567p0);
1944
1945 const __m256 vi6x01234567 = _mm256_maskload_ps(i6, vmask);
1946 const __m256 vk6x01234567 = _mm256_load_ps(w + 112);
1947 vacc01234567p0 = _mm256_fmadd_ps(vi6x01234567, vk6x01234567, vacc01234567p0);
1948
1949 const __m256 vi7x01234567 = _mm256_maskload_ps(i7, vmask);
1950 const __m256 vk7x01234567 = _mm256_load_ps(w + 128);
1951 vacc01234567p0 = _mm256_fmadd_ps(vi7x01234567, vk7x01234567, vacc01234567p0);
1952
1953 const __m256 vi8x01234567 = _mm256_maskload_ps(i8, vmask);
1954 const __m256 vk8x01234567 = _mm256_load_ps(w + 144);
1955 vacc01234567p0 = _mm256_fmadd_ps(vi8x01234567, vk8x01234567, vacc01234567p0);
1956
1957
1958 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
1959 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
1960
1961 __m128 vacc0123 = _mm256_castps256_ps128(vacc01234567);
1962 if (c & 4) {
1963 _mm_storeu_ps(output, vacc0123);
1964 vacc0123 = _mm256_extractf128_ps(vacc01234567, 1);
1965 output += 4;
1966 }
1967 if (c & 2) {
1968 _mm_storel_pi((__m64*) output, vacc0123);
1969 vacc0123 = _mm_movehl_ps(vacc0123, vacc0123);
1970 output += 2;
1971 }
1972 if (c & 1) {
1973 _mm_store_ss(output, vacc0123);
1974 output += 1;
1975 }
1976 }
1977
1978 output = (float*) ((uintptr_t) output + output_increment);
1979 } while (--output_width != 0);
1980 }
1981
xnn_f32_dwconv_minmax_ukernel_up8x25__fma3(size_t channels,size_t output_width,const float ** input,const float * weights,float * output,size_t input_stride,size_t output_increment,size_t input_offset,const float * zero,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1982 void xnn_f32_dwconv_minmax_ukernel_up8x25__fma3(
1983 size_t channels,
1984 size_t output_width,
1985 const float** input,
1986 const float* weights,
1987 float* output,
1988 size_t input_stride,
1989 size_t output_increment,
1990 size_t input_offset,
1991 const float* zero,
1992 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
1993 {
1994 assert(channels != 0);
1995 assert(output_width != 0);
1996
1997 const __m256 vmax = _mm256_load_ps(params->avx.max);
1998 const __m256 vmin = _mm256_load_ps(params->avx.min);
1999 do {
2000 const float* i0 = input[0];
2001 assert(i0 != NULL);
2002 if XNN_UNPREDICTABLE(i0 != zero) {
2003 i0 = (const float*) ((uintptr_t) i0 + input_offset);
2004 }
2005 const float* i1 = input[1];
2006 assert(i1 != NULL);
2007 if XNN_UNPREDICTABLE(i1 != zero) {
2008 i1 = (const float*) ((uintptr_t) i1 + input_offset);
2009 }
2010 const float* i2 = input[2];
2011 assert(i2 != NULL);
2012 if XNN_UNPREDICTABLE(i2 != zero) {
2013 i2 = (const float*) ((uintptr_t) i2 + input_offset);
2014 }
2015 const float* i3 = input[3];
2016 assert(i3 != NULL);
2017 if XNN_UNPREDICTABLE(i3 != zero) {
2018 i3 = (const float*) ((uintptr_t) i3 + input_offset);
2019 }
2020 const float* i4 = input[4];
2021 assert(i4 != NULL);
2022 if XNN_UNPREDICTABLE(i4 != zero) {
2023 i4 = (const float*) ((uintptr_t) i4 + input_offset);
2024 }
2025 const float* i5 = input[5];
2026 assert(i5 != NULL);
2027 if XNN_UNPREDICTABLE(i5 != zero) {
2028 i5 = (const float*) ((uintptr_t) i5 + input_offset);
2029 }
2030 const float* i6 = input[6];
2031 assert(i6 != NULL);
2032 if XNN_UNPREDICTABLE(i6 != zero) {
2033 i6 = (const float*) ((uintptr_t) i6 + input_offset);
2034 }
2035 const float* i7 = input[7];
2036 assert(i7 != NULL);
2037 if XNN_UNPREDICTABLE(i7 != zero) {
2038 i7 = (const float*) ((uintptr_t) i7 + input_offset);
2039 }
2040 const float* i8 = input[8];
2041 assert(i8 != NULL);
2042 if XNN_UNPREDICTABLE(i8 != zero) {
2043 i8 = (const float*) ((uintptr_t) i8 + input_offset);
2044 }
2045 const float* i9 = input[9];
2046 assert(i9 != NULL);
2047 if XNN_UNPREDICTABLE(i9 != zero) {
2048 i9 = (const float*) ((uintptr_t) i9 + input_offset);
2049 }
2050 const float* i10 = input[10];
2051 assert(i10 != NULL);
2052 if XNN_UNPREDICTABLE(i10 != zero) {
2053 i10 = (const float*) ((uintptr_t) i10 + input_offset);
2054 }
2055 const float* i11 = input[11];
2056 assert(i11 != NULL);
2057 if XNN_UNPREDICTABLE(i11 != zero) {
2058 i11 = (const float*) ((uintptr_t) i11 + input_offset);
2059 }
2060 const float* i12 = input[12];
2061 assert(i12 != NULL);
2062 if XNN_UNPREDICTABLE(i12 != zero) {
2063 i12 = (const float*) ((uintptr_t) i12 + input_offset);
2064 }
2065 const float* i13 = input[13];
2066 assert(i13 != NULL);
2067 if XNN_UNPREDICTABLE(i13 != zero) {
2068 i13 = (const float*) ((uintptr_t) i13 + input_offset);
2069 }
2070 const float* i14 = input[14];
2071 assert(i14 != NULL);
2072 if XNN_UNPREDICTABLE(i14 != zero) {
2073 i14 = (const float*) ((uintptr_t) i14 + input_offset);
2074 }
2075 const float* i15 = input[15];
2076 assert(i15 != NULL);
2077 if XNN_UNPREDICTABLE(i15 != zero) {
2078 i15 = (const float*) ((uintptr_t) i15 + input_offset);
2079 }
2080 const float* i16 = input[16];
2081 assert(i16 != NULL);
2082 if XNN_UNPREDICTABLE(i16 != zero) {
2083 i16 = (const float*) ((uintptr_t) i16 + input_offset);
2084 }
2085 const float* i17 = input[17];
2086 assert(i17 != NULL);
2087 if XNN_UNPREDICTABLE(i17 != zero) {
2088 i17 = (const float*) ((uintptr_t) i17 + input_offset);
2089 }
2090 const float* i18 = input[18];
2091 assert(i18 != NULL);
2092 if XNN_UNPREDICTABLE(i18 != zero) {
2093 i18 = (const float*) ((uintptr_t) i18 + input_offset);
2094 }
2095 const float* i19 = input[19];
2096 assert(i19 != NULL);
2097 if XNN_UNPREDICTABLE(i19 != zero) {
2098 i19 = (const float*) ((uintptr_t) i19 + input_offset);
2099 }
2100 const float* i20 = input[20];
2101 assert(i20 != NULL);
2102 if XNN_UNPREDICTABLE(i20 != zero) {
2103 i20 = (const float*) ((uintptr_t) i20 + input_offset);
2104 }
2105 const float* i21 = input[21];
2106 assert(i21 != NULL);
2107 if XNN_UNPREDICTABLE(i21 != zero) {
2108 i21 = (const float*) ((uintptr_t) i21 + input_offset);
2109 }
2110 const float* i22 = input[22];
2111 assert(i22 != NULL);
2112 if XNN_UNPREDICTABLE(i22 != zero) {
2113 i22 = (const float*) ((uintptr_t) i22 + input_offset);
2114 }
2115 const float* i23 = input[23];
2116 assert(i23 != NULL);
2117 if XNN_UNPREDICTABLE(i23 != zero) {
2118 i23 = (const float*) ((uintptr_t) i23 + input_offset);
2119 }
2120 const float* i24 = input[24];
2121 assert(i24 != NULL);
2122 if XNN_UNPREDICTABLE(i24 != zero) {
2123 i24 = (const float*) ((uintptr_t) i24 + input_offset);
2124 }
2125 input = (const float**) ((uintptr_t) input + input_stride);
2126
2127 size_t c = channels;
2128 const float* w = weights;
2129 for (; c >= 8; c -= 8) {
2130 __m256 vacc01234567p0 = _mm256_load_ps(w);
2131
2132
2133 const __m256 vi0x01234567 = _mm256_loadu_ps(i0);
2134 i0 += 8;
2135
2136 const __m256 vk0x01234567 = _mm256_load_ps(w + 8);
2137 vacc01234567p0 = _mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0);
2138
2139 const __m256 vi1x01234567 = _mm256_loadu_ps(i1);
2140 i1 += 8;
2141
2142 const __m256 vk1x01234567 = _mm256_load_ps(w + 16);
2143 vacc01234567p0 = _mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0);
2144
2145 const __m256 vi2x01234567 = _mm256_loadu_ps(i2);
2146 i2 += 8;
2147
2148 const __m256 vk2x01234567 = _mm256_load_ps(w + 24);
2149 vacc01234567p0 = _mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0);
2150
2151 const __m256 vi3x01234567 = _mm256_loadu_ps(i3);
2152 i3 += 8;
2153
2154 const __m256 vk3x01234567 = _mm256_load_ps(w + 32);
2155 vacc01234567p0 = _mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p0);
2156
2157 const __m256 vi4x01234567 = _mm256_loadu_ps(i4);
2158 i4 += 8;
2159
2160 const __m256 vk4x01234567 = _mm256_load_ps(w + 40);
2161 vacc01234567p0 = _mm256_fmadd_ps(vi4x01234567, vk4x01234567, vacc01234567p0);
2162
2163 const __m256 vi5x01234567 = _mm256_loadu_ps(i5);
2164 i5 += 8;
2165
2166 const __m256 vk5x01234567 = _mm256_load_ps(w + 48);
2167 vacc01234567p0 = _mm256_fmadd_ps(vi5x01234567, vk5x01234567, vacc01234567p0);
2168
2169 const __m256 vi6x01234567 = _mm256_loadu_ps(i6);
2170 i6 += 8;
2171
2172 const __m256 vk6x01234567 = _mm256_load_ps(w + 56);
2173 vacc01234567p0 = _mm256_fmadd_ps(vi6x01234567, vk6x01234567, vacc01234567p0);
2174
2175 const __m256 vi7x01234567 = _mm256_loadu_ps(i7);
2176 i7 += 8;
2177
2178 const __m256 vk7x01234567 = _mm256_load_ps(w + 64);
2179 vacc01234567p0 = _mm256_fmadd_ps(vi7x01234567, vk7x01234567, vacc01234567p0);
2180
2181 const __m256 vi8x01234567 = _mm256_loadu_ps(i8);
2182 i8 += 8;
2183
2184 const __m256 vk8x01234567 = _mm256_load_ps(w + 72);
2185 vacc01234567p0 = _mm256_fmadd_ps(vi8x01234567, vk8x01234567, vacc01234567p0);
2186
2187 const __m256 vi9x01234567 = _mm256_loadu_ps(i9);
2188 i9 += 8;
2189
2190 const __m256 vk9x01234567 = _mm256_load_ps(w + 80);
2191 vacc01234567p0 = _mm256_fmadd_ps(vi9x01234567, vk9x01234567, vacc01234567p0);
2192
2193 const __m256 vi10x01234567 = _mm256_loadu_ps(i10);
2194 i10 += 8;
2195
2196 const __m256 vk10x01234567 = _mm256_load_ps(w + 88);
2197 vacc01234567p0 = _mm256_fmadd_ps(vi10x01234567, vk10x01234567, vacc01234567p0);
2198
2199 const __m256 vi11x01234567 = _mm256_loadu_ps(i11);
2200 i11 += 8;
2201
2202 const __m256 vk11x01234567 = _mm256_load_ps(w + 96);
2203 vacc01234567p0 = _mm256_fmadd_ps(vi11x01234567, vk11x01234567, vacc01234567p0);
2204
2205 const __m256 vi12x01234567 = _mm256_loadu_ps(i12);
2206 i12 += 8;
2207
2208 const __m256 vk12x01234567 = _mm256_load_ps(w + 104);
2209 vacc01234567p0 = _mm256_fmadd_ps(vi12x01234567, vk12x01234567, vacc01234567p0);
2210
2211 const __m256 vi13x01234567 = _mm256_loadu_ps(i13);
2212 i13 += 8;
2213
2214 const __m256 vk13x01234567 = _mm256_load_ps(w + 112);
2215 vacc01234567p0 = _mm256_fmadd_ps(vi13x01234567, vk13x01234567, vacc01234567p0);
2216
2217 const __m256 vi14x01234567 = _mm256_loadu_ps(i14);
2218 i14 += 8;
2219
2220 const __m256 vk14x01234567 = _mm256_load_ps(w + 120);
2221 vacc01234567p0 = _mm256_fmadd_ps(vi14x01234567, vk14x01234567, vacc01234567p0);
2222
2223 const __m256 vi15x01234567 = _mm256_loadu_ps(i15);
2224 i15 += 8;
2225
2226 const __m256 vk15x01234567 = _mm256_load_ps(w + 128);
2227 vacc01234567p0 = _mm256_fmadd_ps(vi15x01234567, vk15x01234567, vacc01234567p0);
2228
2229 const __m256 vi16x01234567 = _mm256_loadu_ps(i16);
2230 i16 += 8;
2231
2232 const __m256 vk16x01234567 = _mm256_load_ps(w + 136);
2233 vacc01234567p0 = _mm256_fmadd_ps(vi16x01234567, vk16x01234567, vacc01234567p0);
2234
2235 const __m256 vi17x01234567 = _mm256_loadu_ps(i17);
2236 i17 += 8;
2237
2238 const __m256 vk17x01234567 = _mm256_load_ps(w + 144);
2239 vacc01234567p0 = _mm256_fmadd_ps(vi17x01234567, vk17x01234567, vacc01234567p0);
2240
2241 const __m256 vi18x01234567 = _mm256_loadu_ps(i18);
2242 i18 += 8;
2243
2244 const __m256 vk18x01234567 = _mm256_load_ps(w + 152);
2245 vacc01234567p0 = _mm256_fmadd_ps(vi18x01234567, vk18x01234567, vacc01234567p0);
2246
2247 const __m256 vi19x01234567 = _mm256_loadu_ps(i19);
2248 i19 += 8;
2249
2250 const __m256 vk19x01234567 = _mm256_load_ps(w + 160);
2251 vacc01234567p0 = _mm256_fmadd_ps(vi19x01234567, vk19x01234567, vacc01234567p0);
2252
2253 const __m256 vi20x01234567 = _mm256_loadu_ps(i20);
2254 i20 += 8;
2255
2256 const __m256 vk20x01234567 = _mm256_load_ps(w + 168);
2257 vacc01234567p0 = _mm256_fmadd_ps(vi20x01234567, vk20x01234567, vacc01234567p0);
2258
2259 const __m256 vi21x01234567 = _mm256_loadu_ps(i21);
2260 i21 += 8;
2261
2262 const __m256 vk21x01234567 = _mm256_load_ps(w + 176);
2263 vacc01234567p0 = _mm256_fmadd_ps(vi21x01234567, vk21x01234567, vacc01234567p0);
2264
2265 const __m256 vi22x01234567 = _mm256_loadu_ps(i22);
2266 i22 += 8;
2267
2268 const __m256 vk22x01234567 = _mm256_load_ps(w + 184);
2269 vacc01234567p0 = _mm256_fmadd_ps(vi22x01234567, vk22x01234567, vacc01234567p0);
2270
2271 const __m256 vi23x01234567 = _mm256_loadu_ps(i23);
2272 i23 += 8;
2273
2274 const __m256 vk23x01234567 = _mm256_load_ps(w + 192);
2275 vacc01234567p0 = _mm256_fmadd_ps(vi23x01234567, vk23x01234567, vacc01234567p0);
2276
2277 const __m256 vi24x01234567 = _mm256_loadu_ps(i24);
2278 i24 += 8;
2279
2280 const __m256 vk24x01234567 = _mm256_load_ps(w + 200);
2281 vacc01234567p0 = _mm256_fmadd_ps(vi24x01234567, vk24x01234567, vacc01234567p0);
2282
2283 w += 208;
2284
2285
2286 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
2287 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
2288
2289 _mm256_storeu_ps(output, vacc01234567);
2290 output += 8;
2291 }
2292 if XNN_UNLIKELY(c != 0) {
2293 assert(c >= 1);
2294 assert(c <= 7);
2295 const __m256i vmask = _mm256_loadu_si256((const __m256i*) ¶ms->avx.mask_table[7 - c]);
2296
2297 __m256 vacc01234567p0 = _mm256_load_ps(w);
2298
2299 const __m256 vi0x01234567 = _mm256_maskload_ps(i0, vmask);
2300 const __m256 vk0x01234567 = _mm256_load_ps(w + 8);
2301 vacc01234567p0 = _mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0);
2302
2303 const __m256 vi1x01234567 = _mm256_maskload_ps(i1, vmask);
2304 const __m256 vk1x01234567 = _mm256_load_ps(w + 16);
2305 vacc01234567p0 = _mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0);
2306
2307 const __m256 vi2x01234567 = _mm256_maskload_ps(i2, vmask);
2308 const __m256 vk2x01234567 = _mm256_load_ps(w + 24);
2309 vacc01234567p0 = _mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0);
2310
2311 const __m256 vi3x01234567 = _mm256_maskload_ps(i3, vmask);
2312 const __m256 vk3x01234567 = _mm256_load_ps(w + 32);
2313 vacc01234567p0 = _mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p0);
2314
2315 const __m256 vi4x01234567 = _mm256_maskload_ps(i4, vmask);
2316 const __m256 vk4x01234567 = _mm256_load_ps(w + 40);
2317 vacc01234567p0 = _mm256_fmadd_ps(vi4x01234567, vk4x01234567, vacc01234567p0);
2318
2319 const __m256 vi5x01234567 = _mm256_maskload_ps(i5, vmask);
2320 const __m256 vk5x01234567 = _mm256_load_ps(w + 48);
2321 vacc01234567p0 = _mm256_fmadd_ps(vi5x01234567, vk5x01234567, vacc01234567p0);
2322
2323 const __m256 vi6x01234567 = _mm256_maskload_ps(i6, vmask);
2324 const __m256 vk6x01234567 = _mm256_load_ps(w + 56);
2325 vacc01234567p0 = _mm256_fmadd_ps(vi6x01234567, vk6x01234567, vacc01234567p0);
2326
2327 const __m256 vi7x01234567 = _mm256_maskload_ps(i7, vmask);
2328 const __m256 vk7x01234567 = _mm256_load_ps(w + 64);
2329 vacc01234567p0 = _mm256_fmadd_ps(vi7x01234567, vk7x01234567, vacc01234567p0);
2330
2331 const __m256 vi8x01234567 = _mm256_maskload_ps(i8, vmask);
2332 const __m256 vk8x01234567 = _mm256_load_ps(w + 72);
2333 vacc01234567p0 = _mm256_fmadd_ps(vi8x01234567, vk8x01234567, vacc01234567p0);
2334
2335 const __m256 vi9x01234567 = _mm256_maskload_ps(i9, vmask);
2336 const __m256 vk9x01234567 = _mm256_load_ps(w + 80);
2337 vacc01234567p0 = _mm256_fmadd_ps(vi9x01234567, vk9x01234567, vacc01234567p0);
2338
2339 const __m256 vi10x01234567 = _mm256_maskload_ps(i10, vmask);
2340 const __m256 vk10x01234567 = _mm256_load_ps(w + 88);
2341 vacc01234567p0 = _mm256_fmadd_ps(vi10x01234567, vk10x01234567, vacc01234567p0);
2342
2343 const __m256 vi11x01234567 = _mm256_maskload_ps(i11, vmask);
2344 const __m256 vk11x01234567 = _mm256_load_ps(w + 96);
2345 vacc01234567p0 = _mm256_fmadd_ps(vi11x01234567, vk11x01234567, vacc01234567p0);
2346
2347 const __m256 vi12x01234567 = _mm256_maskload_ps(i12, vmask);
2348 const __m256 vk12x01234567 = _mm256_load_ps(w + 104);
2349 vacc01234567p0 = _mm256_fmadd_ps(vi12x01234567, vk12x01234567, vacc01234567p0);
2350
2351 const __m256 vi13x01234567 = _mm256_maskload_ps(i13, vmask);
2352 const __m256 vk13x01234567 = _mm256_load_ps(w + 112);
2353 vacc01234567p0 = _mm256_fmadd_ps(vi13x01234567, vk13x01234567, vacc01234567p0);
2354
2355 const __m256 vi14x01234567 = _mm256_maskload_ps(i14, vmask);
2356 const __m256 vk14x01234567 = _mm256_load_ps(w + 120);
2357 vacc01234567p0 = _mm256_fmadd_ps(vi14x01234567, vk14x01234567, vacc01234567p0);
2358
2359 const __m256 vi15x01234567 = _mm256_maskload_ps(i15, vmask);
2360 const __m256 vk15x01234567 = _mm256_load_ps(w + 128);
2361 vacc01234567p0 = _mm256_fmadd_ps(vi15x01234567, vk15x01234567, vacc01234567p0);
2362
2363 const __m256 vi16x01234567 = _mm256_maskload_ps(i16, vmask);
2364 const __m256 vk16x01234567 = _mm256_load_ps(w + 136);
2365 vacc01234567p0 = _mm256_fmadd_ps(vi16x01234567, vk16x01234567, vacc01234567p0);
2366
2367 const __m256 vi17x01234567 = _mm256_maskload_ps(i17, vmask);
2368 const __m256 vk17x01234567 = _mm256_load_ps(w + 144);
2369 vacc01234567p0 = _mm256_fmadd_ps(vi17x01234567, vk17x01234567, vacc01234567p0);
2370
2371 const __m256 vi18x01234567 = _mm256_maskload_ps(i18, vmask);
2372 const __m256 vk18x01234567 = _mm256_load_ps(w + 152);
2373 vacc01234567p0 = _mm256_fmadd_ps(vi18x01234567, vk18x01234567, vacc01234567p0);
2374
2375 const __m256 vi19x01234567 = _mm256_maskload_ps(i19, vmask);
2376 const __m256 vk19x01234567 = _mm256_load_ps(w + 160);
2377 vacc01234567p0 = _mm256_fmadd_ps(vi19x01234567, vk19x01234567, vacc01234567p0);
2378
2379 const __m256 vi20x01234567 = _mm256_maskload_ps(i20, vmask);
2380 const __m256 vk20x01234567 = _mm256_load_ps(w + 168);
2381 vacc01234567p0 = _mm256_fmadd_ps(vi20x01234567, vk20x01234567, vacc01234567p0);
2382
2383 const __m256 vi21x01234567 = _mm256_maskload_ps(i21, vmask);
2384 const __m256 vk21x01234567 = _mm256_load_ps(w + 176);
2385 vacc01234567p0 = _mm256_fmadd_ps(vi21x01234567, vk21x01234567, vacc01234567p0);
2386
2387 const __m256 vi22x01234567 = _mm256_maskload_ps(i22, vmask);
2388 const __m256 vk22x01234567 = _mm256_load_ps(w + 184);
2389 vacc01234567p0 = _mm256_fmadd_ps(vi22x01234567, vk22x01234567, vacc01234567p0);
2390
2391 const __m256 vi23x01234567 = _mm256_maskload_ps(i23, vmask);
2392 const __m256 vk23x01234567 = _mm256_load_ps(w + 192);
2393 vacc01234567p0 = _mm256_fmadd_ps(vi23x01234567, vk23x01234567, vacc01234567p0);
2394
2395 const __m256 vi24x01234567 = _mm256_maskload_ps(i24, vmask);
2396 const __m256 vk24x01234567 = _mm256_load_ps(w + 200);
2397 vacc01234567p0 = _mm256_fmadd_ps(vi24x01234567, vk24x01234567, vacc01234567p0);
2398
2399
2400 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
2401 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
2402
2403 __m128 vacc0123 = _mm256_castps256_ps128(vacc01234567);
2404 if (c & 4) {
2405 _mm_storeu_ps(output, vacc0123);
2406 vacc0123 = _mm256_extractf128_ps(vacc01234567, 1);
2407 output += 4;
2408 }
2409 if (c & 2) {
2410 _mm_storel_pi((__m64*) output, vacc0123);
2411 vacc0123 = _mm_movehl_ps(vacc0123, vacc0123);
2412 output += 2;
2413 }
2414 if (c & 1) {
2415 _mm_store_ss(output, vacc0123);
2416 output += 1;
2417 }
2418 }
2419
2420 output = (float*) ((uintptr_t) output + output_increment);
2421 } while (--output_width != 0);
2422 }
2423
xnn_f32_gemm_minmax_ukernel_1x16__fma3_broadcast(size_t mr,size_t nc,size_t kc,const float * restrict a,size_t a_stride,const float * restrict w,float * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])2424 void xnn_f32_gemm_minmax_ukernel_1x16__fma3_broadcast(
2425 size_t mr,
2426 size_t nc,
2427 size_t kc,
2428 const float*restrict a,
2429 size_t a_stride,
2430 const float*restrict w,
2431 float*restrict c,
2432 size_t cm_stride,
2433 size_t cn_stride,
2434 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
2435 {
2436 assert(mr != 0);
2437 assert(mr <= 1);
2438 assert(nc != 0);
2439 assert(kc != 0);
2440 assert(kc % sizeof(float) == 0);
2441 assert(a != NULL);
2442 assert(w != NULL);
2443 assert(c != NULL);
2444
2445 const float* a0 = a;
2446 float* c0 = c;
2447
2448 do {
2449 __m256 vacc0x01234567 = _mm256_load_ps(w + 0);
2450 __m256 vacc0x89ABCDEF = _mm256_load_ps(w + 8);
2451 w += 16;
2452
2453 size_t k = kc;
2454 do {
2455 const __m256 va0 = _mm256_broadcast_ss(a0);
2456 a0 += 1;
2457
2458 const __m256 vb01234567 = _mm256_load_ps(w);
2459 const __m256 vb89ABCDEF = _mm256_load_ps(w + 8);
2460 w += 16;
2461
2462 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567, vacc0x01234567);
2463 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEF, vacc0x89ABCDEF);
2464
2465 k -= sizeof(float);
2466 } while (k != 0);
2467
2468 const __m256 vmin = _mm256_load_ps(params->avx.min);
2469 vacc0x01234567 = _mm256_max_ps(vacc0x01234567, vmin);
2470 vacc0x89ABCDEF = _mm256_max_ps(vacc0x89ABCDEF, vmin);
2471
2472 const __m256 vmax = _mm256_load_ps(params->avx.max);
2473 vacc0x01234567 = _mm256_min_ps(vacc0x01234567, vmax);
2474 vacc0x89ABCDEF = _mm256_min_ps(vacc0x89ABCDEF, vmax);
2475
2476 if XNN_LIKELY(nc >= 16) {
2477 _mm256_storeu_ps(c0, vacc0x01234567);
2478 _mm256_storeu_ps(c0 + 8, vacc0x89ABCDEF);
2479 c0 = (float*) ((uintptr_t) c0 + cn_stride);
2480
2481 a0 = (const float*) ((uintptr_t) a0 - kc);
2482
2483 nc -= 16;
2484 } else {
2485 if (nc & 8) {
2486 _mm256_storeu_ps(c0, vacc0x01234567);
2487
2488 vacc0x01234567 = vacc0x89ABCDEF;
2489
2490 c0 += 8;
2491 }
2492 __m128 vacc0x0123 = _mm256_castps256_ps128(vacc0x01234567);
2493 if (nc & 4) {
2494 _mm_storeu_ps(c0, vacc0x0123);
2495
2496 vacc0x0123 = _mm256_extractf128_ps(vacc0x01234567, 1);
2497
2498 c0 += 4;
2499 }
2500 if (nc & 2) {
2501 _mm_storel_pi((__m64*) c0, vacc0x0123);
2502
2503 vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
2504
2505 c0 += 2;
2506 }
2507 if (nc & 1) {
2508 _mm_store_ss(c0, vacc0x0123);
2509 }
2510
2511 nc = 0;
2512 }
2513 } while (nc != 0);
2514 }
2515
xnn_f32_gemm_minmax_ukernel_1x16s4__fma3_broadcast(size_t mr,size_t nc,size_t kc,const float * restrict a,size_t a_stride,const float * restrict w,float * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])2516 void xnn_f32_gemm_minmax_ukernel_1x16s4__fma3_broadcast(
2517 size_t mr,
2518 size_t nc,
2519 size_t kc,
2520 const float*restrict a,
2521 size_t a_stride,
2522 const float*restrict w,
2523 float*restrict c,
2524 size_t cm_stride,
2525 size_t cn_stride,
2526 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
2527 {
2528 assert(mr != 0);
2529 assert(mr <= 1);
2530 assert(nc != 0);
2531 assert(kc != 0);
2532 assert(kc % sizeof(float) == 0);
2533 assert(a != NULL);
2534 assert(w != NULL);
2535 assert(c != NULL);
2536
2537 const float* a0 = a;
2538 float* c0 = c;
2539
2540 do {
2541 __m256 vacc0x01234567 = _mm256_load_ps(w + 0);
2542 __m256 vacc0x89ABCDEF = _mm256_load_ps(w + 8);
2543 w += 16;
2544
2545 size_t k = kc;
2546 while (k >= 4 * sizeof(float)) {
2547 __m256 va0 = _mm256_broadcast_ps((const __m128*) a0);
2548 a0 += 4;
2549
2550
2551 const __m256 vb01234567c0 = _mm256_load_ps(w + 0);
2552 const __m256 vb89ABCDEFc0 = _mm256_load_ps(w + 8);
2553
2554 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c0, vacc0x01234567);
2555 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc0, vacc0x89ABCDEF);
2556
2557 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
2558
2559 const __m256 vb01234567c1 = _mm256_load_ps(w + 16);
2560 const __m256 vb89ABCDEFc1 = _mm256_load_ps(w + 24);
2561
2562 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c1, vacc0x01234567);
2563 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc1, vacc0x89ABCDEF);
2564
2565 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
2566
2567 const __m256 vb01234567c2 = _mm256_load_ps(w + 32);
2568 const __m256 vb89ABCDEFc2 = _mm256_load_ps(w + 40);
2569
2570 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c2, vacc0x01234567);
2571 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc2, vacc0x89ABCDEF);
2572
2573 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
2574
2575 const __m256 vb01234567c3 = _mm256_load_ps(w + 48);
2576 const __m256 vb89ABCDEFc3 = _mm256_load_ps(w + 56);
2577
2578 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c3, vacc0x01234567);
2579 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc3, vacc0x89ABCDEF);
2580
2581
2582 w += 64;
2583 k -= 4 * sizeof(float);
2584 }
2585 if XNN_UNLIKELY(k != 0) {
2586 __m256 va0 = _mm256_broadcast_ps((const __m128*) a0);
2587 a0 = (const float*) ((uintptr_t) a0 + k);
2588
2589 const __m256 vzero = _mm256_setzero_ps();
2590
2591 const __m256 vb01234567c0 = _mm256_load_ps(w + 0);
2592 const __m256 vb89ABCDEFc0 = _mm256_load_ps(w + 8);
2593
2594 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c0, vzero, _CMP_NEQ_OQ)), vb01234567c0, vacc0x01234567);
2595 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc0, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc0, vacc0x89ABCDEF);
2596
2597 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
2598
2599 const __m256 vb01234567c1 = _mm256_load_ps(w + 16);
2600 const __m256 vb89ABCDEFc1 = _mm256_load_ps(w + 24);
2601
2602 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c1, vzero, _CMP_NEQ_OQ)), vb01234567c1, vacc0x01234567);
2603 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc1, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc1, vacc0x89ABCDEF);
2604
2605 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
2606
2607 const __m256 vb01234567c2 = _mm256_load_ps(w + 32);
2608 const __m256 vb89ABCDEFc2 = _mm256_load_ps(w + 40);
2609
2610 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c2, vzero, _CMP_NEQ_OQ)), vb01234567c2, vacc0x01234567);
2611 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc2, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc2, vacc0x89ABCDEF);
2612
2613 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
2614
2615 const __m256 vb01234567c3 = _mm256_load_ps(w + 48);
2616 const __m256 vb89ABCDEFc3 = _mm256_load_ps(w + 56);
2617
2618 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c3, vzero, _CMP_NEQ_OQ)), vb01234567c3, vacc0x01234567);
2619 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc3, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc3, vacc0x89ABCDEF);
2620
2621
2622 w += 64;
2623 }
2624
2625 const __m256 vmin = _mm256_load_ps(params->avx.min);
2626 vacc0x01234567 = _mm256_max_ps(vacc0x01234567, vmin);
2627 vacc0x89ABCDEF = _mm256_max_ps(vacc0x89ABCDEF, vmin);
2628
2629 const __m256 vmax = _mm256_load_ps(params->avx.max);
2630 vacc0x01234567 = _mm256_min_ps(vacc0x01234567, vmax);
2631 vacc0x89ABCDEF = _mm256_min_ps(vacc0x89ABCDEF, vmax);
2632
2633 if XNN_LIKELY(nc >= 16) {
2634 _mm256_storeu_ps(c0, vacc0x01234567);
2635 _mm256_storeu_ps(c0 + 8, vacc0x89ABCDEF);
2636 c0 = (float*) ((uintptr_t) c0 + cn_stride);
2637
2638 a0 = (const float*) ((uintptr_t) a0 - kc);
2639
2640 nc -= 16;
2641 } else {
2642 if (nc & 8) {
2643 _mm256_storeu_ps(c0, vacc0x01234567);
2644
2645 vacc0x01234567 = vacc0x89ABCDEF;
2646
2647 c0 += 8;
2648 }
2649 __m128 vacc0x0123 = _mm256_castps256_ps128(vacc0x01234567);
2650 if (nc & 4) {
2651 _mm_storeu_ps(c0, vacc0x0123);
2652
2653 vacc0x0123 = _mm256_extractf128_ps(vacc0x01234567, 1);
2654
2655 c0 += 4;
2656 }
2657 if (nc & 2) {
2658 _mm_storel_pi((__m64*) c0, vacc0x0123);
2659
2660 vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
2661
2662 c0 += 2;
2663 }
2664 if (nc & 1) {
2665 _mm_store_ss(c0, vacc0x0123);
2666 }
2667
2668 nc = 0;
2669 }
2670 } while (nc != 0);
2671 }
2672
xnn_f32_gemm_minmax_ukernel_4x16s4__fma3_broadcast(size_t mr,size_t nc,size_t kc,const float * restrict a,size_t a_stride,const float * restrict w,float * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])2673 void xnn_f32_gemm_minmax_ukernel_4x16s4__fma3_broadcast(
2674 size_t mr,
2675 size_t nc,
2676 size_t kc,
2677 const float*restrict a,
2678 size_t a_stride,
2679 const float*restrict w,
2680 float*restrict c,
2681 size_t cm_stride,
2682 size_t cn_stride,
2683 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
2684 {
2685 assert(mr != 0);
2686 assert(mr <= 4);
2687 assert(nc != 0);
2688 assert(kc != 0);
2689 assert(kc % sizeof(float) == 0);
2690 assert(a != NULL);
2691 assert(w != NULL);
2692 assert(c != NULL);
2693
2694 const float* a0 = a;
2695 float* c0 = c;
2696 const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
2697 float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
2698 if XNN_UNPREDICTABLE(mr < 2) {
2699 a1 = a0;
2700 c1 = c0;
2701 }
2702 const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
2703 float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
2704 if XNN_UNPREDICTABLE(mr <= 2) {
2705 a2 = a1;
2706 c2 = c1;
2707 }
2708 const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
2709 float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
2710 if XNN_UNPREDICTABLE(mr != 4) {
2711 a3 = a2;
2712 c3 = c2;
2713 }
2714
2715 do {
2716 __m256 vacc0x01234567 = _mm256_load_ps(w + 0);
2717 __m256 vacc0x89ABCDEF = _mm256_load_ps(w + 8);
2718 __m256 vacc1x01234567 = vacc0x01234567;
2719 __m256 vacc1x89ABCDEF = vacc0x89ABCDEF;
2720 __m256 vacc2x01234567 = vacc0x01234567;
2721 __m256 vacc2x89ABCDEF = vacc0x89ABCDEF;
2722 __m256 vacc3x01234567 = vacc0x01234567;
2723 __m256 vacc3x89ABCDEF = vacc0x89ABCDEF;
2724 w += 16;
2725
2726 size_t k = kc;
2727 while (k >= 4 * sizeof(float)) {
2728 __m256 va0 = _mm256_broadcast_ps((const __m128*) a0);
2729 a0 += 4;
2730 __m256 va1 = _mm256_broadcast_ps((const __m128*) a1);
2731 a1 += 4;
2732 __m256 va2 = _mm256_broadcast_ps((const __m128*) a2);
2733 a2 += 4;
2734 __m256 va3 = _mm256_broadcast_ps((const __m128*) a3);
2735 a3 += 4;
2736
2737
2738 const __m256 vb01234567c0 = _mm256_load_ps(w + 0);
2739 const __m256 vb89ABCDEFc0 = _mm256_load_ps(w + 8);
2740
2741 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c0, vacc0x01234567);
2742 vacc1x01234567 = _mm256_fmadd_ps(va1, vb01234567c0, vacc1x01234567);
2743 vacc2x01234567 = _mm256_fmadd_ps(va2, vb01234567c0, vacc2x01234567);
2744 vacc3x01234567 = _mm256_fmadd_ps(va3, vb01234567c0, vacc3x01234567);
2745 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc0, vacc0x89ABCDEF);
2746 vacc1x89ABCDEF = _mm256_fmadd_ps(va1, vb89ABCDEFc0, vacc1x89ABCDEF);
2747 vacc2x89ABCDEF = _mm256_fmadd_ps(va2, vb89ABCDEFc0, vacc2x89ABCDEF);
2748 vacc3x89ABCDEF = _mm256_fmadd_ps(va3, vb89ABCDEFc0, vacc3x89ABCDEF);
2749
2750 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
2751 va1 = _mm256_permute_ps(va1, _MM_SHUFFLE(0, 3, 2, 1));
2752 va2 = _mm256_permute_ps(va2, _MM_SHUFFLE(0, 3, 2, 1));
2753 va3 = _mm256_permute_ps(va3, _MM_SHUFFLE(0, 3, 2, 1));
2754
2755 const __m256 vb01234567c1 = _mm256_load_ps(w + 16);
2756 const __m256 vb89ABCDEFc1 = _mm256_load_ps(w + 24);
2757
2758 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c1, vacc0x01234567);
2759 vacc1x01234567 = _mm256_fmadd_ps(va1, vb01234567c1, vacc1x01234567);
2760 vacc2x01234567 = _mm256_fmadd_ps(va2, vb01234567c1, vacc2x01234567);
2761 vacc3x01234567 = _mm256_fmadd_ps(va3, vb01234567c1, vacc3x01234567);
2762 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc1, vacc0x89ABCDEF);
2763 vacc1x89ABCDEF = _mm256_fmadd_ps(va1, vb89ABCDEFc1, vacc1x89ABCDEF);
2764 vacc2x89ABCDEF = _mm256_fmadd_ps(va2, vb89ABCDEFc1, vacc2x89ABCDEF);
2765 vacc3x89ABCDEF = _mm256_fmadd_ps(va3, vb89ABCDEFc1, vacc3x89ABCDEF);
2766
2767 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
2768 va1 = _mm256_permute_ps(va1, _MM_SHUFFLE(0, 3, 2, 1));
2769 va2 = _mm256_permute_ps(va2, _MM_SHUFFLE(0, 3, 2, 1));
2770 va3 = _mm256_permute_ps(va3, _MM_SHUFFLE(0, 3, 2, 1));
2771
2772 const __m256 vb01234567c2 = _mm256_load_ps(w + 32);
2773 const __m256 vb89ABCDEFc2 = _mm256_load_ps(w + 40);
2774
2775 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c2, vacc0x01234567);
2776 vacc1x01234567 = _mm256_fmadd_ps(va1, vb01234567c2, vacc1x01234567);
2777 vacc2x01234567 = _mm256_fmadd_ps(va2, vb01234567c2, vacc2x01234567);
2778 vacc3x01234567 = _mm256_fmadd_ps(va3, vb01234567c2, vacc3x01234567);
2779 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc2, vacc0x89ABCDEF);
2780 vacc1x89ABCDEF = _mm256_fmadd_ps(va1, vb89ABCDEFc2, vacc1x89ABCDEF);
2781 vacc2x89ABCDEF = _mm256_fmadd_ps(va2, vb89ABCDEFc2, vacc2x89ABCDEF);
2782 vacc3x89ABCDEF = _mm256_fmadd_ps(va3, vb89ABCDEFc2, vacc3x89ABCDEF);
2783
2784 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
2785 va1 = _mm256_permute_ps(va1, _MM_SHUFFLE(0, 3, 2, 1));
2786 va2 = _mm256_permute_ps(va2, _MM_SHUFFLE(0, 3, 2, 1));
2787 va3 = _mm256_permute_ps(va3, _MM_SHUFFLE(0, 3, 2, 1));
2788
2789 const __m256 vb01234567c3 = _mm256_load_ps(w + 48);
2790 const __m256 vb89ABCDEFc3 = _mm256_load_ps(w + 56);
2791
2792 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c3, vacc0x01234567);
2793 vacc1x01234567 = _mm256_fmadd_ps(va1, vb01234567c3, vacc1x01234567);
2794 vacc2x01234567 = _mm256_fmadd_ps(va2, vb01234567c3, vacc2x01234567);
2795 vacc3x01234567 = _mm256_fmadd_ps(va3, vb01234567c3, vacc3x01234567);
2796 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc3, vacc0x89ABCDEF);
2797 vacc1x89ABCDEF = _mm256_fmadd_ps(va1, vb89ABCDEFc3, vacc1x89ABCDEF);
2798 vacc2x89ABCDEF = _mm256_fmadd_ps(va2, vb89ABCDEFc3, vacc2x89ABCDEF);
2799 vacc3x89ABCDEF = _mm256_fmadd_ps(va3, vb89ABCDEFc3, vacc3x89ABCDEF);
2800
2801
2802 w += 64;
2803 k -= 4 * sizeof(float);
2804 }
2805 if XNN_UNLIKELY(k != 0) {
2806 __m256 va0 = _mm256_broadcast_ps((const __m128*) a0);
2807 a0 = (const float*) ((uintptr_t) a0 + k);
2808 __m256 va1 = _mm256_broadcast_ps((const __m128*) a1);
2809 a1 = (const float*) ((uintptr_t) a1 + k);
2810 __m256 va2 = _mm256_broadcast_ps((const __m128*) a2);
2811 a2 = (const float*) ((uintptr_t) a2 + k);
2812 __m256 va3 = _mm256_broadcast_ps((const __m128*) a3);
2813 a3 = (const float*) ((uintptr_t) a3 + k);
2814
2815 const __m256 vzero = _mm256_setzero_ps();
2816
2817 const __m256 vb01234567c0 = _mm256_load_ps(w + 0);
2818 const __m256 vb89ABCDEFc0 = _mm256_load_ps(w + 8);
2819
2820 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c0, vzero, _CMP_NEQ_OQ)), vb01234567c0, vacc0x01234567);
2821 vacc1x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb01234567c0, vzero, _CMP_NEQ_OQ)), vb01234567c0, vacc1x01234567);
2822 vacc2x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb01234567c0, vzero, _CMP_NEQ_OQ)), vb01234567c0, vacc2x01234567);
2823 vacc3x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb01234567c0, vzero, _CMP_NEQ_OQ)), vb01234567c0, vacc3x01234567);
2824 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc0, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc0, vacc0x89ABCDEF);
2825 vacc1x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb89ABCDEFc0, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc0, vacc1x89ABCDEF);
2826 vacc2x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb89ABCDEFc0, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc0, vacc2x89ABCDEF);
2827 vacc3x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb89ABCDEFc0, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc0, vacc3x89ABCDEF);
2828
2829 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
2830 va1 = _mm256_permute_ps(va1, _MM_SHUFFLE(0, 3, 2, 1));
2831 va2 = _mm256_permute_ps(va2, _MM_SHUFFLE(0, 3, 2, 1));
2832 va3 = _mm256_permute_ps(va3, _MM_SHUFFLE(0, 3, 2, 1));
2833
2834 const __m256 vb01234567c1 = _mm256_load_ps(w + 16);
2835 const __m256 vb89ABCDEFc1 = _mm256_load_ps(w + 24);
2836
2837 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c1, vzero, _CMP_NEQ_OQ)), vb01234567c1, vacc0x01234567);
2838 vacc1x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb01234567c1, vzero, _CMP_NEQ_OQ)), vb01234567c1, vacc1x01234567);
2839 vacc2x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb01234567c1, vzero, _CMP_NEQ_OQ)), vb01234567c1, vacc2x01234567);
2840 vacc3x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb01234567c1, vzero, _CMP_NEQ_OQ)), vb01234567c1, vacc3x01234567);
2841 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc1, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc1, vacc0x89ABCDEF);
2842 vacc1x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb89ABCDEFc1, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc1, vacc1x89ABCDEF);
2843 vacc2x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb89ABCDEFc1, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc1, vacc2x89ABCDEF);
2844 vacc3x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb89ABCDEFc1, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc1, vacc3x89ABCDEF);
2845
2846 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
2847 va1 = _mm256_permute_ps(va1, _MM_SHUFFLE(0, 3, 2, 1));
2848 va2 = _mm256_permute_ps(va2, _MM_SHUFFLE(0, 3, 2, 1));
2849 va3 = _mm256_permute_ps(va3, _MM_SHUFFLE(0, 3, 2, 1));
2850
2851 const __m256 vb01234567c2 = _mm256_load_ps(w + 32);
2852 const __m256 vb89ABCDEFc2 = _mm256_load_ps(w + 40);
2853
2854 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c2, vzero, _CMP_NEQ_OQ)), vb01234567c2, vacc0x01234567);
2855 vacc1x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb01234567c2, vzero, _CMP_NEQ_OQ)), vb01234567c2, vacc1x01234567);
2856 vacc2x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb01234567c2, vzero, _CMP_NEQ_OQ)), vb01234567c2, vacc2x01234567);
2857 vacc3x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb01234567c2, vzero, _CMP_NEQ_OQ)), vb01234567c2, vacc3x01234567);
2858 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc2, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc2, vacc0x89ABCDEF);
2859 vacc1x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb89ABCDEFc2, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc2, vacc1x89ABCDEF);
2860 vacc2x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb89ABCDEFc2, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc2, vacc2x89ABCDEF);
2861 vacc3x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb89ABCDEFc2, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc2, vacc3x89ABCDEF);
2862
2863 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
2864 va1 = _mm256_permute_ps(va1, _MM_SHUFFLE(0, 3, 2, 1));
2865 va2 = _mm256_permute_ps(va2, _MM_SHUFFLE(0, 3, 2, 1));
2866 va3 = _mm256_permute_ps(va3, _MM_SHUFFLE(0, 3, 2, 1));
2867
2868 const __m256 vb01234567c3 = _mm256_load_ps(w + 48);
2869 const __m256 vb89ABCDEFc3 = _mm256_load_ps(w + 56);
2870
2871 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c3, vzero, _CMP_NEQ_OQ)), vb01234567c3, vacc0x01234567);
2872 vacc1x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb01234567c3, vzero, _CMP_NEQ_OQ)), vb01234567c3, vacc1x01234567);
2873 vacc2x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb01234567c3, vzero, _CMP_NEQ_OQ)), vb01234567c3, vacc2x01234567);
2874 vacc3x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb01234567c3, vzero, _CMP_NEQ_OQ)), vb01234567c3, vacc3x01234567);
2875 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc3, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc3, vacc0x89ABCDEF);
2876 vacc1x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb89ABCDEFc3, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc3, vacc1x89ABCDEF);
2877 vacc2x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb89ABCDEFc3, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc3, vacc2x89ABCDEF);
2878 vacc3x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb89ABCDEFc3, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc3, vacc3x89ABCDEF);
2879
2880
2881 w += 64;
2882 }
2883
2884 const __m256 vmin = _mm256_load_ps(params->avx.min);
2885 vacc0x01234567 = _mm256_max_ps(vacc0x01234567, vmin);
2886 vacc1x01234567 = _mm256_max_ps(vacc1x01234567, vmin);
2887 vacc2x01234567 = _mm256_max_ps(vacc2x01234567, vmin);
2888 vacc3x01234567 = _mm256_max_ps(vacc3x01234567, vmin);
2889 vacc0x89ABCDEF = _mm256_max_ps(vacc0x89ABCDEF, vmin);
2890 vacc1x89ABCDEF = _mm256_max_ps(vacc1x89ABCDEF, vmin);
2891 vacc2x89ABCDEF = _mm256_max_ps(vacc2x89ABCDEF, vmin);
2892 vacc3x89ABCDEF = _mm256_max_ps(vacc3x89ABCDEF, vmin);
2893
2894 const __m256 vmax = _mm256_load_ps(params->avx.max);
2895 vacc0x01234567 = _mm256_min_ps(vacc0x01234567, vmax);
2896 vacc1x01234567 = _mm256_min_ps(vacc1x01234567, vmax);
2897 vacc2x01234567 = _mm256_min_ps(vacc2x01234567, vmax);
2898 vacc3x01234567 = _mm256_min_ps(vacc3x01234567, vmax);
2899 vacc0x89ABCDEF = _mm256_min_ps(vacc0x89ABCDEF, vmax);
2900 vacc1x89ABCDEF = _mm256_min_ps(vacc1x89ABCDEF, vmax);
2901 vacc2x89ABCDEF = _mm256_min_ps(vacc2x89ABCDEF, vmax);
2902 vacc3x89ABCDEF = _mm256_min_ps(vacc3x89ABCDEF, vmax);
2903
2904 if XNN_LIKELY(nc >= 16) {
2905 _mm256_storeu_ps(c3, vacc3x01234567);
2906 _mm256_storeu_ps(c3 + 8, vacc3x89ABCDEF);
2907 c3 = (float*) ((uintptr_t) c3 + cn_stride);
2908 _mm256_storeu_ps(c2, vacc2x01234567);
2909 _mm256_storeu_ps(c2 + 8, vacc2x89ABCDEF);
2910 c2 = (float*) ((uintptr_t) c2 + cn_stride);
2911 _mm256_storeu_ps(c1, vacc1x01234567);
2912 _mm256_storeu_ps(c1 + 8, vacc1x89ABCDEF);
2913 c1 = (float*) ((uintptr_t) c1 + cn_stride);
2914 _mm256_storeu_ps(c0, vacc0x01234567);
2915 _mm256_storeu_ps(c0 + 8, vacc0x89ABCDEF);
2916 c0 = (float*) ((uintptr_t) c0 + cn_stride);
2917
2918 a3 = (const float*) ((uintptr_t) a3 - kc);
2919 a2 = (const float*) ((uintptr_t) a2 - kc);
2920 a1 = (const float*) ((uintptr_t) a1 - kc);
2921 a0 = (const float*) ((uintptr_t) a0 - kc);
2922
2923 nc -= 16;
2924 } else {
2925 if (nc & 8) {
2926 _mm256_storeu_ps(c3, vacc3x01234567);
2927 _mm256_storeu_ps(c2, vacc2x01234567);
2928 _mm256_storeu_ps(c1, vacc1x01234567);
2929 _mm256_storeu_ps(c0, vacc0x01234567);
2930
2931 vacc3x01234567 = vacc3x89ABCDEF;
2932 vacc2x01234567 = vacc2x89ABCDEF;
2933 vacc1x01234567 = vacc1x89ABCDEF;
2934 vacc0x01234567 = vacc0x89ABCDEF;
2935
2936 c3 += 8;
2937 c2 += 8;
2938 c1 += 8;
2939 c0 += 8;
2940 }
2941 __m128 vacc3x0123 = _mm256_castps256_ps128(vacc3x01234567);
2942 __m128 vacc2x0123 = _mm256_castps256_ps128(vacc2x01234567);
2943 __m128 vacc1x0123 = _mm256_castps256_ps128(vacc1x01234567);
2944 __m128 vacc0x0123 = _mm256_castps256_ps128(vacc0x01234567);
2945 if (nc & 4) {
2946 _mm_storeu_ps(c3, vacc3x0123);
2947 _mm_storeu_ps(c2, vacc2x0123);
2948 _mm_storeu_ps(c1, vacc1x0123);
2949 _mm_storeu_ps(c0, vacc0x0123);
2950
2951 vacc3x0123 = _mm256_extractf128_ps(vacc3x01234567, 1);
2952 vacc2x0123 = _mm256_extractf128_ps(vacc2x01234567, 1);
2953 vacc1x0123 = _mm256_extractf128_ps(vacc1x01234567, 1);
2954 vacc0x0123 = _mm256_extractf128_ps(vacc0x01234567, 1);
2955
2956 c3 += 4;
2957 c2 += 4;
2958 c1 += 4;
2959 c0 += 4;
2960 }
2961 if (nc & 2) {
2962 _mm_storel_pi((__m64*) c3, vacc3x0123);
2963 _mm_storel_pi((__m64*) c2, vacc2x0123);
2964 _mm_storel_pi((__m64*) c1, vacc1x0123);
2965 _mm_storel_pi((__m64*) c0, vacc0x0123);
2966
2967 vacc3x0123 = _mm_movehl_ps(vacc3x0123, vacc3x0123);
2968 vacc2x0123 = _mm_movehl_ps(vacc2x0123, vacc2x0123);
2969 vacc1x0123 = _mm_movehl_ps(vacc1x0123, vacc1x0123);
2970 vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
2971
2972 c3 += 2;
2973 c2 += 2;
2974 c1 += 2;
2975 c0 += 2;
2976 }
2977 if (nc & 1) {
2978 _mm_store_ss(c3, vacc3x0123);
2979 _mm_store_ss(c2, vacc2x0123);
2980 _mm_store_ss(c1, vacc1x0123);
2981 _mm_store_ss(c0, vacc0x0123);
2982 }
2983
2984 nc = 0;
2985 }
2986 } while (nc != 0);
2987 }
2988
xnn_f32_gemm_minmax_ukernel_5x16__fma3_broadcast(size_t mr,size_t nc,size_t kc,const float * restrict a,size_t a_stride,const float * restrict w,float * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])2989 void xnn_f32_gemm_minmax_ukernel_5x16__fma3_broadcast(
2990 size_t mr,
2991 size_t nc,
2992 size_t kc,
2993 const float*restrict a,
2994 size_t a_stride,
2995 const float*restrict w,
2996 float*restrict c,
2997 size_t cm_stride,
2998 size_t cn_stride,
2999 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
3000 {
3001 assert(mr != 0);
3002 assert(mr <= 5);
3003 assert(nc != 0);
3004 assert(kc != 0);
3005 assert(kc % sizeof(float) == 0);
3006 assert(a != NULL);
3007 assert(w != NULL);
3008 assert(c != NULL);
3009
3010 const float* a0 = a;
3011 float* c0 = c;
3012 const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
3013 float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
3014 if XNN_UNPREDICTABLE(mr < 2) {
3015 a1 = a0;
3016 c1 = c0;
3017 }
3018 const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
3019 float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
3020 if XNN_UNPREDICTABLE(mr <= 2) {
3021 a2 = a1;
3022 c2 = c1;
3023 }
3024 const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
3025 float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
3026 if XNN_UNPREDICTABLE(mr < 4) {
3027 a3 = a2;
3028 c3 = c2;
3029 }
3030 const float* a4 = (const float*) ((uintptr_t) a3 + a_stride);
3031 float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
3032 if XNN_UNPREDICTABLE(mr <= 4) {
3033 a4 = a3;
3034 c4 = c3;
3035 }
3036
3037 do {
3038 __m256 vacc0x01234567 = _mm256_load_ps(w + 0);
3039 __m256 vacc0x89ABCDEF = _mm256_load_ps(w + 8);
3040 __m256 vacc1x01234567 = vacc0x01234567;
3041 __m256 vacc1x89ABCDEF = vacc0x89ABCDEF;
3042 __m256 vacc2x01234567 = vacc0x01234567;
3043 __m256 vacc2x89ABCDEF = vacc0x89ABCDEF;
3044 __m256 vacc3x01234567 = vacc0x01234567;
3045 __m256 vacc3x89ABCDEF = vacc0x89ABCDEF;
3046 __m256 vacc4x01234567 = vacc0x01234567;
3047 __m256 vacc4x89ABCDEF = vacc0x89ABCDEF;
3048 w += 16;
3049
3050 size_t k = kc;
3051 do {
3052 const __m256 va0 = _mm256_broadcast_ss(a0);
3053 a0 += 1;
3054 const __m256 va1 = _mm256_broadcast_ss(a1);
3055 a1 += 1;
3056 const __m256 va2 = _mm256_broadcast_ss(a2);
3057 a2 += 1;
3058 const __m256 va3 = _mm256_broadcast_ss(a3);
3059 a3 += 1;
3060 const __m256 va4 = _mm256_broadcast_ss(a4);
3061 a4 += 1;
3062
3063 const __m256 vb01234567 = _mm256_load_ps(w);
3064 const __m256 vb89ABCDEF = _mm256_load_ps(w + 8);
3065 w += 16;
3066
3067 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567, vacc0x01234567);
3068 vacc1x01234567 = _mm256_fmadd_ps(va1, vb01234567, vacc1x01234567);
3069 vacc2x01234567 = _mm256_fmadd_ps(va2, vb01234567, vacc2x01234567);
3070 vacc3x01234567 = _mm256_fmadd_ps(va3, vb01234567, vacc3x01234567);
3071 vacc4x01234567 = _mm256_fmadd_ps(va4, vb01234567, vacc4x01234567);
3072 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEF, vacc0x89ABCDEF);
3073 vacc1x89ABCDEF = _mm256_fmadd_ps(va1, vb89ABCDEF, vacc1x89ABCDEF);
3074 vacc2x89ABCDEF = _mm256_fmadd_ps(va2, vb89ABCDEF, vacc2x89ABCDEF);
3075 vacc3x89ABCDEF = _mm256_fmadd_ps(va3, vb89ABCDEF, vacc3x89ABCDEF);
3076 vacc4x89ABCDEF = _mm256_fmadd_ps(va4, vb89ABCDEF, vacc4x89ABCDEF);
3077
3078 k -= sizeof(float);
3079 } while (k != 0);
3080
3081 const __m256 vmin = _mm256_load_ps(params->avx.min);
3082 vacc0x01234567 = _mm256_max_ps(vacc0x01234567, vmin);
3083 vacc1x01234567 = _mm256_max_ps(vacc1x01234567, vmin);
3084 vacc2x01234567 = _mm256_max_ps(vacc2x01234567, vmin);
3085 vacc3x01234567 = _mm256_max_ps(vacc3x01234567, vmin);
3086 vacc4x01234567 = _mm256_max_ps(vacc4x01234567, vmin);
3087 vacc0x89ABCDEF = _mm256_max_ps(vacc0x89ABCDEF, vmin);
3088 vacc1x89ABCDEF = _mm256_max_ps(vacc1x89ABCDEF, vmin);
3089 vacc2x89ABCDEF = _mm256_max_ps(vacc2x89ABCDEF, vmin);
3090 vacc3x89ABCDEF = _mm256_max_ps(vacc3x89ABCDEF, vmin);
3091 vacc4x89ABCDEF = _mm256_max_ps(vacc4x89ABCDEF, vmin);
3092
3093 const __m256 vmax = _mm256_load_ps(params->avx.max);
3094 vacc0x01234567 = _mm256_min_ps(vacc0x01234567, vmax);
3095 vacc1x01234567 = _mm256_min_ps(vacc1x01234567, vmax);
3096 vacc2x01234567 = _mm256_min_ps(vacc2x01234567, vmax);
3097 vacc3x01234567 = _mm256_min_ps(vacc3x01234567, vmax);
3098 vacc4x01234567 = _mm256_min_ps(vacc4x01234567, vmax);
3099 vacc0x89ABCDEF = _mm256_min_ps(vacc0x89ABCDEF, vmax);
3100 vacc1x89ABCDEF = _mm256_min_ps(vacc1x89ABCDEF, vmax);
3101 vacc2x89ABCDEF = _mm256_min_ps(vacc2x89ABCDEF, vmax);
3102 vacc3x89ABCDEF = _mm256_min_ps(vacc3x89ABCDEF, vmax);
3103 vacc4x89ABCDEF = _mm256_min_ps(vacc4x89ABCDEF, vmax);
3104
3105 if XNN_LIKELY(nc >= 16) {
3106 _mm256_storeu_ps(c4, vacc4x01234567);
3107 _mm256_storeu_ps(c4 + 8, vacc4x89ABCDEF);
3108 c4 = (float*) ((uintptr_t) c4 + cn_stride);
3109 _mm256_storeu_ps(c3, vacc3x01234567);
3110 _mm256_storeu_ps(c3 + 8, vacc3x89ABCDEF);
3111 c3 = (float*) ((uintptr_t) c3 + cn_stride);
3112 _mm256_storeu_ps(c2, vacc2x01234567);
3113 _mm256_storeu_ps(c2 + 8, vacc2x89ABCDEF);
3114 c2 = (float*) ((uintptr_t) c2 + cn_stride);
3115 _mm256_storeu_ps(c1, vacc1x01234567);
3116 _mm256_storeu_ps(c1 + 8, vacc1x89ABCDEF);
3117 c1 = (float*) ((uintptr_t) c1 + cn_stride);
3118 _mm256_storeu_ps(c0, vacc0x01234567);
3119 _mm256_storeu_ps(c0 + 8, vacc0x89ABCDEF);
3120 c0 = (float*) ((uintptr_t) c0 + cn_stride);
3121
3122 a4 = (const float*) ((uintptr_t) a4 - kc);
3123 a3 = (const float*) ((uintptr_t) a3 - kc);
3124 a2 = (const float*) ((uintptr_t) a2 - kc);
3125 a1 = (const float*) ((uintptr_t) a1 - kc);
3126 a0 = (const float*) ((uintptr_t) a0 - kc);
3127
3128 nc -= 16;
3129 } else {
3130 if (nc & 8) {
3131 _mm256_storeu_ps(c4, vacc4x01234567);
3132 _mm256_storeu_ps(c3, vacc3x01234567);
3133 _mm256_storeu_ps(c2, vacc2x01234567);
3134 _mm256_storeu_ps(c1, vacc1x01234567);
3135 _mm256_storeu_ps(c0, vacc0x01234567);
3136
3137 vacc4x01234567 = vacc4x89ABCDEF;
3138 vacc3x01234567 = vacc3x89ABCDEF;
3139 vacc2x01234567 = vacc2x89ABCDEF;
3140 vacc1x01234567 = vacc1x89ABCDEF;
3141 vacc0x01234567 = vacc0x89ABCDEF;
3142
3143 c4 += 8;
3144 c3 += 8;
3145 c2 += 8;
3146 c1 += 8;
3147 c0 += 8;
3148 }
3149 __m128 vacc4x0123 = _mm256_castps256_ps128(vacc4x01234567);
3150 __m128 vacc3x0123 = _mm256_castps256_ps128(vacc3x01234567);
3151 __m128 vacc2x0123 = _mm256_castps256_ps128(vacc2x01234567);
3152 __m128 vacc1x0123 = _mm256_castps256_ps128(vacc1x01234567);
3153 __m128 vacc0x0123 = _mm256_castps256_ps128(vacc0x01234567);
3154 if (nc & 4) {
3155 _mm_storeu_ps(c4, vacc4x0123);
3156 _mm_storeu_ps(c3, vacc3x0123);
3157 _mm_storeu_ps(c2, vacc2x0123);
3158 _mm_storeu_ps(c1, vacc1x0123);
3159 _mm_storeu_ps(c0, vacc0x0123);
3160
3161 vacc4x0123 = _mm256_extractf128_ps(vacc4x01234567, 1);
3162 vacc3x0123 = _mm256_extractf128_ps(vacc3x01234567, 1);
3163 vacc2x0123 = _mm256_extractf128_ps(vacc2x01234567, 1);
3164 vacc1x0123 = _mm256_extractf128_ps(vacc1x01234567, 1);
3165 vacc0x0123 = _mm256_extractf128_ps(vacc0x01234567, 1);
3166
3167 c4 += 4;
3168 c3 += 4;
3169 c2 += 4;
3170 c1 += 4;
3171 c0 += 4;
3172 }
3173 if (nc & 2) {
3174 _mm_storel_pi((__m64*) c4, vacc4x0123);
3175 _mm_storel_pi((__m64*) c3, vacc3x0123);
3176 _mm_storel_pi((__m64*) c2, vacc2x0123);
3177 _mm_storel_pi((__m64*) c1, vacc1x0123);
3178 _mm_storel_pi((__m64*) c0, vacc0x0123);
3179
3180 vacc4x0123 = _mm_movehl_ps(vacc4x0123, vacc4x0123);
3181 vacc3x0123 = _mm_movehl_ps(vacc3x0123, vacc3x0123);
3182 vacc2x0123 = _mm_movehl_ps(vacc2x0123, vacc2x0123);
3183 vacc1x0123 = _mm_movehl_ps(vacc1x0123, vacc1x0123);
3184 vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
3185
3186 c4 += 2;
3187 c3 += 2;
3188 c2 += 2;
3189 c1 += 2;
3190 c0 += 2;
3191 }
3192 if (nc & 1) {
3193 _mm_store_ss(c4, vacc4x0123);
3194 _mm_store_ss(c3, vacc3x0123);
3195 _mm_store_ss(c2, vacc2x0123);
3196 _mm_store_ss(c1, vacc1x0123);
3197 _mm_store_ss(c0, vacc0x0123);
3198 }
3199
3200 nc = 0;
3201 }
3202 } while (nc != 0);
3203 }
3204
xnn_f32_igemm_minmax_ukernel_1x16__fma3_broadcast(size_t mr,size_t nc,size_t kc,size_t ks,const float ** restrict a,const float * restrict w,float * restrict c,size_t cm_stride,size_t cn_stride,size_t a_offset,const float * zero,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])3205 void xnn_f32_igemm_minmax_ukernel_1x16__fma3_broadcast(
3206 size_t mr,
3207 size_t nc,
3208 size_t kc,
3209 size_t ks,
3210 const float**restrict a,
3211 const float*restrict w,
3212 float*restrict c,
3213 size_t cm_stride,
3214 size_t cn_stride,
3215 size_t a_offset,
3216 const float* zero,
3217 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
3218 {
3219 assert(mr != 0);
3220 assert(mr <= 1);
3221 assert(nc != 0);
3222 assert(kc != 0);
3223 assert(kc % sizeof(float) == 0);
3224 assert(ks != 0);
3225 assert(ks % (1 * sizeof(void*)) == 0);
3226 assert(a_offset % sizeof(float) == 0);
3227 assert(a != NULL);
3228 assert(w != NULL);
3229 assert(c != NULL);
3230
3231 float* c0 = c;
3232
3233 do {
3234 __m256 vacc0x01234567 = _mm256_load_ps(w);
3235 __m256 vacc0x89ABCDEF = _mm256_load_ps(w + 8);
3236 w += 16;
3237
3238 size_t p = ks;
3239 do {
3240 const float* restrict a0 = a[0];
3241 assert(a0 != NULL);
3242 if XNN_UNPREDICTABLE(a0 != zero) {
3243 a0 = (const float*) ((uintptr_t) a0 + a_offset);
3244 }
3245 a += 1;
3246
3247 size_t k = kc;
3248 do {
3249 const __m256 vb01234567 = _mm256_load_ps(w);
3250 const __m256 vb89ABCDEF = _mm256_load_ps(w + 8);
3251 w += 16;
3252
3253 const __m256 va0 = _mm256_broadcast_ss(a0);
3254 a0 += 1;
3255
3256 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567, vacc0x01234567);
3257 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEF, vacc0x89ABCDEF);
3258 k -= sizeof(float);
3259 } while (k != 0);
3260 p -= 1 * sizeof(void*);
3261 } while (p != 0);
3262
3263 const __m256 vmin = _mm256_load_ps(params->avx.min);
3264 vacc0x01234567 = _mm256_max_ps(vacc0x01234567, vmin);
3265 vacc0x89ABCDEF = _mm256_max_ps(vacc0x89ABCDEF, vmin);
3266
3267 const __m256 vmax = _mm256_load_ps(params->avx.max);
3268 vacc0x01234567 = _mm256_min_ps(vacc0x01234567, vmax);
3269 vacc0x89ABCDEF = _mm256_min_ps(vacc0x89ABCDEF, vmax);
3270
3271 if XNN_LIKELY(nc >= 16) {
3272 _mm256_storeu_ps(c0, vacc0x01234567);
3273 _mm256_storeu_ps(c0 + 8, vacc0x89ABCDEF);
3274 c0 = (float*) ((uintptr_t) c0 + cn_stride);
3275
3276 a = (const float**restrict) ((uintptr_t) a - ks);
3277 nc -= 16;
3278 } else {
3279 if (nc & 8) {
3280 _mm256_storeu_ps(c0, vacc0x01234567);
3281
3282 vacc0x01234567 = vacc0x89ABCDEF;
3283
3284 c0 += 8;
3285 }
3286 __m128 vacc0x0123 = _mm256_castps256_ps128(vacc0x01234567);
3287 if (nc & 4) {
3288 _mm_storeu_ps(c0, vacc0x0123);
3289
3290 vacc0x0123 = _mm256_extractf128_ps(vacc0x01234567, 1);
3291
3292 c0 += 4;
3293 }
3294 if (nc & 2) {
3295 _mm_storel_pi((__m64*) c0, vacc0x0123);
3296
3297 vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
3298
3299 c0 += 2;
3300 }
3301 if (nc & 1) {
3302 _mm_store_ss(c0, vacc0x0123);
3303 }
3304
3305 nc = 0;
3306 }
3307 } while (nc != 0);
3308 }
3309
xnn_f32_igemm_minmax_ukernel_1x16s4__fma3_broadcast(size_t mr,size_t nc,size_t kc,size_t ks,const float ** restrict a,const float * restrict w,float * restrict c,size_t cm_stride,size_t cn_stride,size_t a_offset,const float * zero,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])3310 void xnn_f32_igemm_minmax_ukernel_1x16s4__fma3_broadcast(
3311 size_t mr,
3312 size_t nc,
3313 size_t kc,
3314 size_t ks,
3315 const float**restrict a,
3316 const float*restrict w,
3317 float*restrict c,
3318 size_t cm_stride,
3319 size_t cn_stride,
3320 size_t a_offset,
3321 const float* zero,
3322 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
3323 {
3324 assert(mr != 0);
3325 assert(mr <= 1);
3326 assert(nc != 0);
3327 assert(kc != 0);
3328 assert(kc % sizeof(float) == 0);
3329 assert(ks != 0);
3330 assert(ks % (1 * sizeof(void*)) == 0);
3331 assert(a_offset % sizeof(float) == 0);
3332 assert(a != NULL);
3333 assert(w != NULL);
3334 assert(c != NULL);
3335
3336 float* c0 = c;
3337
3338 do {
3339 __m256 vacc0x01234567 = _mm256_load_ps(w);
3340 __m256 vacc0x89ABCDEF = _mm256_load_ps(w + 8);
3341 w += 16;
3342
3343 size_t p = ks;
3344 do {
3345 const float* restrict a0 = a[0];
3346 assert(a0 != NULL);
3347 if XNN_UNPREDICTABLE(a0 != zero) {
3348 a0 = (const float*) ((uintptr_t) a0 + a_offset);
3349 }
3350 a += 1;
3351
3352 size_t k = kc;
3353 while (k >= 4 * sizeof(float)) {
3354 __m256 va0 = _mm256_broadcast_ps((const __m128*) a0);
3355 a0 += 4;
3356
3357
3358 const __m256 vb01234567c0 = _mm256_load_ps(w + 0);
3359 const __m256 vb89ABCDEFc0 = _mm256_load_ps(w + 8);
3360
3361 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c0, vacc0x01234567);
3362 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc0, vacc0x89ABCDEF);
3363
3364 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
3365
3366 const __m256 vb01234567c1 = _mm256_load_ps(w + 16);
3367 const __m256 vb89ABCDEFc1 = _mm256_load_ps(w + 24);
3368
3369 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c1, vacc0x01234567);
3370 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc1, vacc0x89ABCDEF);
3371
3372 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
3373
3374 const __m256 vb01234567c2 = _mm256_load_ps(w + 32);
3375 const __m256 vb89ABCDEFc2 = _mm256_load_ps(w + 40);
3376
3377 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c2, vacc0x01234567);
3378 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc2, vacc0x89ABCDEF);
3379
3380 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
3381
3382 const __m256 vb01234567c3 = _mm256_load_ps(w + 48);
3383 const __m256 vb89ABCDEFc3 = _mm256_load_ps(w + 56);
3384
3385 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c3, vacc0x01234567);
3386 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc3, vacc0x89ABCDEF);
3387
3388
3389 w += 64;
3390 k -= 4 * sizeof(float);
3391 }
3392 if XNN_UNLIKELY(k != 0) {
3393 __m256 va0 = _mm256_broadcast_ps((const __m128*) a0);
3394 a0 = (const float*) ((uintptr_t) a0 + k);
3395
3396 const __m256 vzero = _mm256_setzero_ps();
3397
3398 const __m256 vb01234567c0 = _mm256_load_ps(w + 0);
3399 const __m256 vb89ABCDEFc0 = _mm256_load_ps(w + 8);
3400
3401 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c0, vzero, _CMP_NEQ_OQ)), vb01234567c0, vacc0x01234567);
3402 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc0, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc0, vacc0x89ABCDEF);
3403
3404 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
3405
3406 const __m256 vb01234567c1 = _mm256_load_ps(w + 16);
3407 const __m256 vb89ABCDEFc1 = _mm256_load_ps(w + 24);
3408
3409 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c1, vzero, _CMP_NEQ_OQ)), vb01234567c1, vacc0x01234567);
3410 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc1, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc1, vacc0x89ABCDEF);
3411
3412 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
3413
3414 const __m256 vb01234567c2 = _mm256_load_ps(w + 32);
3415 const __m256 vb89ABCDEFc2 = _mm256_load_ps(w + 40);
3416
3417 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c2, vzero, _CMP_NEQ_OQ)), vb01234567c2, vacc0x01234567);
3418 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc2, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc2, vacc0x89ABCDEF);
3419
3420 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
3421
3422 const __m256 vb01234567c3 = _mm256_load_ps(w + 48);
3423 const __m256 vb89ABCDEFc3 = _mm256_load_ps(w + 56);
3424
3425 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c3, vzero, _CMP_NEQ_OQ)), vb01234567c3, vacc0x01234567);
3426 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc3, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc3, vacc0x89ABCDEF);
3427
3428
3429 w += 64;
3430 }
3431 p -= 1 * sizeof(void*);
3432 } while (p != 0);
3433
3434 const __m256 vmin = _mm256_load_ps(params->avx.min);
3435 vacc0x01234567 = _mm256_max_ps(vacc0x01234567, vmin);
3436 vacc0x89ABCDEF = _mm256_max_ps(vacc0x89ABCDEF, vmin);
3437
3438 const __m256 vmax = _mm256_load_ps(params->avx.max);
3439 vacc0x01234567 = _mm256_min_ps(vacc0x01234567, vmax);
3440 vacc0x89ABCDEF = _mm256_min_ps(vacc0x89ABCDEF, vmax);
3441
3442 if XNN_LIKELY(nc >= 16) {
3443 _mm256_storeu_ps(c0, vacc0x01234567);
3444 _mm256_storeu_ps(c0 + 8, vacc0x89ABCDEF);
3445 c0 = (float*) ((uintptr_t) c0 + cn_stride);
3446
3447 a = (const float**restrict) ((uintptr_t) a - ks);
3448 nc -= 16;
3449 } else {
3450 if (nc & 8) {
3451 _mm256_storeu_ps(c0, vacc0x01234567);
3452
3453 vacc0x01234567 = vacc0x89ABCDEF;
3454
3455 c0 += 8;
3456 }
3457 __m128 vacc0x0123 = _mm256_castps256_ps128(vacc0x01234567);
3458 if (nc & 4) {
3459 _mm_storeu_ps(c0, vacc0x0123);
3460
3461 vacc0x0123 = _mm256_extractf128_ps(vacc0x01234567, 1);
3462
3463 c0 += 4;
3464 }
3465 if (nc & 2) {
3466 _mm_storel_pi((__m64*) c0, vacc0x0123);
3467
3468 vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
3469
3470 c0 += 2;
3471 }
3472 if (nc & 1) {
3473 _mm_store_ss(c0, vacc0x0123);
3474 }
3475
3476 nc = 0;
3477 }
3478 } while (nc != 0);
3479 }
3480
xnn_f32_igemm_minmax_ukernel_4x16s4__fma3_broadcast(size_t mr,size_t nc,size_t kc,size_t ks,const float ** restrict a,const float * restrict w,float * restrict c,size_t cm_stride,size_t cn_stride,size_t a_offset,const float * zero,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])3481 void xnn_f32_igemm_minmax_ukernel_4x16s4__fma3_broadcast(
3482 size_t mr,
3483 size_t nc,
3484 size_t kc,
3485 size_t ks,
3486 const float**restrict a,
3487 const float*restrict w,
3488 float*restrict c,
3489 size_t cm_stride,
3490 size_t cn_stride,
3491 size_t a_offset,
3492 const float* zero,
3493 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
3494 {
3495 assert(mr != 0);
3496 assert(mr <= 4);
3497 assert(nc != 0);
3498 assert(kc != 0);
3499 assert(kc % sizeof(float) == 0);
3500 assert(ks != 0);
3501 assert(ks % (4 * sizeof(void*)) == 0);
3502 assert(a_offset % sizeof(float) == 0);
3503 assert(a != NULL);
3504 assert(w != NULL);
3505 assert(c != NULL);
3506
3507 float* c0 = c;
3508 float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
3509 if XNN_UNPREDICTABLE(mr < 2) {
3510 c1 = c0;
3511 }
3512 float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
3513 if XNN_UNPREDICTABLE(mr <= 2) {
3514 c2 = c1;
3515 }
3516 float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
3517 if XNN_UNPREDICTABLE(mr != 4) {
3518 c3 = c2;
3519 }
3520
3521 do {
3522 __m256 vacc0x01234567 = _mm256_load_ps(w);
3523 __m256 vacc0x89ABCDEF = _mm256_load_ps(w + 8);
3524 __m256 vacc1x01234567 = vacc0x01234567;
3525 __m256 vacc1x89ABCDEF = vacc0x89ABCDEF;
3526 __m256 vacc2x01234567 = vacc0x01234567;
3527 __m256 vacc2x89ABCDEF = vacc0x89ABCDEF;
3528 __m256 vacc3x01234567 = vacc0x01234567;
3529 __m256 vacc3x89ABCDEF = vacc0x89ABCDEF;
3530 w += 16;
3531
3532 size_t p = ks;
3533 do {
3534 const float* restrict a0 = a[0];
3535 assert(a0 != NULL);
3536 if XNN_UNPREDICTABLE(a0 != zero) {
3537 a0 = (const float*) ((uintptr_t) a0 + a_offset);
3538 }
3539 const float* restrict a1 = a[1];
3540 assert(a1 != NULL);
3541 if XNN_UNPREDICTABLE(a1 != zero) {
3542 a1 = (const float*) ((uintptr_t) a1 + a_offset);
3543 }
3544 const float* restrict a2 = a[2];
3545 assert(a2 != NULL);
3546 if XNN_UNPREDICTABLE(a2 != zero) {
3547 a2 = (const float*) ((uintptr_t) a2 + a_offset);
3548 }
3549 const float* restrict a3 = a[3];
3550 assert(a3 != NULL);
3551 if XNN_UNPREDICTABLE(a3 != zero) {
3552 a3 = (const float*) ((uintptr_t) a3 + a_offset);
3553 }
3554 a += 4;
3555
3556 size_t k = kc;
3557 while (k >= 4 * sizeof(float)) {
3558 __m256 va0 = _mm256_broadcast_ps((const __m128*) a0);
3559 a0 += 4;
3560 __m256 va1 = _mm256_broadcast_ps((const __m128*) a1);
3561 a1 += 4;
3562 __m256 va2 = _mm256_broadcast_ps((const __m128*) a2);
3563 a2 += 4;
3564 __m256 va3 = _mm256_broadcast_ps((const __m128*) a3);
3565 a3 += 4;
3566
3567
3568 const __m256 vb01234567c0 = _mm256_load_ps(w + 0);
3569 const __m256 vb89ABCDEFc0 = _mm256_load_ps(w + 8);
3570
3571 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c0, vacc0x01234567);
3572 vacc1x01234567 = _mm256_fmadd_ps(va1, vb01234567c0, vacc1x01234567);
3573 vacc2x01234567 = _mm256_fmadd_ps(va2, vb01234567c0, vacc2x01234567);
3574 vacc3x01234567 = _mm256_fmadd_ps(va3, vb01234567c0, vacc3x01234567);
3575 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc0, vacc0x89ABCDEF);
3576 vacc1x89ABCDEF = _mm256_fmadd_ps(va1, vb89ABCDEFc0, vacc1x89ABCDEF);
3577 vacc2x89ABCDEF = _mm256_fmadd_ps(va2, vb89ABCDEFc0, vacc2x89ABCDEF);
3578 vacc3x89ABCDEF = _mm256_fmadd_ps(va3, vb89ABCDEFc0, vacc3x89ABCDEF);
3579
3580 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
3581 va1 = _mm256_permute_ps(va1, _MM_SHUFFLE(0, 3, 2, 1));
3582 va2 = _mm256_permute_ps(va2, _MM_SHUFFLE(0, 3, 2, 1));
3583 va3 = _mm256_permute_ps(va3, _MM_SHUFFLE(0, 3, 2, 1));
3584
3585 const __m256 vb01234567c1 = _mm256_load_ps(w + 16);
3586 const __m256 vb89ABCDEFc1 = _mm256_load_ps(w + 24);
3587
3588 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c1, vacc0x01234567);
3589 vacc1x01234567 = _mm256_fmadd_ps(va1, vb01234567c1, vacc1x01234567);
3590 vacc2x01234567 = _mm256_fmadd_ps(va2, vb01234567c1, vacc2x01234567);
3591 vacc3x01234567 = _mm256_fmadd_ps(va3, vb01234567c1, vacc3x01234567);
3592 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc1, vacc0x89ABCDEF);
3593 vacc1x89ABCDEF = _mm256_fmadd_ps(va1, vb89ABCDEFc1, vacc1x89ABCDEF);
3594 vacc2x89ABCDEF = _mm256_fmadd_ps(va2, vb89ABCDEFc1, vacc2x89ABCDEF);
3595 vacc3x89ABCDEF = _mm256_fmadd_ps(va3, vb89ABCDEFc1, vacc3x89ABCDEF);
3596
3597 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
3598 va1 = _mm256_permute_ps(va1, _MM_SHUFFLE(0, 3, 2, 1));
3599 va2 = _mm256_permute_ps(va2, _MM_SHUFFLE(0, 3, 2, 1));
3600 va3 = _mm256_permute_ps(va3, _MM_SHUFFLE(0, 3, 2, 1));
3601
3602 const __m256 vb01234567c2 = _mm256_load_ps(w + 32);
3603 const __m256 vb89ABCDEFc2 = _mm256_load_ps(w + 40);
3604
3605 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c2, vacc0x01234567);
3606 vacc1x01234567 = _mm256_fmadd_ps(va1, vb01234567c2, vacc1x01234567);
3607 vacc2x01234567 = _mm256_fmadd_ps(va2, vb01234567c2, vacc2x01234567);
3608 vacc3x01234567 = _mm256_fmadd_ps(va3, vb01234567c2, vacc3x01234567);
3609 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc2, vacc0x89ABCDEF);
3610 vacc1x89ABCDEF = _mm256_fmadd_ps(va1, vb89ABCDEFc2, vacc1x89ABCDEF);
3611 vacc2x89ABCDEF = _mm256_fmadd_ps(va2, vb89ABCDEFc2, vacc2x89ABCDEF);
3612 vacc3x89ABCDEF = _mm256_fmadd_ps(va3, vb89ABCDEFc2, vacc3x89ABCDEF);
3613
3614 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
3615 va1 = _mm256_permute_ps(va1, _MM_SHUFFLE(0, 3, 2, 1));
3616 va2 = _mm256_permute_ps(va2, _MM_SHUFFLE(0, 3, 2, 1));
3617 va3 = _mm256_permute_ps(va3, _MM_SHUFFLE(0, 3, 2, 1));
3618
3619 const __m256 vb01234567c3 = _mm256_load_ps(w + 48);
3620 const __m256 vb89ABCDEFc3 = _mm256_load_ps(w + 56);
3621
3622 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c3, vacc0x01234567);
3623 vacc1x01234567 = _mm256_fmadd_ps(va1, vb01234567c3, vacc1x01234567);
3624 vacc2x01234567 = _mm256_fmadd_ps(va2, vb01234567c3, vacc2x01234567);
3625 vacc3x01234567 = _mm256_fmadd_ps(va3, vb01234567c3, vacc3x01234567);
3626 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc3, vacc0x89ABCDEF);
3627 vacc1x89ABCDEF = _mm256_fmadd_ps(va1, vb89ABCDEFc3, vacc1x89ABCDEF);
3628 vacc2x89ABCDEF = _mm256_fmadd_ps(va2, vb89ABCDEFc3, vacc2x89ABCDEF);
3629 vacc3x89ABCDEF = _mm256_fmadd_ps(va3, vb89ABCDEFc3, vacc3x89ABCDEF);
3630
3631
3632 w += 64;
3633 k -= 4 * sizeof(float);
3634 }
3635 if XNN_UNLIKELY(k != 0) {
3636 __m256 va0 = _mm256_broadcast_ps((const __m128*) a0);
3637 a0 = (const float*) ((uintptr_t) a0 + k);
3638 __m256 va1 = _mm256_broadcast_ps((const __m128*) a1);
3639 a1 = (const float*) ((uintptr_t) a1 + k);
3640 __m256 va2 = _mm256_broadcast_ps((const __m128*) a2);
3641 a2 = (const float*) ((uintptr_t) a2 + k);
3642 __m256 va3 = _mm256_broadcast_ps((const __m128*) a3);
3643 a3 = (const float*) ((uintptr_t) a3 + k);
3644
3645 const __m256 vzero = _mm256_setzero_ps();
3646
3647 const __m256 vb01234567c0 = _mm256_load_ps(w + 0);
3648 const __m256 vb89ABCDEFc0 = _mm256_load_ps(w + 8);
3649
3650 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c0, vzero, _CMP_NEQ_OQ)), vb01234567c0, vacc0x01234567);
3651 vacc1x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb01234567c0, vzero, _CMP_NEQ_OQ)), vb01234567c0, vacc1x01234567);
3652 vacc2x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb01234567c0, vzero, _CMP_NEQ_OQ)), vb01234567c0, vacc2x01234567);
3653 vacc3x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb01234567c0, vzero, _CMP_NEQ_OQ)), vb01234567c0, vacc3x01234567);
3654 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc0, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc0, vacc0x89ABCDEF);
3655 vacc1x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb89ABCDEFc0, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc0, vacc1x89ABCDEF);
3656 vacc2x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb89ABCDEFc0, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc0, vacc2x89ABCDEF);
3657 vacc3x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb89ABCDEFc0, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc0, vacc3x89ABCDEF);
3658
3659 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
3660 va1 = _mm256_permute_ps(va1, _MM_SHUFFLE(0, 3, 2, 1));
3661 va2 = _mm256_permute_ps(va2, _MM_SHUFFLE(0, 3, 2, 1));
3662 va3 = _mm256_permute_ps(va3, _MM_SHUFFLE(0, 3, 2, 1));
3663
3664 const __m256 vb01234567c1 = _mm256_load_ps(w + 16);
3665 const __m256 vb89ABCDEFc1 = _mm256_load_ps(w + 24);
3666
3667 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c1, vzero, _CMP_NEQ_OQ)), vb01234567c1, vacc0x01234567);
3668 vacc1x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb01234567c1, vzero, _CMP_NEQ_OQ)), vb01234567c1, vacc1x01234567);
3669 vacc2x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb01234567c1, vzero, _CMP_NEQ_OQ)), vb01234567c1, vacc2x01234567);
3670 vacc3x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb01234567c1, vzero, _CMP_NEQ_OQ)), vb01234567c1, vacc3x01234567);
3671 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc1, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc1, vacc0x89ABCDEF);
3672 vacc1x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb89ABCDEFc1, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc1, vacc1x89ABCDEF);
3673 vacc2x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb89ABCDEFc1, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc1, vacc2x89ABCDEF);
3674 vacc3x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb89ABCDEFc1, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc1, vacc3x89ABCDEF);
3675
3676 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
3677 va1 = _mm256_permute_ps(va1, _MM_SHUFFLE(0, 3, 2, 1));
3678 va2 = _mm256_permute_ps(va2, _MM_SHUFFLE(0, 3, 2, 1));
3679 va3 = _mm256_permute_ps(va3, _MM_SHUFFLE(0, 3, 2, 1));
3680
3681 const __m256 vb01234567c2 = _mm256_load_ps(w + 32);
3682 const __m256 vb89ABCDEFc2 = _mm256_load_ps(w + 40);
3683
3684 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c2, vzero, _CMP_NEQ_OQ)), vb01234567c2, vacc0x01234567);
3685 vacc1x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb01234567c2, vzero, _CMP_NEQ_OQ)), vb01234567c2, vacc1x01234567);
3686 vacc2x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb01234567c2, vzero, _CMP_NEQ_OQ)), vb01234567c2, vacc2x01234567);
3687 vacc3x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb01234567c2, vzero, _CMP_NEQ_OQ)), vb01234567c2, vacc3x01234567);
3688 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc2, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc2, vacc0x89ABCDEF);
3689 vacc1x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb89ABCDEFc2, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc2, vacc1x89ABCDEF);
3690 vacc2x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb89ABCDEFc2, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc2, vacc2x89ABCDEF);
3691 vacc3x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb89ABCDEFc2, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc2, vacc3x89ABCDEF);
3692
3693 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
3694 va1 = _mm256_permute_ps(va1, _MM_SHUFFLE(0, 3, 2, 1));
3695 va2 = _mm256_permute_ps(va2, _MM_SHUFFLE(0, 3, 2, 1));
3696 va3 = _mm256_permute_ps(va3, _MM_SHUFFLE(0, 3, 2, 1));
3697
3698 const __m256 vb01234567c3 = _mm256_load_ps(w + 48);
3699 const __m256 vb89ABCDEFc3 = _mm256_load_ps(w + 56);
3700
3701 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c3, vzero, _CMP_NEQ_OQ)), vb01234567c3, vacc0x01234567);
3702 vacc1x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb01234567c3, vzero, _CMP_NEQ_OQ)), vb01234567c3, vacc1x01234567);
3703 vacc2x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb01234567c3, vzero, _CMP_NEQ_OQ)), vb01234567c3, vacc2x01234567);
3704 vacc3x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb01234567c3, vzero, _CMP_NEQ_OQ)), vb01234567c3, vacc3x01234567);
3705 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc3, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc3, vacc0x89ABCDEF);
3706 vacc1x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb89ABCDEFc3, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc3, vacc1x89ABCDEF);
3707 vacc2x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb89ABCDEFc3, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc3, vacc2x89ABCDEF);
3708 vacc3x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb89ABCDEFc3, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc3, vacc3x89ABCDEF);
3709
3710
3711 w += 64;
3712 }
3713 p -= 4 * sizeof(void*);
3714 } while (p != 0);
3715
3716 const __m256 vmin = _mm256_load_ps(params->avx.min);
3717 vacc0x01234567 = _mm256_max_ps(vacc0x01234567, vmin);
3718 vacc1x01234567 = _mm256_max_ps(vacc1x01234567, vmin);
3719 vacc2x01234567 = _mm256_max_ps(vacc2x01234567, vmin);
3720 vacc3x01234567 = _mm256_max_ps(vacc3x01234567, vmin);
3721 vacc0x89ABCDEF = _mm256_max_ps(vacc0x89ABCDEF, vmin);
3722 vacc1x89ABCDEF = _mm256_max_ps(vacc1x89ABCDEF, vmin);
3723 vacc2x89ABCDEF = _mm256_max_ps(vacc2x89ABCDEF, vmin);
3724 vacc3x89ABCDEF = _mm256_max_ps(vacc3x89ABCDEF, vmin);
3725
3726 const __m256 vmax = _mm256_load_ps(params->avx.max);
3727 vacc0x01234567 = _mm256_min_ps(vacc0x01234567, vmax);
3728 vacc1x01234567 = _mm256_min_ps(vacc1x01234567, vmax);
3729 vacc2x01234567 = _mm256_min_ps(vacc2x01234567, vmax);
3730 vacc3x01234567 = _mm256_min_ps(vacc3x01234567, vmax);
3731 vacc0x89ABCDEF = _mm256_min_ps(vacc0x89ABCDEF, vmax);
3732 vacc1x89ABCDEF = _mm256_min_ps(vacc1x89ABCDEF, vmax);
3733 vacc2x89ABCDEF = _mm256_min_ps(vacc2x89ABCDEF, vmax);
3734 vacc3x89ABCDEF = _mm256_min_ps(vacc3x89ABCDEF, vmax);
3735
3736 if XNN_LIKELY(nc >= 16) {
3737 _mm256_storeu_ps(c3, vacc3x01234567);
3738 _mm256_storeu_ps(c3 + 8, vacc3x89ABCDEF);
3739 c3 = (float*) ((uintptr_t) c3 + cn_stride);
3740 _mm256_storeu_ps(c2, vacc2x01234567);
3741 _mm256_storeu_ps(c2 + 8, vacc2x89ABCDEF);
3742 c2 = (float*) ((uintptr_t) c2 + cn_stride);
3743 _mm256_storeu_ps(c1, vacc1x01234567);
3744 _mm256_storeu_ps(c1 + 8, vacc1x89ABCDEF);
3745 c1 = (float*) ((uintptr_t) c1 + cn_stride);
3746 _mm256_storeu_ps(c0, vacc0x01234567);
3747 _mm256_storeu_ps(c0 + 8, vacc0x89ABCDEF);
3748 c0 = (float*) ((uintptr_t) c0 + cn_stride);
3749
3750 a = (const float**restrict) ((uintptr_t) a - ks);
3751 nc -= 16;
3752 } else {
3753 if (nc & 8) {
3754 _mm256_storeu_ps(c3, vacc3x01234567);
3755 _mm256_storeu_ps(c2, vacc2x01234567);
3756 _mm256_storeu_ps(c1, vacc1x01234567);
3757 _mm256_storeu_ps(c0, vacc0x01234567);
3758
3759 vacc3x01234567 = vacc3x89ABCDEF;
3760 vacc2x01234567 = vacc2x89ABCDEF;
3761 vacc1x01234567 = vacc1x89ABCDEF;
3762 vacc0x01234567 = vacc0x89ABCDEF;
3763
3764 c3 += 8;
3765 c2 += 8;
3766 c1 += 8;
3767 c0 += 8;
3768 }
3769 __m128 vacc3x0123 = _mm256_castps256_ps128(vacc3x01234567);
3770 __m128 vacc2x0123 = _mm256_castps256_ps128(vacc2x01234567);
3771 __m128 vacc1x0123 = _mm256_castps256_ps128(vacc1x01234567);
3772 __m128 vacc0x0123 = _mm256_castps256_ps128(vacc0x01234567);
3773 if (nc & 4) {
3774 _mm_storeu_ps(c3, vacc3x0123);
3775 _mm_storeu_ps(c2, vacc2x0123);
3776 _mm_storeu_ps(c1, vacc1x0123);
3777 _mm_storeu_ps(c0, vacc0x0123);
3778
3779 vacc3x0123 = _mm256_extractf128_ps(vacc3x01234567, 1);
3780 vacc2x0123 = _mm256_extractf128_ps(vacc2x01234567, 1);
3781 vacc1x0123 = _mm256_extractf128_ps(vacc1x01234567, 1);
3782 vacc0x0123 = _mm256_extractf128_ps(vacc0x01234567, 1);
3783
3784 c3 += 4;
3785 c2 += 4;
3786 c1 += 4;
3787 c0 += 4;
3788 }
3789 if (nc & 2) {
3790 _mm_storel_pi((__m64*) c3, vacc3x0123);
3791 _mm_storel_pi((__m64*) c2, vacc2x0123);
3792 _mm_storel_pi((__m64*) c1, vacc1x0123);
3793 _mm_storel_pi((__m64*) c0, vacc0x0123);
3794
3795 vacc3x0123 = _mm_movehl_ps(vacc3x0123, vacc3x0123);
3796 vacc2x0123 = _mm_movehl_ps(vacc2x0123, vacc2x0123);
3797 vacc1x0123 = _mm_movehl_ps(vacc1x0123, vacc1x0123);
3798 vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
3799
3800 c3 += 2;
3801 c2 += 2;
3802 c1 += 2;
3803 c0 += 2;
3804 }
3805 if (nc & 1) {
3806 _mm_store_ss(c3, vacc3x0123);
3807 _mm_store_ss(c2, vacc2x0123);
3808 _mm_store_ss(c1, vacc1x0123);
3809 _mm_store_ss(c0, vacc0x0123);
3810 }
3811
3812 nc = 0;
3813 }
3814 } while (nc != 0);
3815 }
3816
xnn_f32_igemm_minmax_ukernel_5x16__fma3_broadcast(size_t mr,size_t nc,size_t kc,size_t ks,const float ** restrict a,const float * restrict w,float * restrict c,size_t cm_stride,size_t cn_stride,size_t a_offset,const float * zero,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])3817 void xnn_f32_igemm_minmax_ukernel_5x16__fma3_broadcast(
3818 size_t mr,
3819 size_t nc,
3820 size_t kc,
3821 size_t ks,
3822 const float**restrict a,
3823 const float*restrict w,
3824 float*restrict c,
3825 size_t cm_stride,
3826 size_t cn_stride,
3827 size_t a_offset,
3828 const float* zero,
3829 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
3830 {
3831 assert(mr != 0);
3832 assert(mr <= 5);
3833 assert(nc != 0);
3834 assert(kc != 0);
3835 assert(kc % sizeof(float) == 0);
3836 assert(ks != 0);
3837 assert(ks % (5 * sizeof(void*)) == 0);
3838 assert(a_offset % sizeof(float) == 0);
3839 assert(a != NULL);
3840 assert(w != NULL);
3841 assert(c != NULL);
3842
3843 float* c0 = c;
3844 float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
3845 if XNN_UNPREDICTABLE(mr < 2) {
3846 c1 = c0;
3847 }
3848 float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
3849 if XNN_UNPREDICTABLE(mr <= 2) {
3850 c2 = c1;
3851 }
3852 float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
3853 if XNN_UNPREDICTABLE(mr < 4) {
3854 c3 = c2;
3855 }
3856 float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
3857 if XNN_UNPREDICTABLE(mr <= 4) {
3858 c4 = c3;
3859 }
3860
3861 do {
3862 __m256 vacc0x01234567 = _mm256_load_ps(w);
3863 __m256 vacc0x89ABCDEF = _mm256_load_ps(w + 8);
3864 __m256 vacc1x01234567 = vacc0x01234567;
3865 __m256 vacc1x89ABCDEF = vacc0x89ABCDEF;
3866 __m256 vacc2x01234567 = vacc0x01234567;
3867 __m256 vacc2x89ABCDEF = vacc0x89ABCDEF;
3868 __m256 vacc3x01234567 = vacc0x01234567;
3869 __m256 vacc3x89ABCDEF = vacc0x89ABCDEF;
3870 __m256 vacc4x01234567 = vacc0x01234567;
3871 __m256 vacc4x89ABCDEF = vacc0x89ABCDEF;
3872 w += 16;
3873
3874 size_t p = ks;
3875 do {
3876 const float* restrict a0 = a[0];
3877 assert(a0 != NULL);
3878 if XNN_UNPREDICTABLE(a0 != zero) {
3879 a0 = (const float*) ((uintptr_t) a0 + a_offset);
3880 }
3881 const float* restrict a1 = a[1];
3882 assert(a1 != NULL);
3883 if XNN_UNPREDICTABLE(a1 != zero) {
3884 a1 = (const float*) ((uintptr_t) a1 + a_offset);
3885 }
3886 const float* restrict a2 = a[2];
3887 assert(a2 != NULL);
3888 if XNN_UNPREDICTABLE(a2 != zero) {
3889 a2 = (const float*) ((uintptr_t) a2 + a_offset);
3890 }
3891 const float* restrict a3 = a[3];
3892 assert(a3 != NULL);
3893 if XNN_UNPREDICTABLE(a3 != zero) {
3894 a3 = (const float*) ((uintptr_t) a3 + a_offset);
3895 }
3896 const float* restrict a4 = a[4];
3897 assert(a4 != NULL);
3898 if XNN_UNPREDICTABLE(a4 != zero) {
3899 a4 = (const float*) ((uintptr_t) a4 + a_offset);
3900 }
3901 a += 5;
3902
3903 size_t k = kc;
3904 do {
3905 const __m256 vb01234567 = _mm256_load_ps(w);
3906 const __m256 vb89ABCDEF = _mm256_load_ps(w + 8);
3907 w += 16;
3908
3909 const __m256 va0 = _mm256_broadcast_ss(a0);
3910 a0 += 1;
3911 const __m256 va1 = _mm256_broadcast_ss(a1);
3912 a1 += 1;
3913 const __m256 va2 = _mm256_broadcast_ss(a2);
3914 a2 += 1;
3915 const __m256 va3 = _mm256_broadcast_ss(a3);
3916 a3 += 1;
3917 const __m256 va4 = _mm256_broadcast_ss(a4);
3918 a4 += 1;
3919
3920 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567, vacc0x01234567);
3921 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEF, vacc0x89ABCDEF);
3922 vacc1x01234567 = _mm256_fmadd_ps(va1, vb01234567, vacc1x01234567);
3923 vacc1x89ABCDEF = _mm256_fmadd_ps(va1, vb89ABCDEF, vacc1x89ABCDEF);
3924 vacc2x01234567 = _mm256_fmadd_ps(va2, vb01234567, vacc2x01234567);
3925 vacc2x89ABCDEF = _mm256_fmadd_ps(va2, vb89ABCDEF, vacc2x89ABCDEF);
3926 vacc3x01234567 = _mm256_fmadd_ps(va3, vb01234567, vacc3x01234567);
3927 vacc3x89ABCDEF = _mm256_fmadd_ps(va3, vb89ABCDEF, vacc3x89ABCDEF);
3928 vacc4x01234567 = _mm256_fmadd_ps(va4, vb01234567, vacc4x01234567);
3929 vacc4x89ABCDEF = _mm256_fmadd_ps(va4, vb89ABCDEF, vacc4x89ABCDEF);
3930 k -= sizeof(float);
3931 } while (k != 0);
3932 p -= 5 * sizeof(void*);
3933 } while (p != 0);
3934
3935 const __m256 vmin = _mm256_load_ps(params->avx.min);
3936 vacc0x01234567 = _mm256_max_ps(vacc0x01234567, vmin);
3937 vacc1x01234567 = _mm256_max_ps(vacc1x01234567, vmin);
3938 vacc2x01234567 = _mm256_max_ps(vacc2x01234567, vmin);
3939 vacc3x01234567 = _mm256_max_ps(vacc3x01234567, vmin);
3940 vacc4x01234567 = _mm256_max_ps(vacc4x01234567, vmin);
3941 vacc0x89ABCDEF = _mm256_max_ps(vacc0x89ABCDEF, vmin);
3942 vacc1x89ABCDEF = _mm256_max_ps(vacc1x89ABCDEF, vmin);
3943 vacc2x89ABCDEF = _mm256_max_ps(vacc2x89ABCDEF, vmin);
3944 vacc3x89ABCDEF = _mm256_max_ps(vacc3x89ABCDEF, vmin);
3945 vacc4x89ABCDEF = _mm256_max_ps(vacc4x89ABCDEF, vmin);
3946
3947 const __m256 vmax = _mm256_load_ps(params->avx.max);
3948 vacc0x01234567 = _mm256_min_ps(vacc0x01234567, vmax);
3949 vacc1x01234567 = _mm256_min_ps(vacc1x01234567, vmax);
3950 vacc2x01234567 = _mm256_min_ps(vacc2x01234567, vmax);
3951 vacc3x01234567 = _mm256_min_ps(vacc3x01234567, vmax);
3952 vacc4x01234567 = _mm256_min_ps(vacc4x01234567, vmax);
3953 vacc0x89ABCDEF = _mm256_min_ps(vacc0x89ABCDEF, vmax);
3954 vacc1x89ABCDEF = _mm256_min_ps(vacc1x89ABCDEF, vmax);
3955 vacc2x89ABCDEF = _mm256_min_ps(vacc2x89ABCDEF, vmax);
3956 vacc3x89ABCDEF = _mm256_min_ps(vacc3x89ABCDEF, vmax);
3957 vacc4x89ABCDEF = _mm256_min_ps(vacc4x89ABCDEF, vmax);
3958
3959 if XNN_LIKELY(nc >= 16) {
3960 _mm256_storeu_ps(c4, vacc4x01234567);
3961 _mm256_storeu_ps(c4 + 8, vacc4x89ABCDEF);
3962 c4 = (float*) ((uintptr_t) c4 + cn_stride);
3963 _mm256_storeu_ps(c3, vacc3x01234567);
3964 _mm256_storeu_ps(c3 + 8, vacc3x89ABCDEF);
3965 c3 = (float*) ((uintptr_t) c3 + cn_stride);
3966 _mm256_storeu_ps(c2, vacc2x01234567);
3967 _mm256_storeu_ps(c2 + 8, vacc2x89ABCDEF);
3968 c2 = (float*) ((uintptr_t) c2 + cn_stride);
3969 _mm256_storeu_ps(c1, vacc1x01234567);
3970 _mm256_storeu_ps(c1 + 8, vacc1x89ABCDEF);
3971 c1 = (float*) ((uintptr_t) c1 + cn_stride);
3972 _mm256_storeu_ps(c0, vacc0x01234567);
3973 _mm256_storeu_ps(c0 + 8, vacc0x89ABCDEF);
3974 c0 = (float*) ((uintptr_t) c0 + cn_stride);
3975
3976 a = (const float**restrict) ((uintptr_t) a - ks);
3977 nc -= 16;
3978 } else {
3979 if (nc & 8) {
3980 _mm256_storeu_ps(c4, vacc4x01234567);
3981 _mm256_storeu_ps(c3, vacc3x01234567);
3982 _mm256_storeu_ps(c2, vacc2x01234567);
3983 _mm256_storeu_ps(c1, vacc1x01234567);
3984 _mm256_storeu_ps(c0, vacc0x01234567);
3985
3986 vacc4x01234567 = vacc4x89ABCDEF;
3987 vacc3x01234567 = vacc3x89ABCDEF;
3988 vacc2x01234567 = vacc2x89ABCDEF;
3989 vacc1x01234567 = vacc1x89ABCDEF;
3990 vacc0x01234567 = vacc0x89ABCDEF;
3991
3992 c4 += 8;
3993 c3 += 8;
3994 c2 += 8;
3995 c1 += 8;
3996 c0 += 8;
3997 }
3998 __m128 vacc4x0123 = _mm256_castps256_ps128(vacc4x01234567);
3999 __m128 vacc3x0123 = _mm256_castps256_ps128(vacc3x01234567);
4000 __m128 vacc2x0123 = _mm256_castps256_ps128(vacc2x01234567);
4001 __m128 vacc1x0123 = _mm256_castps256_ps128(vacc1x01234567);
4002 __m128 vacc0x0123 = _mm256_castps256_ps128(vacc0x01234567);
4003 if (nc & 4) {
4004 _mm_storeu_ps(c4, vacc4x0123);
4005 _mm_storeu_ps(c3, vacc3x0123);
4006 _mm_storeu_ps(c2, vacc2x0123);
4007 _mm_storeu_ps(c1, vacc1x0123);
4008 _mm_storeu_ps(c0, vacc0x0123);
4009
4010 vacc4x0123 = _mm256_extractf128_ps(vacc4x01234567, 1);
4011 vacc3x0123 = _mm256_extractf128_ps(vacc3x01234567, 1);
4012 vacc2x0123 = _mm256_extractf128_ps(vacc2x01234567, 1);
4013 vacc1x0123 = _mm256_extractf128_ps(vacc1x01234567, 1);
4014 vacc0x0123 = _mm256_extractf128_ps(vacc0x01234567, 1);
4015
4016 c4 += 4;
4017 c3 += 4;
4018 c2 += 4;
4019 c1 += 4;
4020 c0 += 4;
4021 }
4022 if (nc & 2) {
4023 _mm_storel_pi((__m64*) c4, vacc4x0123);
4024 _mm_storel_pi((__m64*) c3, vacc3x0123);
4025 _mm_storel_pi((__m64*) c2, vacc2x0123);
4026 _mm_storel_pi((__m64*) c1, vacc1x0123);
4027 _mm_storel_pi((__m64*) c0, vacc0x0123);
4028
4029 vacc4x0123 = _mm_movehl_ps(vacc4x0123, vacc4x0123);
4030 vacc3x0123 = _mm_movehl_ps(vacc3x0123, vacc3x0123);
4031 vacc2x0123 = _mm_movehl_ps(vacc2x0123, vacc2x0123);
4032 vacc1x0123 = _mm_movehl_ps(vacc1x0123, vacc1x0123);
4033 vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
4034
4035 c4 += 2;
4036 c3 += 2;
4037 c2 += 2;
4038 c1 += 2;
4039 c0 += 2;
4040 }
4041 if (nc & 1) {
4042 _mm_store_ss(c4, vacc4x0123);
4043 _mm_store_ss(c3, vacc3x0123);
4044 _mm_store_ss(c2, vacc2x0123);
4045 _mm_store_ss(c1, vacc1x0123);
4046 _mm_store_ss(c0, vacc0x0123);
4047 }
4048
4049 nc = 0;
4050 }
4051 } while (nc != 0);
4052 }
4053
xnn_f32_vhswish_ukernel__fma3_x16(size_t n,const float * x,float * y,const union xnn_f32_hswish_params params[restrict XNN_MIN_ELEMENTS (1)])4054 void xnn_f32_vhswish_ukernel__fma3_x16(
4055 size_t n,
4056 const float* x,
4057 float* y,
4058 const union xnn_f32_hswish_params params[restrict XNN_MIN_ELEMENTS(1)])
4059 {
4060 assert(n != 0);
4061 assert(n % sizeof(float) == 0);
4062
4063 const __m256 vsixth = _mm256_load_ps(params->avx.sixth);
4064 const __m256 vhalf = _mm256_load_ps(params->avx.half);
4065 const __m256 vone = _mm256_load_ps(params->avx.one);
4066 const __m256 vzero = _mm256_setzero_ps();
4067
4068 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
4069 const __m256 vx01234567 = _mm256_loadu_ps(x);
4070 const __m256 vx89ABCDEF = _mm256_loadu_ps(x + 8);
4071 x += 16;
4072
4073 __m256 vacc01234567 = _mm256_fmadd_ps(vx01234567, vsixth, vhalf);
4074 __m256 vacc89ABCDEF = _mm256_fmadd_ps(vx89ABCDEF, vsixth, vhalf);
4075
4076 vacc01234567 = _mm256_max_ps(vacc01234567, vzero);
4077 vacc89ABCDEF = _mm256_max_ps(vacc89ABCDEF, vzero);
4078
4079 vacc01234567 = _mm256_min_ps(vacc01234567, vone);
4080 vacc89ABCDEF = _mm256_min_ps(vacc89ABCDEF, vone);
4081
4082 vacc01234567 = _mm256_mul_ps(vacc01234567, vx01234567);
4083 vacc89ABCDEF = _mm256_mul_ps(vacc89ABCDEF, vx89ABCDEF);
4084
4085 _mm256_storeu_ps(y, vacc01234567);
4086 _mm256_storeu_ps(y + 8, vacc89ABCDEF);
4087 y += 16;
4088 }
4089 for (; n >= 8 * sizeof(float); n -= 8 * sizeof(float)) {
4090 const __m256 vx = _mm256_loadu_ps(x);
4091 x += 8;
4092 __m256 vacc = _mm256_fmadd_ps(vx, vsixth, vhalf);
4093 vacc = _mm256_max_ps(vacc, vzero);
4094 vacc = _mm256_min_ps(vacc, vone);
4095 vacc = _mm256_mul_ps(vacc, vx);
4096 _mm256_storeu_ps(y, vacc);
4097 y += 8;
4098 }
4099 if XNN_UNLIKELY(n != 0) {
4100 assert(n >= 1 * sizeof(float));
4101 assert(n <= 7 * sizeof(float));
4102 const __m256i vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) ¶ms->avx.mask_table[7] - n));
4103
4104 const __m256 vx = _mm256_maskload_ps(x, vmask);
4105 __m256 vacc = _mm256_fmadd_ps(vx, vsixth, vhalf);
4106 vacc = _mm256_max_ps(vacc, vzero);
4107 vacc = _mm256_min_ps(vacc, vone);
4108 vacc = _mm256_mul_ps(vacc, vx);
4109
4110 __m128 vacc_lo = _mm256_castps256_ps128(vacc);
4111 if (n & (4 * sizeof(float))) {
4112 _mm_storeu_ps(y, vacc_lo);
4113 vacc_lo = _mm256_extractf128_ps(vacc, 1);
4114 y += 4;
4115 }
4116 if (n & (2 * sizeof(float))) {
4117 _mm_storel_pi((__m64*) y, vacc_lo);
4118 vacc_lo = _mm_movehl_ps(vacc_lo, vacc_lo);
4119 y += 2;
4120 }
4121 if (n & (1 * sizeof(float))) {
4122 _mm_store_ss(y, vacc_lo);
4123 }
4124 }
4125 }
4126