xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8maxpool/16x9p8q-neon.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <assert.h>
10 
11 #include <arm_neon.h>
12 
13 #include <qnnpack/u8maxpool.h>
14 
pytorch_u8maxpool_ukernel_16x9p8q__neon(size_t n,size_t ks,size_t kc,const uint8_t ** input,uint8_t * output,size_t input_increment,size_t output_increment,const union pytorch_qnnp_u8_clamping_params params[restrict static1])15 void pytorch_u8maxpool_ukernel_16x9p8q__neon(
16     size_t n,
17     size_t ks,
18     size_t kc,
19     const uint8_t** input,
20     uint8_t* output,
21     size_t input_increment,
22     size_t output_increment,
23     const union pytorch_qnnp_u8_clamping_params params[restrict static 1]) {
24   assert(n != 0);
25   assert(ks != 0);
26   assert(kc >= 16);
27 
28   const uint8x16_t voutput_max = vld1q_dup_u8(&params->neon.output_max);
29   const uint8x16_t voutput_min = vld1q_dup_u8(&params->neon.output_min);
30   do {
31     uint8_t* o = output;
32     {
33       const uint8_t* i0 = *input++;
34       const uint8_t* i1 = *input++;
35       const uint8_t* i2 = *input++;
36       const uint8_t* i3 = *input++;
37       const uint8_t* i4 = *input++;
38       const uint8_t* i5 = *input++;
39       const uint8_t* i6 = *input++;
40       const uint8_t* i7 = *input++;
41       const uint8_t* i8 = *input++;
42       if (ks < 2) {
43         i1 = i0;
44       }
45       if (ks <= 2) {
46         i2 = i0;
47       }
48       if (ks < 4) {
49         i3 = i0;
50       }
51       if (ks <= 4) {
52         i4 = i0;
53       }
54       if (ks < 6) {
55         i5 = i0;
56       }
57       if (ks <= 6) {
58         i6 = i0;
59       }
60       if (ks < 8) {
61         i7 = i0;
62       }
63       if (ks <= 8) {
64         i8 = i0;
65       }
66 
67       size_t k = kc;
68       while (k >= 16) {
69         const uint8x16_t vi0 = vld1q_u8(i0);
70         i0 += 16;
71         const uint8x16_t vi1 = vld1q_u8(i1);
72         i1 += 16;
73         const uint8x16_t vi2 = vld1q_u8(i2);
74         i2 += 16;
75         const uint8x16_t vi3 = vld1q_u8(i3);
76         i3 += 16;
77         const uint8x16_t vi4 = vld1q_u8(i4);
78         i4 += 16;
79         const uint8x16_t vi5 = vld1q_u8(i5);
80         i5 += 16;
81         const uint8x16_t vi6 = vld1q_u8(i6);
82         i6 += 16;
83         const uint8x16_t vi7 = vld1q_u8(i7);
84         i7 += 16;
85         const uint8x16_t vi8 = vld1q_u8(i8);
86         i8 += 16;
87 
88         const uint8x16_t vmax018 = vmaxq_u8(vmaxq_u8(vi0, vi1), vi8);
89         const uint8x16_t vmax23 = vmaxq_u8(vi2, vi3);
90         const uint8x16_t vmax45 = vmaxq_u8(vi4, vi5);
91         const uint8x16_t vmax67 = vmaxq_u8(vi6, vi7);
92 
93         const uint8x16_t vmax2345 = vmaxq_u8(vmax23, vmax45);
94         const uint8x16_t vmax01678 = vmaxq_u8(vmax018, vmax67);
95         const uint8x16_t vmax = vmaxq_u8(vmax2345, vmax01678);
96         const uint8x16_t vout =
97             vmaxq_u8(vminq_u8(vmax, voutput_max), voutput_min);
98 
99         vst1q_u8(o, vout);
100         o += 16;
101 
102         k -= 16;
103       }
104       if (k != 0) {
105         const size_t address_increment = k - 16;
106         i0 = (const uint8_t*)((uintptr_t)i0 + address_increment);
107         i1 = (const uint8_t*)((uintptr_t)i1 + address_increment);
108         i2 = (const uint8_t*)((uintptr_t)i2 + address_increment);
109         i3 = (const uint8_t*)((uintptr_t)i3 + address_increment);
110         i4 = (const uint8_t*)((uintptr_t)i4 + address_increment);
111         i5 = (const uint8_t*)((uintptr_t)i5 + address_increment);
112         i6 = (const uint8_t*)((uintptr_t)i6 + address_increment);
113         i7 = (const uint8_t*)((uintptr_t)i7 + address_increment);
114         i8 = (const uint8_t*)((uintptr_t)i8 + address_increment);
115         o = (uint8_t*)((uintptr_t)o + address_increment);
116 
117         const uint8x16_t vi0 = vld1q_u8(i0);
118         const uint8x16_t vi1 = vld1q_u8(i1);
119         const uint8x16_t vi2 = vld1q_u8(i2);
120         const uint8x16_t vi3 = vld1q_u8(i3);
121         const uint8x16_t vi4 = vld1q_u8(i4);
122         const uint8x16_t vi5 = vld1q_u8(i5);
123         const uint8x16_t vi6 = vld1q_u8(i6);
124         const uint8x16_t vi7 = vld1q_u8(i7);
125         const uint8x16_t vi8 = vld1q_u8(i8);
126 
127         const uint8x16_t vmax018 = vmaxq_u8(vmaxq_u8(vi0, vi1), vi8);
128         const uint8x16_t vmax23 = vmaxq_u8(vi2, vi3);
129         const uint8x16_t vmax45 = vmaxq_u8(vi4, vi5);
130         const uint8x16_t vmax67 = vmaxq_u8(vi6, vi7);
131 
132         const uint8x16_t vmax2345 = vmaxq_u8(vmax23, vmax45);
133         const uint8x16_t vmax01678 = vmaxq_u8(vmax018, vmax67);
134         const uint8x16_t vmax = vmaxq_u8(vmax2345, vmax01678);
135         const uint8x16_t vout =
136             vmaxq_u8(vminq_u8(vmax, voutput_max), voutput_min);
137 
138         vst1q_u8(o, vout);
139         o += 16;
140       }
141     }
142 
143     for (ptrdiff_t m = (ptrdiff_t)ks - 9; m > 0; m -= 8) {
144       const uint8_t* i0 = *input++;
145       const uint8_t* i1 = *input++;
146       const uint8_t* i2 = *input++;
147       const uint8_t* i3 = *input++;
148       const uint8_t* i4 = *input++;
149       const uint8_t* i5 = *input++;
150       const uint8_t* i6 = *input++;
151       const uint8_t* i7 = *input++;
152       if (m < 2) {
153         i1 = i0;
154       }
155       if (m <= 2) {
156         i2 = i0;
157       }
158       if (m < 4) {
159         i3 = i0;
160       }
161       if (m <= 4) {
162         i4 = i0;
163       }
164       if (m < 6) {
165         i5 = i0;
166       }
167       if (m <= 6) {
168         i6 = i0;
169       }
170       if (m < 8) {
171         i7 = i0;
172       }
173 
174       o = output;
175       size_t k = kc;
176       while (k >= 16) {
177         const uint8x16_t vi0 = vld1q_u8(i0);
178         i0 += 16;
179         const uint8x16_t vi1 = vld1q_u8(i1);
180         i1 += 16;
181         const uint8x16_t vi2 = vld1q_u8(i2);
182         i2 += 16;
183         const uint8x16_t vi3 = vld1q_u8(i3);
184         i3 += 16;
185         const uint8x16_t vi4 = vld1q_u8(i4);
186         i4 += 16;
187         const uint8x16_t vi5 = vld1q_u8(i5);
188         i5 += 16;
189         const uint8x16_t vi6 = vld1q_u8(i6);
190         i6 += 16;
191         const uint8x16_t vi7 = vld1q_u8(i7);
192         i7 += 16;
193         const uint8x16_t vo = vld1q_u8(o);
194 
195         const uint8x16_t vmax01 = vmaxq_u8(vmaxq_u8(vi0, vi1), vo);
196         const uint8x16_t vmax23 = vmaxq_u8(vi2, vi3);
197         const uint8x16_t vmax45 = vmaxq_u8(vi4, vi5);
198         const uint8x16_t vmax67 = vmaxq_u8(vi6, vi7);
199 
200         const uint8x16_t vmax2345 = vmaxq_u8(vmax23, vmax45);
201         const uint8x16_t vmax0167 = vmaxq_u8(vmax01, vmax67);
202         const uint8x16_t vmax = vmaxq_u8(vmax2345, vmax0167);
203         const uint8x16_t vout =
204             vmaxq_u8(vminq_u8(vmax, voutput_max), voutput_min);
205 
206         vst1q_u8(o, vout);
207         o += 16;
208 
209         k -= 16;
210       }
211       if (k != 0) {
212         const size_t address_increment = k - 16;
213         i0 = (const uint8_t*)((uintptr_t)i0 + address_increment);
214         i1 = (const uint8_t*)((uintptr_t)i1 + address_increment);
215         i2 = (const uint8_t*)((uintptr_t)i2 + address_increment);
216         i3 = (const uint8_t*)((uintptr_t)i3 + address_increment);
217         i4 = (const uint8_t*)((uintptr_t)i4 + address_increment);
218         i5 = (const uint8_t*)((uintptr_t)i5 + address_increment);
219         i6 = (const uint8_t*)((uintptr_t)i6 + address_increment);
220         i7 = (const uint8_t*)((uintptr_t)i7 + address_increment);
221         o = (uint8_t*)((uintptr_t)o + address_increment);
222 
223         const uint8x16_t vi0 = vld1q_u8(i0);
224         const uint8x16_t vi1 = vld1q_u8(i1);
225         const uint8x16_t vi2 = vld1q_u8(i2);
226         const uint8x16_t vi3 = vld1q_u8(i3);
227         const uint8x16_t vi4 = vld1q_u8(i4);
228         const uint8x16_t vi5 = vld1q_u8(i5);
229         const uint8x16_t vi6 = vld1q_u8(i6);
230         const uint8x16_t vi7 = vld1q_u8(i7);
231         const uint8x16_t vo = vld1q_u8(o);
232 
233         const uint8x16_t vmax01 = vmaxq_u8(vmaxq_u8(vi0, vi1), vo);
234         const uint8x16_t vmax23 = vmaxq_u8(vi2, vi3);
235         const uint8x16_t vmax45 = vmaxq_u8(vi4, vi5);
236         const uint8x16_t vmax67 = vmaxq_u8(vi6, vi7);
237 
238         const uint8x16_t vmax2345 = vmaxq_u8(vmax23, vmax45);
239         const uint8x16_t vmax0167 = vmaxq_u8(vmax01, vmax67);
240         const uint8x16_t vmax = vmaxq_u8(vmax2345, vmax0167);
241         const uint8x16_t vout =
242             vmaxq_u8(vminq_u8(vmax, voutput_max), voutput_min);
243 
244         vst1q_u8(o, vout);
245         o += 16;
246       }
247     }
248     input = (const uint8_t**)((uintptr_t)input + input_increment);
249     output = (uint8_t*)((uintptr_t)o + output_increment);
250   } while (--n != 0);
251 }
252