xref: /aosp_15_r20/external/XNNPACK/src/qu8-gavgpool/gen/7p7x-minmax-fp32-wasmsimd-c16.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Auto-generated file. Do not edit!
2 //   Template: src/qs8-gavgpool/multipass-wasmsimd.c.in
3 //   Generator: tools/xngen
4 //
5 // Copyright 2020 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9 
10 #include <assert.h>
11 
12 #include <wasm_simd128.h>
13 
14 #include <xnnpack/gavgpool.h>
15 #include <xnnpack/math.h>
16 
17 
xnn_qu8_gavgpool_minmax_fp32_ukernel_7p7x__wasmsimd_c16(size_t rows,size_t channels,const uint8_t * input,size_t input_stride,const uint8_t * zero,int32_t * buffer,uint8_t * output,const union xnn_qu8_avgpool_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])18 void xnn_qu8_gavgpool_minmax_fp32_ukernel_7p7x__wasmsimd_c16(
19     size_t rows,
20     size_t channels,
21     const uint8_t* input,
22     size_t input_stride,
23     const uint8_t* zero,
24     int32_t* buffer,
25     uint8_t* output,
26     const union xnn_qu8_avgpool_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
27 {
28   assert(rows > 7);
29   assert(channels != 0);
30 
31   const uint8_t* i0 = input;
32   const uint8_t* i1 = (const uint8_t*) ((uintptr_t) i0 + input_stride);
33   const uint8_t* i2 = (const uint8_t*) ((uintptr_t) i1 + input_stride);
34   const uint8_t* i3 = (const uint8_t*) ((uintptr_t) i2 + input_stride);
35   const uint8_t* i4 = (const uint8_t*) ((uintptr_t) i3 + input_stride);
36   const uint8_t* i5 = (const uint8_t*) ((uintptr_t) i4 + input_stride);
37   const uint8_t* i6 = (const uint8_t*) ((uintptr_t) i5 + input_stride);
38   const size_t input_increment = 7 * input_stride - round_up_po2(channels, 16) * sizeof(uint8_t);
39 
40   const v128_t vinit_bias = wasm_v128_load64_splat(params->fp32_wasmsimd.init_bias);
41   int32_t* b = buffer;
42   size_t c = channels;
43   for (; c != 0; c = doz(c, 16)) {
44     const v128_t vxi0x01234567 = wasm_u16x8_load8x8(i0);
45     const v128_t vxi0x89ABCDEF = wasm_u16x8_load8x8(i0 + 8);
46     i0 += 16;
47     const v128_t vxi1x01234567 = wasm_u16x8_load8x8(i1);
48     const v128_t vxi1x89ABCDEF = wasm_u16x8_load8x8(i1 + 8);
49     i1 += 16;
50 
51     v128_t vacc01234567 = wasm_i16x8_add(vxi0x01234567, vxi1x01234567);
52     const v128_t vxi2x01234567 = wasm_u16x8_load8x8(i2);
53     v128_t vacc89ABCDEF = wasm_i16x8_add(vxi0x89ABCDEF, vxi1x89ABCDEF);
54     const v128_t vxi2x89ABCDEF = wasm_u16x8_load8x8(i2 + 8);
55     i2 += 16;
56 
57     vacc01234567 = wasm_i16x8_add(vacc01234567, vxi2x01234567);
58     const v128_t vxi3x01234567 = wasm_u16x8_load8x8(i3);
59     vacc89ABCDEF = wasm_i16x8_add(vacc89ABCDEF, vxi2x89ABCDEF);
60     const v128_t vxi3x89ABCDEF = wasm_u16x8_load8x8(i3 + 8);
61     i3 += 16;
62     vacc01234567 = wasm_i16x8_add(vacc01234567, vxi3x01234567);
63     const v128_t vxi4x01234567 = wasm_u16x8_load8x8(i4);
64     vacc89ABCDEF = wasm_i16x8_add(vacc89ABCDEF, vxi3x89ABCDEF);
65     const v128_t vxi4x89ABCDEF = wasm_u16x8_load8x8(i4 + 8);
66     i4 += 16;
67     vacc01234567 = wasm_i16x8_add(vacc01234567, vxi4x01234567);
68     const v128_t vxi5x01234567 = wasm_u16x8_load8x8(i5);
69     vacc89ABCDEF = wasm_i16x8_add(vacc89ABCDEF, vxi4x89ABCDEF);
70     const v128_t vxi5x89ABCDEF = wasm_u16x8_load8x8(i5 + 8);
71     i5 += 16;
72     vacc01234567 = wasm_i16x8_add(vacc01234567, vxi5x01234567);
73     const v128_t vxi6x01234567 = wasm_u16x8_load8x8(i6);
74     vacc89ABCDEF = wasm_i16x8_add(vacc89ABCDEF, vxi5x89ABCDEF);
75     const v128_t vxi6x89ABCDEF = wasm_u16x8_load8x8(i6 + 8);
76     i6 += 16;
77 
78     vacc01234567 = wasm_i16x8_add(vacc01234567, vxi6x01234567);
79     vacc89ABCDEF = wasm_i16x8_add(vacc89ABCDEF, vxi6x89ABCDEF);
80 
81     const v128_t vacc0123 = wasm_i32x4_add(vinit_bias, wasm_u32x4_extend_low_u16x8(vacc01234567));
82     const v128_t vacc4567 = wasm_i32x4_add(vinit_bias, wasm_u32x4_extend_high_u16x8(vacc01234567));
83     const v128_t vacc89AB = wasm_i32x4_add(vinit_bias, wasm_u32x4_extend_low_u16x8(vacc89ABCDEF));
84     const v128_t vaccCDEF = wasm_i32x4_add(vinit_bias, wasm_u32x4_extend_high_u16x8(vacc89ABCDEF));
85 
86     wasm_v128_store(b, vacc0123);
87     wasm_v128_store(b + 4, vacc4567);
88     wasm_v128_store(b + 8, vacc89AB);
89     wasm_v128_store(b + 12, vaccCDEF);
90     b += 16;
91   }
92 
93   for (rows -= 7; rows > 7; rows -= 7) {
94     i0 = (const uint8_t*) ((uintptr_t) i0 + input_increment);
95     i1 = (const uint8_t*) ((uintptr_t) i1 + input_increment);
96     i2 = (const uint8_t*) ((uintptr_t) i2 + input_increment);
97     i3 = (const uint8_t*) ((uintptr_t) i3 + input_increment);
98     i4 = (const uint8_t*) ((uintptr_t) i4 + input_increment);
99     i5 = (const uint8_t*) ((uintptr_t) i5 + input_increment);
100     i6 = (const uint8_t*) ((uintptr_t) i6 + input_increment);
101 
102     int32_t* b = buffer;
103     size_t c = channels;
104     for (; c != 0; c = doz(c, 16)) {
105       const v128_t vxi0x01234567 = wasm_u16x8_load8x8(i0);
106       const v128_t vxi0x89ABCDEF = wasm_u16x8_load8x8(i0 + 8);
107       i0 += 16;
108       const v128_t vxi1x01234567 = wasm_u16x8_load8x8(i1);
109       const v128_t vxi1x89ABCDEF = wasm_u16x8_load8x8(i1 + 8);
110       i1 += 16;
111 
112       v128_t vacc01234567 = wasm_i16x8_add(vxi0x01234567, vxi1x01234567);
113       const v128_t vxi2x01234567 = wasm_u16x8_load8x8(i2);
114       v128_t vacc89ABCDEF = wasm_i16x8_add(vxi0x89ABCDEF, vxi1x89ABCDEF);
115       const v128_t vxi2x89ABCDEF = wasm_u16x8_load8x8(i2 + 8);
116       i2 += 16;
117 
118       vacc01234567 = wasm_i16x8_add(vacc01234567, vxi2x01234567);
119       const v128_t vxi3x01234567 = wasm_u16x8_load8x8(i3);
120       vacc89ABCDEF = wasm_i16x8_add(vacc89ABCDEF, vxi2x89ABCDEF);
121       const v128_t vxi3x89ABCDEF = wasm_u16x8_load8x8(i3 + 8);
122       i3 += 16;
123       vacc01234567 = wasm_i16x8_add(vacc01234567, vxi3x01234567);
124       const v128_t vxi4x01234567 = wasm_u16x8_load8x8(i4);
125       vacc89ABCDEF = wasm_i16x8_add(vacc89ABCDEF, vxi3x89ABCDEF);
126       const v128_t vxi4x89ABCDEF = wasm_u16x8_load8x8(i4 + 8);
127       i4 += 16;
128       vacc01234567 = wasm_i16x8_add(vacc01234567, vxi4x01234567);
129       const v128_t vxi5x01234567 = wasm_u16x8_load8x8(i5);
130       vacc89ABCDEF = wasm_i16x8_add(vacc89ABCDEF, vxi4x89ABCDEF);
131       const v128_t vxi5x89ABCDEF = wasm_u16x8_load8x8(i5 + 8);
132       i5 += 16;
133       vacc01234567 = wasm_i16x8_add(vacc01234567, vxi5x01234567);
134       const v128_t vxi6x01234567 = wasm_u16x8_load8x8(i6);
135       vacc89ABCDEF = wasm_i16x8_add(vacc89ABCDEF, vxi5x89ABCDEF);
136       const v128_t vxi6x89ABCDEF = wasm_u16x8_load8x8(i6 + 8);
137       i6 += 16;
138 
139       vacc01234567 = wasm_i16x8_add(vacc01234567, vxi6x01234567);
140       vacc89ABCDEF = wasm_i16x8_add(vacc89ABCDEF, vxi6x89ABCDEF);
141 
142       v128_t vacc0123 = wasm_v128_load(b);
143       v128_t vacc4567 = wasm_v128_load(b + 4);
144       v128_t vacc89AB = wasm_v128_load(b + 8);
145       v128_t vaccCDEF = wasm_v128_load(b + 12);
146 
147       vacc0123 = wasm_i32x4_add(vacc0123, wasm_u32x4_extend_low_u16x8(vacc01234567));
148       vacc4567 = wasm_i32x4_add(vacc4567, wasm_u32x4_extend_high_u16x8(vacc01234567));
149       vacc89AB = wasm_i32x4_add(vacc89AB, wasm_u32x4_extend_low_u16x8(vacc89ABCDEF));
150       vaccCDEF = wasm_i32x4_add(vaccCDEF, wasm_u32x4_extend_high_u16x8(vacc89ABCDEF));
151 
152       wasm_v128_store(b, vacc0123);
153       wasm_v128_store(b + 4, vacc4567);
154       wasm_v128_store(b + 8, vacc89AB);
155       wasm_v128_store(b + 12, vaccCDEF);
156       b += 16;
157     }
158   }
159 
160   i0 = (const uint8_t*) ((uintptr_t) i0 + input_increment);
161   i1 = (const uint8_t*) ((uintptr_t) i1 + input_increment);
162   if XNN_UNPREDICTABLE(rows < 2) {
163     i1 = zero;
164   }
165   i2 = (const uint8_t*) ((uintptr_t) i2 + input_increment);
166   if XNN_UNPREDICTABLE(rows <= 2) {
167     i2 = zero;
168   }
169   i3 = (const uint8_t*) ((uintptr_t) i3 + input_increment);
170   if XNN_UNPREDICTABLE(rows < 4) {
171     i3 = zero;
172   }
173   i4 = (const uint8_t*) ((uintptr_t) i4 + input_increment);
174   if XNN_UNPREDICTABLE(rows <= 4) {
175     i4 = zero;
176   }
177   i5 = (const uint8_t*) ((uintptr_t) i5 + input_increment);
178   if XNN_UNPREDICTABLE(rows < 6) {
179     i5 = zero;
180   }
181   i6 = (const uint8_t*) ((uintptr_t) i6 + input_increment);
182   if XNN_UNPREDICTABLE(rows <= 6) {
183     i6 = zero;
184   }
185 
186   const v128_t vscale = wasm_v128_load64_splat(params->fp32_wasmsimd.scale);
187   const v128_t vmagic_bias = wasm_v128_load64_splat(params->fp32_wasmsimd.magic_bias);
188   const v128_t vmagic_min = wasm_v128_load64_splat(params->fp32_wasmsimd.magic_min);
189   const v128_t vmagic_bias_less_output_zero_point = wasm_v128_load64_splat(params->fp32_wasmsimd.magic_bias_less_output_zero_point);
190   const v128_t voutput_max = wasm_v128_load64_splat(params->fp32_wasmsimd.output_max);
191   for (; channels >= 16; channels -= 16) {
192     const v128_t vxi0x01234567 = wasm_u16x8_load8x8(i0);
193     const v128_t vxi0x89ABCDEF = wasm_u16x8_load8x8(i0 + 8);
194     i0 += 16;
195     const v128_t vxi1x01234567 = wasm_u16x8_load8x8(i1);
196     const v128_t vxi1x89ABCDEF = wasm_u16x8_load8x8(i1 + 8);
197     i1 += 16;
198 
199     v128_t vacc01234567 = wasm_i16x8_add(vxi0x01234567, vxi1x01234567);
200     const v128_t vxi2x01234567 = wasm_u16x8_load8x8(i2);
201     v128_t vacc89ABCDEF = wasm_i16x8_add(vxi0x89ABCDEF, vxi1x89ABCDEF);
202     const v128_t vxi2x89ABCDEF = wasm_u16x8_load8x8(i2 + 8);
203     i2 += 16;
204 
205     vacc01234567 = wasm_i16x8_add(vacc01234567, vxi2x01234567);
206     const v128_t vxi3x01234567 = wasm_u16x8_load8x8(i3);
207     vacc89ABCDEF = wasm_i16x8_add(vacc89ABCDEF, vxi2x89ABCDEF);
208     const v128_t vxi3x89ABCDEF = wasm_u16x8_load8x8(i3 + 8);
209     i3 += 16;
210     vacc01234567 = wasm_i16x8_add(vacc01234567, vxi3x01234567);
211     const v128_t vxi4x01234567 = wasm_u16x8_load8x8(i4);
212     vacc89ABCDEF = wasm_i16x8_add(vacc89ABCDEF, vxi3x89ABCDEF);
213     const v128_t vxi4x89ABCDEF = wasm_u16x8_load8x8(i4 + 8);
214     i4 += 16;
215     vacc01234567 = wasm_i16x8_add(vacc01234567, vxi4x01234567);
216     const v128_t vxi5x01234567 = wasm_u16x8_load8x8(i5);
217     vacc89ABCDEF = wasm_i16x8_add(vacc89ABCDEF, vxi4x89ABCDEF);
218     const v128_t vxi5x89ABCDEF = wasm_u16x8_load8x8(i5 + 8);
219     i5 += 16;
220     vacc01234567 = wasm_i16x8_add(vacc01234567, vxi5x01234567);
221     const v128_t vxi6x01234567 = wasm_u16x8_load8x8(i6);
222     vacc89ABCDEF = wasm_i16x8_add(vacc89ABCDEF, vxi5x89ABCDEF);
223     const v128_t vxi6x89ABCDEF = wasm_u16x8_load8x8(i6 + 8);
224     i6 += 16;
225 
226     vacc01234567 = wasm_i16x8_add(vacc01234567, vxi6x01234567);
227     vacc89ABCDEF = wasm_i16x8_add(vacc89ABCDEF, vxi6x89ABCDEF);
228 
229     v128_t vacc0123 = wasm_v128_load(buffer);
230     v128_t vacc4567 = wasm_v128_load(buffer + 4);
231     v128_t vacc89AB = wasm_v128_load(buffer + 8);
232     v128_t vaccCDEF = wasm_v128_load(buffer + 12);
233     buffer += 16;
234 
235     vacc0123 = wasm_i32x4_add(vacc0123, wasm_u32x4_extend_low_u16x8(vacc01234567));
236     vacc4567 = wasm_i32x4_add(vacc4567, wasm_u32x4_extend_high_u16x8(vacc01234567));
237     vacc89AB = wasm_i32x4_add(vacc89AB, wasm_u32x4_extend_low_u16x8(vacc89ABCDEF));
238     vaccCDEF = wasm_i32x4_add(vaccCDEF, wasm_u32x4_extend_high_u16x8(vacc89ABCDEF));
239 
240     vacc0123 = wasm_f32x4_convert_i32x4(vacc0123);
241     vacc4567 = wasm_f32x4_convert_i32x4(vacc4567);
242     vacc89AB = wasm_f32x4_convert_i32x4(vacc89AB);
243     vaccCDEF = wasm_f32x4_convert_i32x4(vaccCDEF);
244 
245     vacc0123 = wasm_f32x4_mul(vacc0123, vscale);
246     vacc4567 = wasm_f32x4_mul(vacc4567, vscale);
247     vacc89AB = wasm_f32x4_mul(vacc89AB, vscale);
248     vaccCDEF = wasm_f32x4_mul(vaccCDEF, vscale);
249 
250     vacc0123 = wasm_f32x4_add(vacc0123, vmagic_bias);
251     vacc4567 = wasm_f32x4_add(vacc4567, vmagic_bias);
252     vacc89AB = wasm_f32x4_add(vacc89AB, vmagic_bias);
253     vaccCDEF = wasm_f32x4_add(vaccCDEF, vmagic_bias);
254 
255     vacc0123 = wasm_i32x4_max(vacc0123, vmagic_min);
256     vacc4567 = wasm_i32x4_max(vacc4567, vmagic_min);
257     vacc89AB = wasm_i32x4_max(vacc89AB, vmagic_min);
258     vaccCDEF = wasm_i32x4_max(vaccCDEF, vmagic_min);
259 
260     vacc0123 = wasm_i32x4_sub(vacc0123, vmagic_bias_less_output_zero_point);
261     vacc4567 = wasm_i32x4_sub(vacc4567, vmagic_bias_less_output_zero_point);
262     vacc89AB = wasm_i32x4_sub(vacc89AB, vmagic_bias_less_output_zero_point);
263     vaccCDEF = wasm_i32x4_sub(vaccCDEF, vmagic_bias_less_output_zero_point);
264 
265     v128_t vout01234567 = wasm_i16x8_narrow_i32x4(vacc0123, vacc4567);
266     v128_t vout89ABCDEF = wasm_i16x8_narrow_i32x4(vacc89AB, vaccCDEF);
267 
268     v128_t vout0123456789ABCDEF = wasm_u8x16_narrow_i16x8(vout01234567, vout89ABCDEF);
269 
270     vout0123456789ABCDEF = wasm_u8x16_min(vout0123456789ABCDEF, voutput_max);
271 
272     wasm_v128_store(output, vout0123456789ABCDEF);
273     output += 16;
274   }
275   if XNN_UNLIKELY(channels != 0) {
276     do {
277       const v128_t vxi0x01234567 = wasm_u16x8_load8x8(i0);
278       i0 += 8;
279       const v128_t vxi1x01234567 = wasm_u16x8_load8x8(i1);
280       i1 += 8;
281 
282       v128_t vacc01234567 = wasm_i16x8_add(vxi0x01234567, vxi1x01234567);
283       const v128_t vxi2x01234567 = wasm_u16x8_load8x8(i2);
284       i2 += 8;
285 
286       vacc01234567 = wasm_i16x8_add(vacc01234567, vxi2x01234567);
287       const v128_t vxi3x01234567 = wasm_u16x8_load8x8(i3);
288       i3 += 8;
289       vacc01234567 = wasm_i16x8_add(vacc01234567, vxi3x01234567);
290       const v128_t vxi4x01234567 = wasm_u16x8_load8x8(i4);
291       i4 += 8;
292       vacc01234567 = wasm_i16x8_add(vacc01234567, vxi4x01234567);
293       const v128_t vxi5x01234567 = wasm_u16x8_load8x8(i5);
294       i5 += 8;
295       vacc01234567 = wasm_i16x8_add(vacc01234567, vxi5x01234567);
296       const v128_t vxi6x01234567 = wasm_u16x8_load8x8(i6);
297       i6 += 8;
298 
299       vacc01234567 = wasm_i16x8_add(vacc01234567, vxi6x01234567);
300 
301       v128_t vacc0123 = wasm_v128_load(buffer);
302       v128_t vacc4567 = wasm_v128_load(buffer + 4);
303       buffer += 8;
304 
305       vacc0123 = wasm_i32x4_add(vacc0123, wasm_u32x4_extend_low_u16x8(vacc01234567));
306       vacc4567 = wasm_i32x4_add(vacc4567, wasm_u32x4_extend_high_u16x8(vacc01234567));
307 
308       vacc0123 = wasm_f32x4_convert_i32x4(vacc0123);
309       vacc4567 = wasm_f32x4_convert_i32x4(vacc4567);
310 
311       vacc0123 = wasm_f32x4_mul(vacc0123, vscale);
312       vacc4567 = wasm_f32x4_mul(vacc4567, vscale);
313 
314       vacc0123 = wasm_f32x4_add(vacc0123, vmagic_bias);
315       vacc4567 = wasm_f32x4_add(vacc4567, vmagic_bias);
316 
317       vacc0123 = wasm_i32x4_max(vacc0123, vmagic_min);
318       vacc4567 = wasm_i32x4_max(vacc4567, vmagic_min);
319 
320       vacc0123 = wasm_i32x4_sub(vacc0123, vmagic_bias_less_output_zero_point);
321       vacc4567 = wasm_i32x4_sub(vacc4567, vmagic_bias_less_output_zero_point);
322 
323       const v128_t vout01234567 = wasm_i16x8_narrow_i32x4(vacc0123, vacc4567);
324       v128_t vout0123456701234567 = wasm_u8x16_narrow_i16x8(vout01234567, vout01234567);
325       vout0123456701234567 = wasm_u8x16_min(vout0123456701234567, voutput_max);
326 
327       if XNN_LIKELY(channels >= 8) {
328         *((double*) output) = wasm_f64x2_extract_lane(vout0123456701234567, 0);
329         output += 8;
330         channels -= 8;
331       } else {
332         if (channels & 4) {
333           *((float*) output) = wasm_f32x4_extract_lane(vout0123456701234567, 0);
334           vout0123456701234567 = wasm_u64x2_shr(vout0123456701234567, 32);
335           output += 4;
336         }
337         uint32_t vout0123 = wasm_i32x4_extract_lane(vout0123456701234567, 0);
338         if (channels & 2) {
339           *((uint16_t*) output) = (uint16_t) vout0123;
340           vout0123 >>= 16;
341           output += 2;
342         }
343         if (channels & 1) {
344           *output = (uint8_t) vout0123;
345           output += 1;
346         }
347         channels = 0;
348       }
349     } while (channels != 0);
350   }
351 }
352