xref: /aosp_15_r20/external/XNNPACK/src/f32-gavgpool/7p7x-minmax-sse-c4.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 
8 #include <xmmintrin.h>
9 
10 #include <xnnpack/gavgpool.h>
11 #include <xnnpack/math.h>
12 
13 
xnn_f32_gavgpool_minmax_ukernel_7p7x__sse_c4(size_t rows,size_t channels,const float * input,size_t input_stride,const float * zero,float * buffer,float * output,const union xnn_f32_scaleminmax_params params[restrict XNN_MIN_ELEMENTS (1)])14 void xnn_f32_gavgpool_minmax_ukernel_7p7x__sse_c4(
15     size_t rows,
16     size_t channels,
17     const float* input,
18     size_t input_stride,
19     const float* zero,
20     float* buffer,
21     float* output,
22     const union xnn_f32_scaleminmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
23 {
24   assert(rows > 7);
25   assert(channels != 0);
26 
27   const float* i0 = input;
28   const float* i1 = (const float*) ((uintptr_t) i0 + input_stride);
29   const float* i2 = (const float*) ((uintptr_t) i1 + input_stride);
30   const float* i3 = (const float*) ((uintptr_t) i2 + input_stride);
31   const float* i4 = (const float*) ((uintptr_t) i3 + input_stride);
32   const float* i5 = (const float*) ((uintptr_t) i4 + input_stride);
33   const float* i6 = (const float*) ((uintptr_t) i5 + input_stride);
34   const size_t packed_channels = round_up_po2(channels, 4);
35   const size_t input_increment = 7 * input_stride - packed_channels * sizeof(float);
36 
37   float* b = buffer;
38   for (size_t c = 0; c < channels; c += 4) {
39     const __m128 vi0 = _mm_loadu_ps(i0);
40     i0 += 4;
41     const __m128 vi1 = _mm_loadu_ps(i1);
42     i1 += 4;
43     const __m128 vi2 = _mm_loadu_ps(i2);
44     i2 += 4;
45     const __m128 vi3 = _mm_loadu_ps(i3);
46     i3 += 4;
47     const __m128 vi4 = _mm_loadu_ps(i4);
48     i4 += 4;
49     const __m128 vi5 = _mm_loadu_ps(i5);
50     i5 += 4;
51     const __m128 vi6 = _mm_loadu_ps(i6);
52     i6 += 4;
53 
54     const __m128 vsum01 = _mm_add_ps(vi0, vi1);
55     const __m128 vsum23 = _mm_add_ps(vi2, vi3);
56     const __m128 vsum45 = _mm_add_ps(vi4, vi5);
57 
58     const __m128 vsum016 = _mm_add_ps(vsum01, vi6);
59     const __m128 vsum2345 = _mm_add_ps(vsum23, vsum45);
60 
61     const __m128 vsum = _mm_add_ps(vsum016, vsum2345);
62 
63     _mm_store_ps(b, vsum); b += 4;
64   }
65   for (rows -= 7; rows > 7; rows -= 7) {
66     b = buffer;
67 
68     i0 = (const float*) ((uintptr_t) i0 + input_increment);
69     i1 = (const float*) ((uintptr_t) i1 + input_increment);
70     i2 = (const float*) ((uintptr_t) i2 + input_increment);
71     i3 = (const float*) ((uintptr_t) i3 + input_increment);
72     i4 = (const float*) ((uintptr_t) i4 + input_increment);
73     i5 = (const float*) ((uintptr_t) i5 + input_increment);
74     i6 = (const float*) ((uintptr_t) i6 + input_increment);
75 
76     for (size_t c = 0; c < channels; c += 4) {
77       const __m128 vi0 = _mm_loadu_ps(i0);
78       i0 += 4;
79       const __m128 vi1 = _mm_loadu_ps(i1);
80       i1 += 4;
81       const __m128 vi2 = _mm_loadu_ps(i2);
82       i2 += 4;
83       const __m128 vi3 = _mm_loadu_ps(i3);
84       i3 += 4;
85       const __m128 vi4 = _mm_loadu_ps(i4);
86       i4 += 4;
87       const __m128 vi5 = _mm_loadu_ps(i5);
88       i5 += 4;
89       const __m128 vi6 = _mm_loadu_ps(i6);
90       i6 += 4;
91       const __m128 vacc = _mm_load_ps(b);
92 
93       const __m128 vsum01 = _mm_add_ps(vi0, vi1);
94       const __m128 vsum23 = _mm_add_ps(vi2, vi3);
95       const __m128 vsum45 = _mm_add_ps(vi4, vi5);
96       const __m128 vsum6a = _mm_add_ps(vi6, vacc);
97 
98       const __m128 vsum0123 = _mm_add_ps(vsum01, vsum23);
99       const __m128 vsum456a = _mm_add_ps(vsum45, vsum6a);
100 
101       const __m128 vsum = _mm_add_ps(vsum0123, vsum456a);
102 
103       _mm_store_ps(b, vsum); b += 4;
104     }
105   }
106 
107   i0 = (const float*) ((uintptr_t) i0 + input_increment);
108   i1 = (const float*) ((uintptr_t) i1 + input_increment);
109   if (rows < 2) {
110     i1 = zero;
111   }
112   i2 = (const float*) ((uintptr_t) i2 + input_increment);
113   if (rows <= 2) {
114     i2 = zero;
115   }
116   i3 = (const float*) ((uintptr_t) i3 + input_increment);
117   if (rows < 4) {
118     i3 = zero;
119   }
120   i4 = (const float*) ((uintptr_t) i4 + input_increment);
121   if (rows <= 4) {
122     i4 = zero;
123   }
124   i5 = (const float*) ((uintptr_t) i5 + input_increment);
125   if (rows < 6) {
126     i5 = zero;
127   }
128   i6 = (const float*) ((uintptr_t) i6 + input_increment);
129   if (rows <= 6) {
130     i6 = zero;
131   }
132   const __m128 vscale = _mm_load_ps(params->sse.scale);
133   const __m128 vmin = _mm_load_ps(params->sse.min);
134   const __m128 vmax = _mm_load_ps(params->sse.max);
135 
136   b = buffer;
137   while (channels >= 4) {
138     const __m128 vi0 = _mm_loadu_ps(i0);
139     i0 += 4;
140     const __m128 vi1 = _mm_loadu_ps(i1);
141     i1 += 4;
142     const __m128 vi2 = _mm_loadu_ps(i2);
143     i2 += 4;
144     const __m128 vi3 = _mm_loadu_ps(i3);
145     i3 += 4;
146     const __m128 vi4 = _mm_loadu_ps(i4);
147     i4 += 4;
148     const __m128 vi5 = _mm_loadu_ps(i5);
149     i5 += 4;
150     const __m128 vi6 = _mm_loadu_ps(i6);
151     i6 += 4;
152     const __m128 vacc = _mm_load_ps(b);
153     b += 4;
154 
155     const __m128 vsum01 = _mm_add_ps(vi0, vi1);
156     const __m128 vsum23 = _mm_add_ps(vi2, vi3);
157     const __m128 vsum45 = _mm_add_ps(vi4, vi5);
158     const __m128 vsum6a = _mm_add_ps(vi6, vacc);
159 
160     const __m128 vsum0123 = _mm_add_ps(vsum01, vsum23);
161     const __m128 vsum456a = _mm_add_ps(vsum45, vsum6a);
162 
163     const __m128 vsum = _mm_add_ps(vsum0123, vsum456a);
164 
165     __m128 vout = _mm_mul_ps(vsum, vscale);
166     vout = _mm_max_ps(vout, vmin);
167     vout = _mm_min_ps(vout, vmax);
168 
169     _mm_storeu_ps(output, vout);
170     output += 4;
171 
172     channels -= 4;
173   }
174   if (channels != 0) {
175     const __m128 vi0 = _mm_loadu_ps(i0);
176     const __m128 vi1 = _mm_loadu_ps(i1);
177     const __m128 vi2 = _mm_loadu_ps(i2);
178     const __m128 vi3 = _mm_loadu_ps(i3);
179     const __m128 vi4 = _mm_loadu_ps(i4);
180     const __m128 vi5 = _mm_loadu_ps(i5);
181     const __m128 vi6 = _mm_loadu_ps(i6);
182     const __m128 vacc = _mm_loadu_ps(b);
183 
184     const __m128 vsum01 = _mm_add_ps(vi0, vi1);
185     const __m128 vsum23 = _mm_add_ps(vi2, vi3);
186     const __m128 vsum45 = _mm_add_ps(vi4, vi5);
187     const __m128 vsum6a = _mm_add_ps(vi6, vacc);
188 
189     const __m128 vsum0123 = _mm_add_ps(vsum01, vsum23);
190     const __m128 vsum456a = _mm_add_ps(vsum45, vsum6a);
191 
192     const __m128 vsum = _mm_add_ps(vsum0123, vsum456a);
193 
194     __m128 vout = _mm_mul_ps(vsum, vscale);
195     vout = _mm_max_ps(vout, vmin);
196     vout = _mm_min_ps(vout, vmax);
197 
198     if (channels & 2) {
199       _mm_storel_pi((__m64*) output, vout);
200       vout = _mm_movehl_ps(vout, vout);
201       output += 2;
202     }
203     if (channels & 1) {
204       _mm_store_ss(output, vout);
205     }
206   }
207 }
208