xref: /aosp_15_r20/external/XNNPACK/src/f32-argmaxpool/9p8x-scalar-c1.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 
8 #include <xnnpack/argmaxpool.h>
9 #include <xnnpack/math.h>
10 
11 
xnn_f32_argmaxpool_ukernel_9p8x__scalar_c1(size_t output_pixels,size_t pooling_elements,size_t channels,const float ** input,size_t input_offset,float * accumulation_buffer,uint32_t * index_buffer,float * output,uint32_t * index,size_t input_increment,size_t output_increment)12 void xnn_f32_argmaxpool_ukernel_9p8x__scalar_c1(
13     size_t output_pixels,
14     size_t pooling_elements,
15     size_t channels,
16     const float** input,
17     size_t input_offset,
18     float* accumulation_buffer,
19     uint32_t* index_buffer,
20     float* output,
21     uint32_t* index,
22     size_t input_increment,
23     size_t output_increment)
24 {
25   assert(output_pixels != 0);
26   assert(pooling_elements != 0);
27   assert(pooling_elements > 9);
28   assert(channels != 0);
29 
30   do {
31     {
32       float* ab = accumulation_buffer;
33       uint32_t* ib = index_buffer;
34 
35       const float* i0 = *input++;
36       const float* i1 = *input++;
37       const float* i2 = *input++;
38       const float* i3 = *input++;
39       const float* i4 = *input++;
40       const float* i5 = *input++;
41       const float* i6 = *input++;
42       const float* i7 = *input++;
43       const float* i8 = *input++;
44       i0 = (const float*) ((uintptr_t) i0 + input_offset);
45       i1 = (const float*) ((uintptr_t) i1 + input_offset);
46       i2 = (const float*) ((uintptr_t) i2 + input_offset);
47       i3 = (const float*) ((uintptr_t) i3 + input_offset);
48       i4 = (const float*) ((uintptr_t) i4 + input_offset);
49       i5 = (const float*) ((uintptr_t) i5 + input_offset);
50       i6 = (const float*) ((uintptr_t) i6 + input_offset);
51       i7 = (const float*) ((uintptr_t) i7 + input_offset);
52       i8 = (const float*) ((uintptr_t) i8 + input_offset);
53 
54       size_t c = channels;
55       do {
56         const float vi0 = *i0++;
57         const float vi1 = *i1++;
58         const float vi2 = *i2++;
59         const float vi3 = *i3++;
60         const float vi4 = *i4++;
61         const float vi5 = *i5++;
62         const float vi6 = *i6++;
63         const float vi7 = *i7++;
64         const float vi8 = *i8++;
65 
66         float vmax = vi0;
67         uint32_t vidx = 0;
68 
69         if (vi1 > vmax) {
70           vmax = vi1;
71           vidx = 1;
72         }
73 
74         if (vi2 > vmax) {
75           vmax = vi2;
76           vidx = 2;
77         }
78 
79         if (vi3 > vmax) {
80           vmax = vi3;
81           vidx = 3;
82         }
83 
84         if (vi4 > vmax) {
85           vmax = vi4;
86           vidx = 4;
87         }
88 
89         if (vi5 > vmax) {
90           vmax = vi5;
91           vidx = 5;
92         }
93 
94         if (vi6 > vmax) {
95           vmax = vi6;
96           vidx = 6;
97         }
98 
99         if (vi7 > vmax) {
100           vmax = vi7;
101           vidx = 7;
102         }
103 
104         if (vi8 > vmax) {
105           vmax = vi8;
106           vidx = 8;
107         }
108 
109         *ab++ = vmax;
110         *ib++ = vidx;
111       } while (--c != 0);
112     }
113     uint32_t vidx0 = 9;
114     size_t k = pooling_elements;
115     for (k -= 9; k > 8; k -= 8) {
116       const float* i0 = *input++;
117       const float* i1 = *input++;
118       const float* i2 = *input++;
119       const float* i3 = *input++;
120       const float* i4 = *input++;
121       const float* i5 = *input++;
122       const float* i6 = *input++;
123       const float* i7 = *input++;
124       i0 = (const float*) ((uintptr_t) i0 + input_offset);
125       i1 = (const float*) ((uintptr_t) i1 + input_offset);
126       i2 = (const float*) ((uintptr_t) i2 + input_offset);
127       i3 = (const float*) ((uintptr_t) i3 + input_offset);
128       i4 = (const float*) ((uintptr_t) i4 + input_offset);
129       i5 = (const float*) ((uintptr_t) i5 + input_offset);
130       i6 = (const float*) ((uintptr_t) i6 + input_offset);
131       i7 = (const float*) ((uintptr_t) i7 + input_offset);
132 
133       float* ab = accumulation_buffer;
134       uint32_t* ib = index_buffer;
135 
136       size_t c = channels;
137       do {
138         const float vi0 = *i0++;
139         const float vi1 = *i1++;
140         const float vi2 = *i2++;
141         const float vi3 = *i3++;
142         const float vi4 = *i4++;
143         const float vi5 = *i5++;
144         const float vi6 = *i6++;
145         const float vi7 = *i7++;
146 
147         float vmax = *ab;
148         uint32_t vidx = *ib;
149 
150         if (vi0 > vmax) {
151           vmax = vi0;
152           vidx = vidx0;
153         }
154 
155         if (vi1 > vmax) {
156           vmax = vi1;
157           vidx = vidx0 + 1;
158         }
159 
160         if (vi2 > vmax) {
161           vmax = vi2;
162           vidx = vidx0 + 2;
163         }
164 
165         if (vi3 > vmax) {
166           vmax = vi3;
167           vidx = vidx0 + 3;
168         }
169 
170         if (vi4 > vmax) {
171           vmax = vi4;
172           vidx = vidx0 + 4;
173         }
174 
175         if (vi5 > vmax) {
176           vmax = vi5;
177           vidx = vidx0 + 5;
178         }
179 
180         if (vi6 > vmax) {
181           vmax = vi6;
182           vidx = vidx0 + 6;
183         }
184 
185         if (vi7 > vmax) {
186           vmax = vi7;
187           vidx = vidx0 + 7;
188         }
189 
190         *ab++ = vmax;
191         *ib++ = vidx;
192       } while (--c != 0);
193       vidx0 += 8;
194     }
195 
196     float* o = output;
197     uint32_t* i = index;
198     {
199       const float* i0 = input[0];
200       const float* i1 = input[1];
201       const float* i2 = input[2];
202       const float* i3 = input[3];
203       const float* i4 = input[4];
204       const float* i5 = input[5];
205       const float* i6 = input[6];
206       const float* i7 = input[7];
207       i0 = (const float*) ((uintptr_t) i0 + input_offset);
208       i1 = (const float*) ((uintptr_t) i1 + input_offset);
209       i2 = (const float*) ((uintptr_t) i2 + input_offset);
210       i3 = (const float*) ((uintptr_t) i3 + input_offset);
211       i4 = (const float*) ((uintptr_t) i4 + input_offset);
212       i5 = (const float*) ((uintptr_t) i5 + input_offset);
213       i6 = (const float*) ((uintptr_t) i6 + input_offset);
214       i7 = (const float*) ((uintptr_t) i7 + input_offset);
215       input = (const float**) ((uintptr_t) input + input_increment);
216       if (k < 2) {
217         i1 = i0;
218       }
219       if (k <= 2) {
220         i2 = i0;
221       }
222       if (k < 4) {
223         i3 = i0;
224       }
225       if (k <= 4) {
226         i4 = i0;
227       }
228       if (k < 6) {
229         i5 = i0;
230       }
231       if (k <= 6) {
232         i6 = i0;
233       }
234       if (k != 8) {
235         i7 = i0;
236       }
237 
238       size_t c = channels;
239       float* ab = accumulation_buffer;
240       uint32_t* ib = index_buffer;
241       do {
242         const float vi0 = *i0++;
243         const float vi1 = *i1++;
244         const float vi2 = *i2++;
245         const float vi3 = *i3++;
246         const float vi4 = *i4++;
247         const float vi5 = *i5++;
248         const float vi6 = *i6++;
249         const float vi7 = *i7++;
250 
251         float vmax = *ab++;
252         uint32_t vidx = *ib++;
253 
254         if (vi0 > vmax) {
255           vmax = vi0;
256           vidx = vidx0;
257         }
258 
259         if (vi1 > vmax) {
260           vmax = vi1;
261           vidx = vidx0 + 1;
262         }
263 
264         if (vi2 > vmax) {
265           vmax = vi2;
266           vidx = vidx0 + 2;
267         }
268 
269         if (vi3 > vmax) {
270           vmax = vi3;
271           vidx = vidx0 + 3;
272         }
273 
274         if (vi4 > vmax) {
275           vmax = vi4;
276           vidx = vidx0 + 4;
277         }
278 
279         if (vi5 > vmax) {
280           vmax = vi5;
281           vidx = vidx0 + 5;
282         }
283 
284         if (vi6 > vmax) {
285           vmax = vi6;
286           vidx = vidx0 + 6;
287         }
288 
289         if (vi7 > vmax) {
290           vmax = vi7;
291           vidx = vidx0 + 7;
292         }
293 
294         *o++ = vmax;
295         *i++ = vidx;
296       } while (--c != 0);
297     }
298 
299     output = (float*) ((uintptr_t) o + output_increment);
300     index = (uint32_t*) i;
301   } while (--output_pixels != 0);
302 }
303