1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <assert.h>
10
11 #include <arm_neon.h>
12
13 #include <qnnpack/u8maxpool.h>
14
pytorch_u8maxpool_ukernel_16x9p8q__neon(size_t n,size_t ks,size_t kc,const uint8_t ** input,uint8_t * output,size_t input_increment,size_t output_increment,const union pytorch_qnnp_u8_clamping_params params[restrict static1])15 void pytorch_u8maxpool_ukernel_16x9p8q__neon(
16 size_t n,
17 size_t ks,
18 size_t kc,
19 const uint8_t** input,
20 uint8_t* output,
21 size_t input_increment,
22 size_t output_increment,
23 const union pytorch_qnnp_u8_clamping_params params[restrict static 1]) {
24 assert(n != 0);
25 assert(ks != 0);
26 assert(kc >= 16);
27
28 const uint8x16_t voutput_max = vld1q_dup_u8(¶ms->neon.output_max);
29 const uint8x16_t voutput_min = vld1q_dup_u8(¶ms->neon.output_min);
30 do {
31 uint8_t* o = output;
32 {
33 const uint8_t* i0 = *input++;
34 const uint8_t* i1 = *input++;
35 const uint8_t* i2 = *input++;
36 const uint8_t* i3 = *input++;
37 const uint8_t* i4 = *input++;
38 const uint8_t* i5 = *input++;
39 const uint8_t* i6 = *input++;
40 const uint8_t* i7 = *input++;
41 const uint8_t* i8 = *input++;
42 if (ks < 2) {
43 i1 = i0;
44 }
45 if (ks <= 2) {
46 i2 = i0;
47 }
48 if (ks < 4) {
49 i3 = i0;
50 }
51 if (ks <= 4) {
52 i4 = i0;
53 }
54 if (ks < 6) {
55 i5 = i0;
56 }
57 if (ks <= 6) {
58 i6 = i0;
59 }
60 if (ks < 8) {
61 i7 = i0;
62 }
63 if (ks <= 8) {
64 i8 = i0;
65 }
66
67 size_t k = kc;
68 while (k >= 16) {
69 const uint8x16_t vi0 = vld1q_u8(i0);
70 i0 += 16;
71 const uint8x16_t vi1 = vld1q_u8(i1);
72 i1 += 16;
73 const uint8x16_t vi2 = vld1q_u8(i2);
74 i2 += 16;
75 const uint8x16_t vi3 = vld1q_u8(i3);
76 i3 += 16;
77 const uint8x16_t vi4 = vld1q_u8(i4);
78 i4 += 16;
79 const uint8x16_t vi5 = vld1q_u8(i5);
80 i5 += 16;
81 const uint8x16_t vi6 = vld1q_u8(i6);
82 i6 += 16;
83 const uint8x16_t vi7 = vld1q_u8(i7);
84 i7 += 16;
85 const uint8x16_t vi8 = vld1q_u8(i8);
86 i8 += 16;
87
88 const uint8x16_t vmax018 = vmaxq_u8(vmaxq_u8(vi0, vi1), vi8);
89 const uint8x16_t vmax23 = vmaxq_u8(vi2, vi3);
90 const uint8x16_t vmax45 = vmaxq_u8(vi4, vi5);
91 const uint8x16_t vmax67 = vmaxq_u8(vi6, vi7);
92
93 const uint8x16_t vmax2345 = vmaxq_u8(vmax23, vmax45);
94 const uint8x16_t vmax01678 = vmaxq_u8(vmax018, vmax67);
95 const uint8x16_t vmax = vmaxq_u8(vmax2345, vmax01678);
96 const uint8x16_t vout =
97 vmaxq_u8(vminq_u8(vmax, voutput_max), voutput_min);
98
99 vst1q_u8(o, vout);
100 o += 16;
101
102 k -= 16;
103 }
104 if (k != 0) {
105 const size_t address_increment = k - 16;
106 i0 = (const uint8_t*)((uintptr_t)i0 + address_increment);
107 i1 = (const uint8_t*)((uintptr_t)i1 + address_increment);
108 i2 = (const uint8_t*)((uintptr_t)i2 + address_increment);
109 i3 = (const uint8_t*)((uintptr_t)i3 + address_increment);
110 i4 = (const uint8_t*)((uintptr_t)i4 + address_increment);
111 i5 = (const uint8_t*)((uintptr_t)i5 + address_increment);
112 i6 = (const uint8_t*)((uintptr_t)i6 + address_increment);
113 i7 = (const uint8_t*)((uintptr_t)i7 + address_increment);
114 i8 = (const uint8_t*)((uintptr_t)i8 + address_increment);
115 o = (uint8_t*)((uintptr_t)o + address_increment);
116
117 const uint8x16_t vi0 = vld1q_u8(i0);
118 const uint8x16_t vi1 = vld1q_u8(i1);
119 const uint8x16_t vi2 = vld1q_u8(i2);
120 const uint8x16_t vi3 = vld1q_u8(i3);
121 const uint8x16_t vi4 = vld1q_u8(i4);
122 const uint8x16_t vi5 = vld1q_u8(i5);
123 const uint8x16_t vi6 = vld1q_u8(i6);
124 const uint8x16_t vi7 = vld1q_u8(i7);
125 const uint8x16_t vi8 = vld1q_u8(i8);
126
127 const uint8x16_t vmax018 = vmaxq_u8(vmaxq_u8(vi0, vi1), vi8);
128 const uint8x16_t vmax23 = vmaxq_u8(vi2, vi3);
129 const uint8x16_t vmax45 = vmaxq_u8(vi4, vi5);
130 const uint8x16_t vmax67 = vmaxq_u8(vi6, vi7);
131
132 const uint8x16_t vmax2345 = vmaxq_u8(vmax23, vmax45);
133 const uint8x16_t vmax01678 = vmaxq_u8(vmax018, vmax67);
134 const uint8x16_t vmax = vmaxq_u8(vmax2345, vmax01678);
135 const uint8x16_t vout =
136 vmaxq_u8(vminq_u8(vmax, voutput_max), voutput_min);
137
138 vst1q_u8(o, vout);
139 o += 16;
140 }
141 }
142
143 for (ptrdiff_t m = (ptrdiff_t)ks - 9; m > 0; m -= 8) {
144 const uint8_t* i0 = *input++;
145 const uint8_t* i1 = *input++;
146 const uint8_t* i2 = *input++;
147 const uint8_t* i3 = *input++;
148 const uint8_t* i4 = *input++;
149 const uint8_t* i5 = *input++;
150 const uint8_t* i6 = *input++;
151 const uint8_t* i7 = *input++;
152 if (m < 2) {
153 i1 = i0;
154 }
155 if (m <= 2) {
156 i2 = i0;
157 }
158 if (m < 4) {
159 i3 = i0;
160 }
161 if (m <= 4) {
162 i4 = i0;
163 }
164 if (m < 6) {
165 i5 = i0;
166 }
167 if (m <= 6) {
168 i6 = i0;
169 }
170 if (m < 8) {
171 i7 = i0;
172 }
173
174 o = output;
175 size_t k = kc;
176 while (k >= 16) {
177 const uint8x16_t vi0 = vld1q_u8(i0);
178 i0 += 16;
179 const uint8x16_t vi1 = vld1q_u8(i1);
180 i1 += 16;
181 const uint8x16_t vi2 = vld1q_u8(i2);
182 i2 += 16;
183 const uint8x16_t vi3 = vld1q_u8(i3);
184 i3 += 16;
185 const uint8x16_t vi4 = vld1q_u8(i4);
186 i4 += 16;
187 const uint8x16_t vi5 = vld1q_u8(i5);
188 i5 += 16;
189 const uint8x16_t vi6 = vld1q_u8(i6);
190 i6 += 16;
191 const uint8x16_t vi7 = vld1q_u8(i7);
192 i7 += 16;
193 const uint8x16_t vo = vld1q_u8(o);
194
195 const uint8x16_t vmax01 = vmaxq_u8(vmaxq_u8(vi0, vi1), vo);
196 const uint8x16_t vmax23 = vmaxq_u8(vi2, vi3);
197 const uint8x16_t vmax45 = vmaxq_u8(vi4, vi5);
198 const uint8x16_t vmax67 = vmaxq_u8(vi6, vi7);
199
200 const uint8x16_t vmax2345 = vmaxq_u8(vmax23, vmax45);
201 const uint8x16_t vmax0167 = vmaxq_u8(vmax01, vmax67);
202 const uint8x16_t vmax = vmaxq_u8(vmax2345, vmax0167);
203 const uint8x16_t vout =
204 vmaxq_u8(vminq_u8(vmax, voutput_max), voutput_min);
205
206 vst1q_u8(o, vout);
207 o += 16;
208
209 k -= 16;
210 }
211 if (k != 0) {
212 const size_t address_increment = k - 16;
213 i0 = (const uint8_t*)((uintptr_t)i0 + address_increment);
214 i1 = (const uint8_t*)((uintptr_t)i1 + address_increment);
215 i2 = (const uint8_t*)((uintptr_t)i2 + address_increment);
216 i3 = (const uint8_t*)((uintptr_t)i3 + address_increment);
217 i4 = (const uint8_t*)((uintptr_t)i4 + address_increment);
218 i5 = (const uint8_t*)((uintptr_t)i5 + address_increment);
219 i6 = (const uint8_t*)((uintptr_t)i6 + address_increment);
220 i7 = (const uint8_t*)((uintptr_t)i7 + address_increment);
221 o = (uint8_t*)((uintptr_t)o + address_increment);
222
223 const uint8x16_t vi0 = vld1q_u8(i0);
224 const uint8x16_t vi1 = vld1q_u8(i1);
225 const uint8x16_t vi2 = vld1q_u8(i2);
226 const uint8x16_t vi3 = vld1q_u8(i3);
227 const uint8x16_t vi4 = vld1q_u8(i4);
228 const uint8x16_t vi5 = vld1q_u8(i5);
229 const uint8x16_t vi6 = vld1q_u8(i6);
230 const uint8x16_t vi7 = vld1q_u8(i7);
231 const uint8x16_t vo = vld1q_u8(o);
232
233 const uint8x16_t vmax01 = vmaxq_u8(vmaxq_u8(vi0, vi1), vo);
234 const uint8x16_t vmax23 = vmaxq_u8(vi2, vi3);
235 const uint8x16_t vmax45 = vmaxq_u8(vi4, vi5);
236 const uint8x16_t vmax67 = vmaxq_u8(vi6, vi7);
237
238 const uint8x16_t vmax2345 = vmaxq_u8(vmax23, vmax45);
239 const uint8x16_t vmax0167 = vmaxq_u8(vmax01, vmax67);
240 const uint8x16_t vmax = vmaxq_u8(vmax2345, vmax0167);
241 const uint8x16_t vout =
242 vmaxq_u8(vminq_u8(vmax, voutput_max), voutput_min);
243
244 vst1q_u8(o, vout);
245 o += 16;
246 }
247 }
248 input = (const uint8_t**)((uintptr_t)input + input_increment);
249 output = (uint8_t*)((uintptr_t)o + output_increment);
250 } while (--n != 0);
251 }
252