1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7
8 #include <arm_neon.h>
9
10 #include <xnnpack/gavgpool.h>
11 #include <xnnpack/math.h>
12
13
xnn_f32_gavgpool_cw_ukernel__neon_x4(size_t elements,size_t channels,const float * input,float * output,const union xnn_f32_gavgpool_params params[restrict XNN_MIN_ELEMENTS (1)])14 void xnn_f32_gavgpool_cw_ukernel__neon_x4(
15 size_t elements,
16 size_t channels,
17 const float* input,
18 float* output,
19 const union xnn_f32_gavgpool_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
20 {
21 assert(elements != 0);
22 assert(elements % sizeof(float) == 0);
23 assert(channels != 0);
24
25 const float* i0 = input;
26 const float* i1 = (const float*) ((uintptr_t) i0 + elements);
27 const float* i2 = (const float*) ((uintptr_t) i1 + elements);
28 const float* i3 = (const float*) ((uintptr_t) i2 + elements);
29
30 const uint32x4_t vmask = vld1q_u32(params->neon.mask);
31 const float32x4_t vmultiplier = vld1q_dup_f32(¶ms->neon.multiplier);
32 const float32x4_t voutput_min = vld1q_dup_f32(¶ms->neon.output_min);
33 const float32x4_t voutput_max = vld1q_dup_f32(¶ms->neon.output_max);
34
35 while (channels >= 4) {
36 float32x4_t vsum0 = vmovq_n_f32(0.0f);
37 float32x4_t vsum1 = vmovq_n_f32(0.0f);
38 float32x4_t vsum2 = vmovq_n_f32(0.0f);
39 float32x4_t vsum3 = vmovq_n_f32(0.0f);
40 size_t n = elements;
41 while (n >= 4 * sizeof(float)) {
42 const float32x4_t vi0 = vld1q_f32(i0); i0 += 4;
43 const float32x4_t vi1 = vld1q_f32(i1); i1 += 4;
44 const float32x4_t vi2 = vld1q_f32(i2); i2 += 4;
45 const float32x4_t vi3 = vld1q_f32(i3); i3 += 4;
46
47 vsum0 = vaddq_f32(vsum0, vi0);
48 vsum1 = vaddq_f32(vsum1, vi1);
49 vsum2 = vaddq_f32(vsum2, vi2);
50 vsum3 = vaddq_f32(vsum3, vi3);
51 n -= 4 * sizeof(float);
52 }
53
54 if XNN_UNLIKELY(n != 0) {
55 float32x4_t vi0 = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + n);
56 float32x4_t vi1 = vld1q_f32(i1); i1 = (const float*) ((uintptr_t) i1 + n);
57 float32x4_t vi2 = vld1q_f32(i2); i2 = (const float*) ((uintptr_t) i2 + n);
58 float32x4_t vi3 = vld1q_f32(i3); i3 = (const float*) ((uintptr_t) i3 + n);
59
60 vi0 = vreinterpretq_f32_u32(vandq_u32(vmask, vreinterpretq_u32_f32(vi0)));
61 vi1 = vreinterpretq_f32_u32(vandq_u32(vmask, vreinterpretq_u32_f32(vi1)));
62 vi2 = vreinterpretq_f32_u32(vandq_u32(vmask, vreinterpretq_u32_f32(vi2)));
63 vi3 = vreinterpretq_f32_u32(vandq_u32(vmask, vreinterpretq_u32_f32(vi3)));
64
65 vsum0 = vaddq_f32(vsum0, vi0);
66 vsum1 = vaddq_f32(vsum1, vi1);
67 vsum2 = vaddq_f32(vsum2, vi2);
68 vsum3 = vaddq_f32(vsum3, vi3);
69 }
70
71 // Having exactly 4 rows makes this work out nicely as we end up with
72 // the 4 totals in 4 different lanes of the same vector.
73 #if XNN_ARCH_ARM64
74 const float32x4_t vsum01 = vpaddq_f32(vsum0, vsum1);
75 const float32x4_t vsum23 = vpaddq_f32(vsum2, vsum3);
76 const float32x4_t vsum = vpaddq_f32(vsum01, vsum23);
77 #else
78 const float32x4_t vsum01 = vcombine_f32(vadd_f32(vget_low_f32(vsum0), vget_high_f32(vsum0)),
79 vadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1)));
80 const float32x4_t vsum23 = vcombine_f32(vadd_f32(vget_low_f32(vsum2), vget_high_f32(vsum2)),
81 vadd_f32(vget_low_f32(vsum3), vget_high_f32(vsum3)));
82 const float32x4_t vsum = vcombine_f32(vpadd_f32(vget_low_f32(vsum01), vget_high_f32(vsum01)),
83 vpadd_f32(vget_low_f32(vsum23), vget_high_f32(vsum23)));
84 #endif
85
86 float32x4_t vout = vmulq_f32(vsum, vmultiplier);
87
88 vout = vmaxq_f32(vout, voutput_min);
89 vout = vminq_f32(vout, voutput_max);
90
91 vst1q_f32(output, vout); output += 4;
92 i0 = i3;
93 i1 = (const float*) ((uintptr_t) i0 + elements);
94 i2 = (const float*) ((uintptr_t) i1 + elements);
95 i3 = (const float*) ((uintptr_t) i2 + elements);
96 channels -= 4;
97 }
98
99 while (channels != 0) {
100 float32x4_t vsum0 = vmovq_n_f32(0.0f);
101 size_t n = elements;
102 while (n >= 4 * sizeof(float)) {
103 const float32x4_t vi0 = vld1q_f32(i0); i0 += 4;
104 vsum0 = vaddq_f32(vsum0, vi0);
105 n -= 4 * sizeof(float);
106 }
107
108 if XNN_UNLIKELY(n != 0) {
109 float32x4_t vi0 = vld1q_f32(i0); i0 = (const float*) ((uintptr_t) i0 + n);
110 vi0 = vreinterpretq_f32_u32(vandq_u32(vmask, vreinterpretq_u32_f32(vi0)));
111 vsum0 = vaddq_f32(vsum0, vi0);
112 }
113
114 float32x2_t vsum = vadd_f32(vget_low_f32(vsum0), vget_high_f32(vsum0));
115 vsum = vpadd_f32(vsum, vsum);
116
117 float32x2_t vout = vmul_f32(vsum, vget_low_f32(vmultiplier));
118
119 vout = vmax_f32(vout, vget_low_f32(voutput_min));
120 vout = vmin_f32(vout, vget_low_f32(voutput_max));
121
122 vst1_lane_f32(output, vout, 0); output += 1;
123 channels -= 1;
124 }
125 }
126