1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7
8 #include <wasm_simd128.h>
9
10 #include <xnnpack/argmaxpool.h>
11
12
xnn_f32_argmaxpool_ukernel_9p8x__wasmsimd_c4(size_t output_pixels,size_t pooling_elements,size_t channels,const float ** input,size_t input_offset,float * accumulation_buffer,uint32_t * index_buffer,float * output,uint32_t * index,size_t input_increment,size_t output_increment)13 void xnn_f32_argmaxpool_ukernel_9p8x__wasmsimd_c4(
14 size_t output_pixels,
15 size_t pooling_elements,
16 size_t channels,
17 const float** input,
18 size_t input_offset,
19 float* accumulation_buffer,
20 uint32_t* index_buffer,
21 float* output,
22 uint32_t* index,
23 size_t input_increment,
24 size_t output_increment) XNN_OOB_READS
25 {
26 assert(output_pixels != 0);
27 assert(pooling_elements != 0);
28 assert(pooling_elements > 9);
29 assert(channels != 0);
30
31 do {
32 {
33 float* ab = accumulation_buffer;
34 uint32_t* ib = index_buffer;
35
36 const float* i0 = *input++;
37 const float* i1 = *input++;
38 const float* i2 = *input++;
39 const float* i3 = *input++;
40 const float* i4 = *input++;
41 const float* i5 = *input++;
42 const float* i6 = *input++;
43 const float* i7 = *input++;
44 const float* i8 = *input++;
45 i0 = (const float*) ((uintptr_t) i0 + input_offset);
46 i1 = (const float*) ((uintptr_t) i1 + input_offset);
47 i2 = (const float*) ((uintptr_t) i2 + input_offset);
48 i3 = (const float*) ((uintptr_t) i3 + input_offset);
49 i4 = (const float*) ((uintptr_t) i4 + input_offset);
50 i5 = (const float*) ((uintptr_t) i5 + input_offset);
51 i6 = (const float*) ((uintptr_t) i6 + input_offset);
52 i7 = (const float*) ((uintptr_t) i7 + input_offset);
53 i8 = (const float*) ((uintptr_t) i8 + input_offset);
54
55 for (size_t c = 0; c < channels; c += 4) {
56 const v128_t vi0 = wasm_v128_load(i0);
57 i0 += 4;
58 const v128_t vi1 = wasm_v128_load(i1);
59 i1 += 4;
60 const v128_t vi2 = wasm_v128_load(i2);
61 i2 += 4;
62 const v128_t vi3 = wasm_v128_load(i3);
63 i3 += 4;
64 const v128_t vi4 = wasm_v128_load(i4);
65 i4 += 4;
66 const v128_t vi5 = wasm_v128_load(i5);
67 i5 += 4;
68 const v128_t vi6 = wasm_v128_load(i6);
69 i6 += 4;
70 const v128_t vi7 = wasm_v128_load(i7);
71 i7 += 4;
72 const v128_t vi8 = wasm_v128_load(i8);
73 i8 += 4;
74
75 v128_t vmax = vi0;
76 v128_t vidx = wasm_i32x4_const_splat(0);
77
78 const v128_t vm1 = wasm_f32x4_gt(vi1, vmax);
79 vmax = wasm_v128_bitselect(vi1, vmax, vm1);
80 vidx = wasm_v128_bitselect(wasm_i32x4_const_splat(1), vidx, vm1);
81
82 const v128_t vm2 = wasm_f32x4_gt(vi2, vmax);
83 vmax = wasm_v128_bitselect(vi2, vmax, vm2);
84 vidx = wasm_v128_bitselect(wasm_i32x4_const_splat(2), vidx, vm2);
85
86 const v128_t vm3 = wasm_f32x4_gt(vi3, vmax);
87 vmax = wasm_v128_bitselect(vi3, vmax, vm3);
88 vidx = wasm_v128_bitselect(wasm_i32x4_const_splat(3), vidx, vm3);
89
90 const v128_t vm4 = wasm_f32x4_gt(vi4, vmax);
91 vmax = wasm_v128_bitselect(vi4, vmax, vm4);
92 vidx = wasm_v128_bitselect(wasm_i32x4_const_splat(4), vidx, vm4);
93
94 const v128_t vm5 = wasm_f32x4_gt(vi5, vmax);
95 vmax = wasm_v128_bitselect(vi5, vmax, vm5);
96 vidx = wasm_v128_bitselect(wasm_i32x4_const_splat(5), vidx, vm5);
97
98 const v128_t vm6 = wasm_f32x4_gt(vi6, vmax);
99 vmax = wasm_v128_bitselect(vi6, vmax, vm6);
100 vidx = wasm_v128_bitselect(wasm_i32x4_const_splat(6), vidx, vm6);
101
102 const v128_t vm7 = wasm_f32x4_gt(vi7, vmax);
103 vmax = wasm_v128_bitselect(vi7, vmax, vm7);
104 vidx = wasm_v128_bitselect(wasm_i32x4_const_splat(7), vidx, vm7);
105
106 const v128_t vm8 = wasm_f32x4_gt(vi8, vmax);
107 vmax = wasm_v128_bitselect(vi8, vmax, vm8);
108 vidx = wasm_v128_bitselect(wasm_i32x4_const_splat(8), vidx, vm8);
109
110 wasm_v128_store(ab, vmax);
111 ab += 4;
112 wasm_v128_store(ib, vidx);
113 ib += 4;
114 }
115 }
116 const v128_t v1 = wasm_i32x4_const_splat(1);
117 const v128_t v8 = wasm_i32x4_const_splat(8);
118 v128_t vidx0 = wasm_i32x4_add(v1, v8);
119
120 size_t k = pooling_elements;
121 for (k -= 9; k > 8; k -= 8) {
122 const float* i0 = *input++;
123 const float* i1 = *input++;
124 const float* i2 = *input++;
125 const float* i3 = *input++;
126 const float* i4 = *input++;
127 const float* i5 = *input++;
128 const float* i6 = *input++;
129 const float* i7 = *input++;
130 i0 = (const float*) ((uintptr_t) i0 + input_offset);
131 i1 = (const float*) ((uintptr_t) i1 + input_offset);
132 i2 = (const float*) ((uintptr_t) i2 + input_offset);
133 i3 = (const float*) ((uintptr_t) i3 + input_offset);
134 i4 = (const float*) ((uintptr_t) i4 + input_offset);
135 i5 = (const float*) ((uintptr_t) i5 + input_offset);
136 i6 = (const float*) ((uintptr_t) i6 + input_offset);
137 i7 = (const float*) ((uintptr_t) i7 + input_offset);
138
139 float* ab = accumulation_buffer;
140 uint32_t* ib = index_buffer;
141
142 for (size_t c = 0; c < channels; c += 4) {
143 const v128_t vi0 = wasm_v128_load(i0);
144 i0 += 4;
145 const v128_t vi1 = wasm_v128_load(i1);
146 i1 += 4;
147 const v128_t vi2 = wasm_v128_load(i2);
148 i2 += 4;
149 const v128_t vi3 = wasm_v128_load(i3);
150 i3 += 4;
151 const v128_t vi4 = wasm_v128_load(i4);
152 i4 += 4;
153 const v128_t vi5 = wasm_v128_load(i5);
154 i5 += 4;
155 const v128_t vi6 = wasm_v128_load(i6);
156 i6 += 4;
157 const v128_t vi7 = wasm_v128_load(i7);
158 i7 += 4;
159
160 v128_t vmax = wasm_v128_load(ab);
161 v128_t vidx = wasm_v128_load(ib);
162
163 const v128_t vm0 = wasm_f32x4_gt(vi0, vmax);
164 vmax = wasm_v128_bitselect(vi0, vmax, vm0);
165 vidx = wasm_v128_bitselect(vidx0, vidx, vm0);
166
167 const v128_t vm1 = wasm_f32x4_gt(vi1, vmax);
168 const v128_t vidx1 = wasm_i32x4_add(vidx0, v1);
169 vmax = wasm_v128_bitselect(vi1, vmax, vm1);
170 vidx = wasm_v128_bitselect(vidx1, vidx, vm1);
171
172 const v128_t vm2 = wasm_f32x4_gt(vi2, vmax);
173 const v128_t vidx2 = wasm_i32x4_add(vidx1, v1);
174 vmax = wasm_v128_bitselect(vi2, vmax, vm2);
175 vidx = wasm_v128_bitselect(vidx2, vidx, vm2);
176
177 const v128_t vm3 = wasm_f32x4_gt(vi3, vmax);
178 const v128_t vidx3 = wasm_i32x4_add(vidx2, v1);
179 vmax = wasm_v128_bitselect(vi3, vmax, vm3);
180 vidx = wasm_v128_bitselect(vidx3, vidx, vm3);
181
182 const v128_t vm4 = wasm_f32x4_gt(vi4, vmax);
183 const v128_t vidx4 = wasm_i32x4_add(vidx3, v1);
184 vmax = wasm_v128_bitselect(vi4, vmax, vm4);
185 vidx = wasm_v128_bitselect(vidx4, vidx, vm4);
186
187 const v128_t vm5 = wasm_f32x4_gt(vi5, vmax);
188 const v128_t vidx5 = wasm_i32x4_add(vidx4, v1);
189 vmax = wasm_v128_bitselect(vi5, vmax, vm5);
190 vidx = wasm_v128_bitselect(vidx5, vidx, vm5);
191
192 const v128_t vm6 = wasm_f32x4_gt(vi6, vmax);
193 const v128_t vidx6 = wasm_i32x4_add(vidx5, v1);
194 vmax = wasm_v128_bitselect(vi6, vmax, vm6);
195 vidx = wasm_v128_bitselect(vidx6, vidx, vm6);
196
197 const v128_t vm7 = wasm_f32x4_gt(vi7, vmax);
198 const v128_t vidx7 = wasm_i32x4_add(vidx6, v1);
199 vmax = wasm_v128_bitselect(vi7, vmax, vm7);
200 vidx = wasm_v128_bitselect(vidx7, vidx, vm7);
201
202 wasm_v128_store(ab, vmax);
203 ab += 4;
204 wasm_v128_store(ib, vidx);
205 ib += 4;
206 }
207 vidx0 = wasm_i32x4_add(vidx0, v8);
208 }
209
210 float* o = output;
211 float* i = (float*) index;
212 {
213 const float* i0 = input[0];
214 const float* i1 = input[1];
215 const float* i2 = input[2];
216 const float* i3 = input[3];
217 const float* i4 = input[4];
218 const float* i5 = input[5];
219 const float* i6 = input[6];
220 const float* i7 = input[7];
221 i0 = (const float*) ((uintptr_t) i0 + input_offset);
222 i1 = (const float*) ((uintptr_t) i1 + input_offset);
223 i2 = (const float*) ((uintptr_t) i2 + input_offset);
224 i3 = (const float*) ((uintptr_t) i3 + input_offset);
225 i4 = (const float*) ((uintptr_t) i4 + input_offset);
226 i5 = (const float*) ((uintptr_t) i5 + input_offset);
227 i6 = (const float*) ((uintptr_t) i6 + input_offset);
228 i7 = (const float*) ((uintptr_t) i7 + input_offset);
229 input = (const float**) ((uintptr_t) input + input_increment);
230 if (k < 2) {
231 i1 = i0;
232 }
233 if (k <= 2) {
234 i2 = i0;
235 }
236 if (k < 4) {
237 i3 = i0;
238 }
239 if (k <= 4) {
240 i4 = i0;
241 }
242 if (k < 6) {
243 i5 = i0;
244 }
245 if (k <= 6) {
246 i6 = i0;
247 }
248 if (k != 8) {
249 i7 = i0;
250 }
251
252 size_t c = channels;
253 float* ab = accumulation_buffer;
254 uint32_t* ib = index_buffer;
255 for (; c >= 4; c -= 4) {
256 const v128_t vi0 = wasm_v128_load(i0);
257 i0 += 4;
258 const v128_t vi1 = wasm_v128_load(i1);
259 i1 += 4;
260 const v128_t vi2 = wasm_v128_load(i2);
261 i2 += 4;
262 const v128_t vi3 = wasm_v128_load(i3);
263 i3 += 4;
264 const v128_t vi4 = wasm_v128_load(i4);
265 i4 += 4;
266 const v128_t vi5 = wasm_v128_load(i5);
267 i5 += 4;
268 const v128_t vi6 = wasm_v128_load(i6);
269 i6 += 4;
270 const v128_t vi7 = wasm_v128_load(i7);
271 i7 += 4;
272
273 v128_t vmax = wasm_v128_load(ab);
274 ab += 4;
275 v128_t vidx = wasm_v128_load(ib);
276 ib += 4;
277
278 const v128_t vm0 = wasm_f32x4_gt(vi0, vmax);
279 vmax = wasm_v128_bitselect(vi0, vmax, vm0);
280 vidx = wasm_v128_bitselect(vidx0, vidx, vm0);
281
282 const v128_t vm1 = wasm_f32x4_gt(vi1, vmax);
283 const v128_t vidx1 = wasm_i32x4_add(vidx0, v1);
284 vmax = wasm_v128_bitselect(vi1, vmax, vm1);
285 vidx = wasm_v128_bitselect(vidx1, vidx, vm1);
286
287 const v128_t vm2 = wasm_f32x4_gt(vi2, vmax);
288 const v128_t vidx2 = wasm_i32x4_add(vidx1, v1);
289 vmax = wasm_v128_bitselect(vi2, vmax, vm2);
290 vidx = wasm_v128_bitselect(vidx2, vidx, vm2);
291
292 const v128_t vm3 = wasm_f32x4_gt(vi3, vmax);
293 const v128_t vidx3 = wasm_i32x4_add(vidx2, v1);
294 vmax = wasm_v128_bitselect(vi3, vmax, vm3);
295 vidx = wasm_v128_bitselect(vidx3, vidx, vm3);
296
297 const v128_t vm4 = wasm_f32x4_gt(vi4, vmax);
298 const v128_t vidx4 = wasm_i32x4_add(vidx3, v1);
299 vmax = wasm_v128_bitselect(vi4, vmax, vm4);
300 vidx = wasm_v128_bitselect(vidx4, vidx, vm4);
301
302 const v128_t vm5 = wasm_f32x4_gt(vi5, vmax);
303 const v128_t vidx5 = wasm_i32x4_add(vidx4, v1);
304 vmax = wasm_v128_bitselect(vi5, vmax, vm5);
305 vidx = wasm_v128_bitselect(vidx5, vidx, vm5);
306
307 const v128_t vm6 = wasm_f32x4_gt(vi6, vmax);
308 const v128_t vidx6 = wasm_i32x4_add(vidx5, v1);
309 vmax = wasm_v128_bitselect(vi6, vmax, vm6);
310 vidx = wasm_v128_bitselect(vidx6, vidx, vm6);
311
312 const v128_t vm7 = wasm_f32x4_gt(vi7, vmax);
313 const v128_t vidx7 = wasm_i32x4_add(vidx6, v1);
314 vmax = wasm_v128_bitselect(vi7, vmax, vm7);
315 vidx = wasm_v128_bitselect(vidx7, vidx, vm7);
316
317 wasm_v128_store(o, vmax);
318 o += 4;
319 wasm_v128_store(i, vidx);
320 i += 4;
321 }
322 if (c != 0) {
323 const v128_t vi0 = wasm_v128_load(i0);
324 const v128_t vi1 = wasm_v128_load(i1);
325 const v128_t vi2 = wasm_v128_load(i2);
326 const v128_t vi3 = wasm_v128_load(i3);
327 const v128_t vi4 = wasm_v128_load(i4);
328 const v128_t vi5 = wasm_v128_load(i5);
329 const v128_t vi6 = wasm_v128_load(i6);
330 const v128_t vi7 = wasm_v128_load(i7);
331
332 v128_t vmax = wasm_v128_load(ab);
333 v128_t vidx = wasm_v128_load(ib);
334
335 const v128_t vm0 = wasm_f32x4_gt(vi0, vmax);
336 vmax = wasm_v128_bitselect(vi0, vmax, vm0);
337 vidx = wasm_v128_bitselect(vidx0, vidx, vm0);
338
339 const v128_t vm1 = wasm_f32x4_gt(vi1, vmax);
340 const v128_t vidx1 = wasm_i32x4_add(vidx0, v1);
341 vmax = wasm_v128_bitselect(vi1, vmax, vm1);
342 vidx = wasm_v128_bitselect(vidx1, vidx, vm1);
343
344 const v128_t vm2 = wasm_f32x4_gt(vi2, vmax);
345 const v128_t vidx2 = wasm_i32x4_add(vidx1, v1);
346 vmax = wasm_v128_bitselect(vi2, vmax, vm2);
347 vidx = wasm_v128_bitselect(vidx2, vidx, vm2);
348
349 const v128_t vm3 = wasm_f32x4_gt(vi3, vmax);
350 const v128_t vidx3 = wasm_i32x4_add(vidx2, v1);
351 vmax = wasm_v128_bitselect(vi3, vmax, vm3);
352 vidx = wasm_v128_bitselect(vidx3, vidx, vm3);
353
354 const v128_t vm4 = wasm_f32x4_gt(vi4, vmax);
355 const v128_t vidx4 = wasm_i32x4_add(vidx3, v1);
356 vmax = wasm_v128_bitselect(vi4, vmax, vm4);
357 vidx = wasm_v128_bitselect(vidx4, vidx, vm4);
358
359 const v128_t vm5 = wasm_f32x4_gt(vi5, vmax);
360 const v128_t vidx5 = wasm_i32x4_add(vidx4, v1);
361 vmax = wasm_v128_bitselect(vi5, vmax, vm5);
362 vidx = wasm_v128_bitselect(vidx5, vidx, vm5);
363
364 const v128_t vm6 = wasm_f32x4_gt(vi6, vmax);
365 const v128_t vidx6 = wasm_i32x4_add(vidx5, v1);
366 vmax = wasm_v128_bitselect(vi6, vmax, vm6);
367 vidx = wasm_v128_bitselect(vidx6, vidx, vm6);
368
369 const v128_t vm7 = wasm_f32x4_gt(vi7, vmax);
370 const v128_t vidx7 = wasm_i32x4_add(vidx6, v1);
371 vmax = wasm_v128_bitselect(vi7, vmax, vm7);
372 vidx = wasm_v128_bitselect(vidx7, vidx, vm7);
373
374 if (c & 2) {
375 *((double*) o) = wasm_f64x2_extract_lane(vmax, 0);
376 *((double*) i) = wasm_f64x2_extract_lane(vidx, 0);
377 vmax = wasm_v32x4_shuffle(vmax, vmax, 2, 3, 2, 3);
378 vidx = wasm_v32x4_shuffle(vidx, vidx, 2, 3, 2, 3);
379 o += 2;
380 i += 2;
381 }
382 if (c & 1) {
383 *o++ = wasm_f32x4_extract_lane(vmax, 0);
384 *i++ = wasm_f32x4_extract_lane(vidx, 0);
385 }
386 }
387 }
388
389 output = (float*) ((uintptr_t) o + output_increment);
390 index = (uint32_t*) i;
391 } while (--output_pixels != 0);
392 }
393