xref: /aosp_15_r20/external/XNNPACK/src/f32-gavgpool-cw/neon-x4.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 
8 #include <arm_neon.h>
9 
10 #include <xnnpack/gavgpool.h>
11 #include <xnnpack/math.h>
12 
13 
xnn_f32_gavgpool_cw_ukernel__neon_x4(size_t elements,size_t channels,const float * input,float * output,const union xnn_f32_gavgpool_params params[restrict XNN_MIN_ELEMENTS (1)])14 void xnn_f32_gavgpool_cw_ukernel__neon_x4(
15     size_t elements,
16     size_t channels,
17     const float* input,
18     float* output,
19     const union xnn_f32_gavgpool_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
20 {
21   assert(elements != 0);
22   assert(elements % sizeof(float) == 0);
23   assert(channels != 0);
24 
25   const float* i0 = input;
26   const float* i1 = (const float*) ((uintptr_t) i0 + elements);
27   const float* i2 = (const float*) ((uintptr_t) i1 + elements);
28   const float* i3 = (const float*) ((uintptr_t) i2 + elements);
29 
30   const uint32x4_t vmask = vld1q_u32(params->neon.mask);
31   const float32x4_t vmultiplier = vld1q_dup_f32(&params->neon.multiplier);
32   const float32x4_t voutput_min = vld1q_dup_f32(&params->neon.output_min);
33   const float32x4_t voutput_max = vld1q_dup_f32(&params->neon.output_max);
34 
35   while (channels >= 4) {
36     float32x4_t vsum0 = vmovq_n_f32(0.0f);
37     float32x4_t vsum1 = vmovq_n_f32(0.0f);
38     float32x4_t vsum2 = vmovq_n_f32(0.0f);
39     float32x4_t vsum3 = vmovq_n_f32(0.0f);
40     size_t n = elements;
41     while (n >= 4 * sizeof(float)) {
42       const float32x4_t vi0 = vld1q_f32(i0); i0 += 4;
43       const float32x4_t vi1 = vld1q_f32(i1); i1 += 4;
44       const float32x4_t vi2 = vld1q_f32(i2); i2 += 4;
45       const float32x4_t vi3 = vld1q_f32(i3); i3 += 4;
46 
47       vsum0 = vaddq_f32(vsum0, vi0);
48       vsum1 = vaddq_f32(vsum1, vi1);
49       vsum2 = vaddq_f32(vsum2, vi2);
50       vsum3 = vaddq_f32(vsum3, vi3);
51       n -= 4 * sizeof(float);
52     }
53 
54     if XNN_UNLIKELY(n != 0) {
55       float32x4_t vi0 = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + n);
56       float32x4_t vi1 = vld1q_f32(i1); i1 = (const float*) ((uintptr_t) i1 + n);
57       float32x4_t vi2 = vld1q_f32(i2); i2 = (const float*) ((uintptr_t) i2 + n);
58       float32x4_t vi3 = vld1q_f32(i3); i3 = (const float*) ((uintptr_t) i3 + n);
59 
60       vi0 = vreinterpretq_f32_u32(vandq_u32(vmask, vreinterpretq_u32_f32(vi0)));
61       vi1 = vreinterpretq_f32_u32(vandq_u32(vmask, vreinterpretq_u32_f32(vi1)));
62       vi2 = vreinterpretq_f32_u32(vandq_u32(vmask, vreinterpretq_u32_f32(vi2)));
63       vi3 = vreinterpretq_f32_u32(vandq_u32(vmask, vreinterpretq_u32_f32(vi3)));
64 
65       vsum0 = vaddq_f32(vsum0, vi0);
66       vsum1 = vaddq_f32(vsum1, vi1);
67       vsum2 = vaddq_f32(vsum2, vi2);
68       vsum3 = vaddq_f32(vsum3, vi3);
69     }
70 
71     // Having exactly 4 rows makes this work out nicely as we end up with
72     // the 4 totals in 4 different lanes of the same vector.
73 #if XNN_ARCH_ARM64
74     const float32x4_t vsum01 = vpaddq_f32(vsum0, vsum1);
75     const float32x4_t vsum23 = vpaddq_f32(vsum2, vsum3);
76     const float32x4_t vsum = vpaddq_f32(vsum01, vsum23);
77 #else
78     const float32x4_t vsum01 = vcombine_f32(vadd_f32(vget_low_f32(vsum0), vget_high_f32(vsum0)),
79                                             vadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1)));
80     const float32x4_t vsum23 = vcombine_f32(vadd_f32(vget_low_f32(vsum2), vget_high_f32(vsum2)),
81                                             vadd_f32(vget_low_f32(vsum3), vget_high_f32(vsum3)));
82     const float32x4_t vsum = vcombine_f32(vpadd_f32(vget_low_f32(vsum01), vget_high_f32(vsum01)),
83                                           vpadd_f32(vget_low_f32(vsum23), vget_high_f32(vsum23)));
84 #endif
85 
86     float32x4_t vout = vmulq_f32(vsum, vmultiplier);
87 
88     vout = vmaxq_f32(vout, voutput_min);
89     vout = vminq_f32(vout, voutput_max);
90 
91     vst1q_f32(output, vout); output += 4;
92     i0 = i3;
93     i1 = (const float*) ((uintptr_t) i0 + elements);
94     i2 = (const float*) ((uintptr_t) i1 + elements);
95     i3 = (const float*) ((uintptr_t) i2 + elements);
96     channels -= 4;
97   }
98 
99   while (channels != 0) {
100     float32x4_t vsum0 = vmovq_n_f32(0.0f);
101     size_t n = elements;
102     while (n >= 4 * sizeof(float)) {
103       const float32x4_t vi0 = vld1q_f32(i0); i0 += 4;
104       vsum0 = vaddq_f32(vsum0, vi0);
105       n -= 4 * sizeof(float);
106     }
107 
108     if XNN_UNLIKELY(n != 0) {
109       float32x4_t vi0 = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + n);
110       vi0 = vreinterpretq_f32_u32(vandq_u32(vmask, vreinterpretq_u32_f32(vi0)));
111       vsum0 = vaddq_f32(vsum0, vi0);
112     }
113 
114     float32x2_t vsum = vadd_f32(vget_low_f32(vsum0), vget_high_f32(vsum0));
115     vsum = vpadd_f32(vsum, vsum);
116 
117     float32x2_t vout = vmul_f32(vsum, vget_low_f32(vmultiplier));
118 
119     vout = vmax_f32(vout, vget_low_f32(voutput_min));
120     vout = vmin_f32(vout, vget_low_f32(voutput_max));
121 
122     vst1_lane_f32(output, vout, 0); output += 1;
123     channels -= 1;
124   }
125 }
126