1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7
8 #include <immintrin.h>
9
10 #include <xnnpack/avgpool.h>
11 #include <xnnpack/intrinsics-polyfill.h>
12
13
xnn_f16_avgpool_minmax_ukernel_9p8x__f16c_c8(size_t output_pixels,size_t kernel_elements,size_t channels,const void ** input,size_t input_offset,const void * zero,void * buffer,void * output,size_t input_increment,size_t output_increment,const union xnn_f16_scaleminmax_params params[restrict XNN_MIN_ELEMENTS (1)])14 void xnn_f16_avgpool_minmax_ukernel_9p8x__f16c_c8(
15 size_t output_pixels,
16 size_t kernel_elements,
17 size_t channels,
18 const void** input,
19 size_t input_offset,
20 const void* zero,
21 void* buffer,
22 void* output,
23 size_t input_increment,
24 size_t output_increment,
25 const union xnn_f16_scaleminmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
26 {
27 assert(output_pixels != 0);
28 assert(kernel_elements > 9);
29 assert(channels != 0);
30
31 const __m256 vscale = _mm256_load_ps(params->avx.scale);
32 const __m256 vmin = _mm256_load_ps(params->avx.min);
33 const __m256 vmax = _mm256_load_ps(params->avx.max);
34
35 uint16_t* o = (uint16_t*) output;
36 do {
37 {
38 const uint16_t* i0 = *input++;
39 assert(i0 != NULL);
40 if XNN_UNPREDICTABLE(i0 != zero) {
41 i0 = (const uint16_t*) ((uintptr_t) i0 + input_offset);
42 }
43 const uint16_t* i1 = *input++;
44 assert(i1 != NULL);
45 if XNN_UNPREDICTABLE(i1 != zero) {
46 i1 = (const uint16_t*) ((uintptr_t) i1 + input_offset);
47 }
48 const uint16_t* i2 = *input++;
49 assert(i2 != NULL);
50 if XNN_UNPREDICTABLE(i2 != zero) {
51 i2 = (const uint16_t*) ((uintptr_t) i2 + input_offset);
52 }
53 const uint16_t* i3 = *input++;
54 assert(i3 != NULL);
55 if XNN_UNPREDICTABLE(i3 != zero) {
56 i3 = (const uint16_t*) ((uintptr_t) i3 + input_offset);
57 }
58 const uint16_t* i4 = *input++;
59 assert(i4 != NULL);
60 if XNN_UNPREDICTABLE(i4 != zero) {
61 i4 = (const uint16_t*) ((uintptr_t) i4 + input_offset);
62 }
63 const uint16_t* i5 = *input++;
64 assert(i5 != NULL);
65 if XNN_UNPREDICTABLE(i5 != zero) {
66 i5 = (const uint16_t*) ((uintptr_t) i5 + input_offset);
67 }
68 const uint16_t* i6 = *input++;
69 assert(i6 != NULL);
70 if XNN_UNPREDICTABLE(i6 != zero) {
71 i6 = (const uint16_t*) ((uintptr_t) i6 + input_offset);
72 }
73 const uint16_t* i7 = *input++;
74 assert(i7 != NULL);
75 if XNN_UNPREDICTABLE(i7 != zero) {
76 i7 = (const uint16_t*) ((uintptr_t) i7 + input_offset);
77 }
78 const uint16_t* i8 = *input++;
79 assert(i8 != NULL);
80 if XNN_UNPREDICTABLE(i8 != zero) {
81 i8 = (const uint16_t*) ((uintptr_t) i8 + input_offset);
82 }
83
84 uint16_t* b = (uint16_t*) buffer;
85 for (size_t c = 0; c < channels; c += 8) {
86 const __m256 vi0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
87 i0 += 8;
88 const __m256 vi1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
89 i1 += 8;
90 const __m256 vi2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
91 i2 += 8;
92 const __m256 vi3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
93 i3 += 8;
94 const __m256 vi4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4));
95 i4 += 8;
96 const __m256 vi5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5));
97 i5 += 8;
98 const __m256 vi6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6));
99 i6 += 8;
100 const __m256 vi7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i7));
101 i7 += 8;
102 const __m256 vi8 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i8));
103 i8 += 8;
104
105 const __m256 vsum01 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi0, vi1), _MM_FROUND_NO_EXC));
106 const __m256 vsum23 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi2, vi3), _MM_FROUND_NO_EXC));
107 const __m256 vsum45 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi4, vi5), _MM_FROUND_NO_EXC));
108 const __m256 vsum67 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi6, vi7), _MM_FROUND_NO_EXC));
109 const __m256 vsum018 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum01, vi8), _MM_FROUND_NO_EXC));
110 const __m256 vsum2345 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum23, vsum45), _MM_FROUND_NO_EXC));
111 const __m256 vsum01678 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum018, vsum67), _MM_FROUND_NO_EXC));
112 const __m128i vsum = _mm256_cvtps_ph(_mm256_add_ps(vsum2345, vsum01678), _MM_FROUND_NO_EXC);
113
114 _mm_storeu_si128((__m128i*) b, vsum);
115 b += 8;
116 }
117 }
118
119 size_t k = kernel_elements;
120 for (k -= 9; k > 8; k -= 8) {
121 const uint16_t* i0 = (const uint16_t*) *input++;
122 assert(i0 != NULL);
123 if XNN_UNPREDICTABLE(i0 != zero) {
124 i0 = (const uint16_t*) ((uintptr_t) i0 + input_offset);
125 }
126 const uint16_t* i1 = (const uint16_t*) *input++;
127 assert(i1 != NULL);
128 if XNN_UNPREDICTABLE(i1 != zero) {
129 i1 = (const uint16_t*) ((uintptr_t) i1 + input_offset);
130 }
131 const uint16_t* i2 = (const uint16_t*) *input++;
132 assert(i2 != NULL);
133 if XNN_UNPREDICTABLE(i2 != zero) {
134 i2 = (const uint16_t*) ((uintptr_t) i2 + input_offset);
135 }
136 const uint16_t* i3 = (const uint16_t*) *input++;
137 assert(i3 != NULL);
138 if XNN_UNPREDICTABLE(i3 != zero) {
139 i3 = (const uint16_t*) ((uintptr_t) i3 + input_offset);
140 }
141 const uint16_t* i4 = (const uint16_t*) *input++;
142 assert(i4 != NULL);
143 if XNN_UNPREDICTABLE(i4 != zero) {
144 i4 = (const uint16_t*) ((uintptr_t) i4 + input_offset);
145 }
146 const uint16_t* i5 = (const uint16_t*) *input++;
147 assert(i5 != NULL);
148 if XNN_UNPREDICTABLE(i5 != zero) {
149 i5 = (const uint16_t*) ((uintptr_t) i5 + input_offset);
150 }
151 const uint16_t* i6 = (const uint16_t*) *input++;
152 assert(i6 != NULL);
153 if XNN_UNPREDICTABLE(i6 != zero) {
154 i6 = (const uint16_t*) ((uintptr_t) i6 + input_offset);
155 }
156 const uint16_t* i7 = (const uint16_t*) *input++;
157 assert(i7 != NULL);
158 if XNN_UNPREDICTABLE(i7 != zero) {
159 i7 = (const uint16_t*) ((uintptr_t) i7 + input_offset);
160 }
161
162 uint16_t* b = (uint16_t*) buffer;
163 for (size_t c = 0; c < channels; c += 8) {
164 const __m256 vi0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
165 i0 += 8;
166 const __m256 vi1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
167 i1 += 8;
168 const __m256 vi2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
169 i2 += 8;
170 const __m256 vi3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
171 i3 += 8;
172 const __m256 vi4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4));
173 i4 += 8;
174 const __m256 vi5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5));
175 i5 += 8;
176 const __m256 vi6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6));
177 i6 += 8;
178 const __m256 vi7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i7));
179 i7 += 8;
180 const __m256 vacc = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) b));
181
182 const __m256 vsum01 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi0, vi1), _MM_FROUND_NO_EXC));
183 const __m256 vsum23 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi2, vi3), _MM_FROUND_NO_EXC));
184 const __m256 vsum45 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi4, vi5), _MM_FROUND_NO_EXC));
185 const __m256 vsum67 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi6, vi7), _MM_FROUND_NO_EXC));
186 const __m256 vsum01a = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum01, vacc), _MM_FROUND_NO_EXC));
187 const __m256 vsum2345 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum23, vsum45), _MM_FROUND_NO_EXC));
188 const __m256 vsum0167a = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum01a, vsum67), _MM_FROUND_NO_EXC));
189 const __m128i vsum = _mm256_cvtps_ph(_mm256_add_ps(vsum2345, vsum0167a), _MM_FROUND_NO_EXC);
190
191 _mm_storeu_si128((__m128i*) b, vsum);
192 b += 8;
193 }
194 }
195
196 assert(k >= 1);
197 {
198 const uint16_t* i0 = (const uint16_t*) input[0];
199 assert(i0 != NULL);
200 const uint16_t* i1 = (const uint16_t*) input[1];
201 const uint16_t* i2 = (const uint16_t*) input[2];
202 const uint16_t* i3 = (const uint16_t*) input[3];
203 const uint16_t* i4 = (const uint16_t*) input[4];
204 const uint16_t* i5 = (const uint16_t*) input[5];
205 const uint16_t* i6 = (const uint16_t*) input[6];
206 const uint16_t* i7 = (const uint16_t*) input[7];
207 input = (const void**) ((uintptr_t) input + input_increment);
208 if (k < 2) {
209 i1 = (const uint16_t*) zero;
210 }
211 assert(i1 != NULL);
212 if (k <= 2) {
213 i2 = (const uint16_t*) zero;
214 }
215 assert(i2 != NULL);
216 if (k < 4) {
217 i3 = (const uint16_t*) zero;
218 }
219 assert(i3 != NULL);
220 if (k <= 4) {
221 i4 = (const uint16_t*) zero;
222 }
223 assert(i4 != NULL);
224 if (k < 6) {
225 i5 = (const uint16_t*) zero;
226 }
227 assert(i5 != NULL);
228 if (k <= 6) {
229 i6 = (const uint16_t*) zero;
230 }
231 assert(i6 != NULL);
232 if (k < 8) {
233 i7 = (const uint16_t*) zero;
234 }
235 assert(i7 != NULL);
236 if XNN_UNPREDICTABLE(i0 != zero) {
237 i0 = (const uint16_t*) ((uintptr_t) i0 + input_offset);
238 }
239 if XNN_UNPREDICTABLE(i1 != zero) {
240 i1 = (const uint16_t*) ((uintptr_t) i1 + input_offset);
241 }
242 if XNN_UNPREDICTABLE(i2 != zero) {
243 i2 = (const uint16_t*) ((uintptr_t) i2 + input_offset);
244 }
245 if XNN_UNPREDICTABLE(i3 != zero) {
246 i3 = (const uint16_t*) ((uintptr_t) i3 + input_offset);
247 }
248 if XNN_UNPREDICTABLE(i4 != zero) {
249 i4 = (const uint16_t*) ((uintptr_t) i4 + input_offset);
250 }
251 if XNN_UNPREDICTABLE(i5 != zero) {
252 i5 = (const uint16_t*) ((uintptr_t) i5 + input_offset);
253 }
254 if XNN_UNPREDICTABLE(i6 != zero) {
255 i6 = (const uint16_t*) ((uintptr_t) i6 + input_offset);
256 }
257 if XNN_UNPREDICTABLE(i7 != zero) {
258 i7 = (const uint16_t*) ((uintptr_t) i7 + input_offset);
259 }
260
261 size_t c = channels;
262 uint16_t* b = (uint16_t*) buffer;
263 while (c >= 8) {
264 const __m256 vi0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
265 i0 += 8;
266 const __m256 vi1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
267 i1 += 8;
268 const __m256 vi2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
269 i2 += 8;
270 const __m256 vi3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
271 i3 += 8;
272 const __m256 vi4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4));
273 i4 += 8;
274 const __m256 vi5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5));
275 i5 += 8;
276 const __m256 vi6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6));
277 i6 += 8;
278 const __m256 vi7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i7));
279 i7 += 8;
280 const __m256 vacc = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) b));
281 b += 8;
282
283 const __m256 vsum01 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi0, vi1), _MM_FROUND_NO_EXC));
284 const __m256 vsum23 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi2, vi3), _MM_FROUND_NO_EXC));
285 const __m256 vsum45 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi4, vi5), _MM_FROUND_NO_EXC));
286 const __m256 vsum67 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi6, vi7), _MM_FROUND_NO_EXC));
287 const __m256 vsum01a = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum01, vacc), _MM_FROUND_NO_EXC));
288 const __m256 vsum2345 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum23, vsum45), _MM_FROUND_NO_EXC));
289 const __m256 vsum0167a = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum01a, vsum67), _MM_FROUND_NO_EXC));
290 const __m256 vsum = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum2345, vsum0167a), _MM_FROUND_NO_EXC));
291
292 __m256 vout = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(vsum, vscale), _MM_FROUND_NO_EXC));
293 vout = _mm256_max_ps(vout, vmin);
294 vout = _mm256_min_ps(vout, vmax);
295
296 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vout, _MM_FROUND_NO_EXC));
297 o += 8;
298
299 c -= 8;
300 }
301 if (c != 0) {
302 const __m256 vi0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
303 const __m256 vi1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
304 const __m256 vi2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
305 const __m256 vi3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
306 const __m256 vi4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4));
307 const __m256 vi5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5));
308 const __m256 vi6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6));
309 const __m256 vi7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i7));
310 const __m256 vacc = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) b));
311
312 const __m256 vsum01 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi0, vi1), _MM_FROUND_NO_EXC));
313 const __m256 vsum23 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi2, vi3), _MM_FROUND_NO_EXC));
314 const __m256 vsum45 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi4, vi5), _MM_FROUND_NO_EXC));
315 const __m256 vsum67 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vi6, vi7), _MM_FROUND_NO_EXC));
316 const __m256 vsum01a = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum01, vacc), _MM_FROUND_NO_EXC));
317 const __m256 vsum2345 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum23, vsum45), _MM_FROUND_NO_EXC));
318 const __m256 vsum0167a = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum01a, vsum67), _MM_FROUND_NO_EXC));
319 const __m256 vsum = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vsum2345, vsum0167a), _MM_FROUND_NO_EXC));
320
321 __m256 vout = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(vsum, vscale), _MM_FROUND_NO_EXC));
322 vout = _mm256_max_ps(vout, vmin);
323 vout = _mm256_min_ps(vout, vmax);
324
325 __m128i vh = _mm256_cvtps_ph(vout, _MM_FROUND_NO_EXC);
326 if (c & 4) {
327 _mm_storel_epi64((__m128i*) o, vh);
328 vh = _mm_unpackhi_epi64(vh, vh);
329 o += 4;
330 }
331 if (c & 2) {
332 _mm_storeu_si32(o, vh);
333 vh = _mm_srli_epi64(vh, 32);
334 o += 2;
335 }
336 if (c & 1) {
337 *o = (uint16_t) _mm_extract_epi16(vh, 0);
338 o += 1;
339 }
340 }
341 }
342 o = (uint16_t*) ((uintptr_t) o + output_increment);
343 } while (--output_pixels != 0);
344 }
345