xref: /aosp_15_r20/external/ComputeLibrary/src/core/NEON/kernels/arm_gemm/quantized.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2019 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifdef __aarch64__
25 
26 #include "arm_gemm.hpp"
27 #include "utils.hpp"
28 
29 #include <arm_neon.h>
30 
31 namespace arm_gemm {
32 
33 namespace {
34 
35 /* Requantize a block of data, using the requantize parameters in 'qp'.
36  *
37  * row_bias and col_bias are assumed to be precomputed values which include
38  * any externally supplied bias, plus the row/column contibution sums, plus
39  * the overall constant offset (A_offset * B_offset * depth).
40  *
41  * Note that this function works equally well for uint8_t output: just set
42  * minval/maxval appropriately and cast the output pointer.  It is caller's
43  * responsibility to ensure that minval/maxval are representable in the
44  * target type - the downcast to (u)int8_t is done by simply extracting the
45  * LSB.
46  *
47  * The 'do_shift_correction' template parameter turns on the correction
48  * applied to negative values being shifted right to make sure they round
49  * properly - if negative values are never output (e.g. fused ReLU) this is
50  * unnecessary.
51  *
52  * The 'per_channel' template parameter selects between per channel and per
53  * layer requantization - in the former case we need to load vectors of
54  * shifts and multipliers for each column.  A separate vector for each
55  * column is set up in any case (and it is hoped that the compiler can elide
56  * the needless movs in the per-layer case).
57  */
58 template<bool do_shift_correction, bool per_channel, bool do_left_shift>
requantize_block_32_int(const Requantize32 & qp,unsigned int width,unsigned int height,const int32_t * input,unsigned int in_stride,int8_t * output,unsigned int out_stride,const int32_t * row_bias,const int32_t * col_bias,const unsigned int start_col)59 void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigned int height,
60                              const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride,
61                              const int32_t *row_bias, const int32_t *col_bias, const unsigned int start_col) {
62     const int32x4_t v_mul          = vdupq_n_s32(qp.per_layer_mul);
63     const int32x4_t v_right_shift  = vdupq_n_s32(qp.per_layer_right_shift);
64     const int32x4_t v_left_shift   = vdupq_n_s32(qp.per_layer_left_shift);
65     const int32x4_t v_minval       = vdupq_n_s32(qp.minval);
66     const int32x4_t v_maxval       = vdupq_n_s32(qp.maxval);
67     const int32x4_t v_c_offset     = vdupq_n_s32(qp.c_offset);
68 
69     /* To make sure we have plenty of accumulators, compute two rows at a
70      * time.  If the number of rows is odd, compute the bottom row twice to
71      * avoid needing a duplicate codepath. */
72     for (unsigned int row=0; row<height; row+=2) {
73         /* Prefer to do 4 vectors (16 values) at once as this collapses
74          * neatly to a single vector of output, failing that a vector at a
75          * time and then the odd ones out at the end.  */
76         unsigned int blocks=(width / 16);
77         unsigned int regs=(width % 16) / 4;
78         unsigned int odds=(width % 4);
79 
80         const int32_t *colptr = col_bias;
81         const int32_t *perch_mul_ptr    = qp.per_channel_muls + start_col;
82         const int32_t *perch_shift_ptr  = qp.per_channel_right_shifts + start_col;
83         const int32_t *perch_shiftl_ptr = qp.per_channel_left_shifts + start_col;
84 
85         const int32_t *in_ptr = input + (row * in_stride);
86         int8_t *out_ptr = output + (row * out_stride);
87         int32_t row_sum = row_bias[row];
88 
89         const int32_t *in_ptr1;
90         int8_t *out_ptr1;
91         int32_t row_sum1;
92 
93         if (row == height-1) {
94             in_ptr1  = in_ptr;
95             out_ptr1 = out_ptr;
96             row_sum1 = row_sum;
97         } else {
98             in_ptr1  = in_ptr + in_stride;
99             out_ptr1 = out_ptr + out_stride;
100             row_sum1 = row_bias[row+1];
101         }
102 
103         const int32x4_t v_row_sum  = vdupq_n_s32(row_sum);
104         const int32x4_t v_row_sum1 = vdupq_n_s32(row_sum1);
105 
106         while (blocks--) {
107             int32x4_t v_mul0;
108             int32x4_t v_mul1;
109             int32x4_t v_mul2;
110             int32x4_t v_mul3;
111 
112             int32x4_t v_shf0;
113             int32x4_t v_shf1;
114             int32x4_t v_shf2;
115             int32x4_t v_shf3;
116 
117             int32x4_t v_shf0l;
118             int32x4_t v_shf1l;
119             int32x4_t v_shf2l;
120             int32x4_t v_shf3l;
121 
122             if (per_channel) {
123                 v_mul0 = vld1q_s32(perch_mul_ptr);
124                 v_mul1 = vld1q_s32(perch_mul_ptr + 4);
125                 v_mul2 = vld1q_s32(perch_mul_ptr + 8);
126                 v_mul3 = vld1q_s32(perch_mul_ptr + 12);
127                 perch_mul_ptr += 16;
128 
129                 v_shf0 = vld1q_s32(perch_shift_ptr);
130                 v_shf1 = vld1q_s32(perch_shift_ptr + 4);
131                 v_shf2 = vld1q_s32(perch_shift_ptr + 8);
132                 v_shf3 = vld1q_s32(perch_shift_ptr + 12);
133                 perch_shift_ptr += 16;
134 
135                 if (do_left_shift) {
136                     v_shf0l = vld1q_s32(perch_shiftl_ptr);
137                     v_shf1l = vld1q_s32(perch_shiftl_ptr + 4);
138                     v_shf2l = vld1q_s32(perch_shiftl_ptr + 8);
139                     v_shf3l = vld1q_s32(perch_shiftl_ptr + 12);
140                     perch_shiftl_ptr += 16;
141                 }
142             } else {
143                 v_mul0=v_mul1=v_mul2=v_mul3=v_mul;
144                 v_shf0=v_shf1=v_shf2=v_shf3=v_right_shift;
145                 v_shf0l=v_shf1l=v_shf2l=v_shf3l=v_left_shift;
146             }
147 
148             // Load column pointers
149             int32x4_t v_col0 = vld1q_s32(colptr);
150             int32x4_t v_col1 = vld1q_s32(colptr + 4);
151             int32x4_t v_col2 = vld1q_s32(colptr + 8);
152             int32x4_t v_col3 = vld1q_s32(colptr + 12);
153             colptr += 16;
154 
155             // Load input data (row 0);
156             int32x4_t v_in00 = vld1q_s32(in_ptr);
157             int32x4_t v_in01 = vld1q_s32(in_ptr + 4);
158             int32x4_t v_in02 = vld1q_s32(in_ptr + 8);
159             int32x4_t v_in03 = vld1q_s32(in_ptr + 12);
160             in_ptr += 16;
161 
162             // Load input data (row 1);
163             int32x4_t v_in10 = vld1q_s32(in_ptr1);
164             int32x4_t v_in11 = vld1q_s32(in_ptr1 + 4);
165             int32x4_t v_in12 = vld1q_s32(in_ptr1 + 8);
166             int32x4_t v_in13 = vld1q_s32(in_ptr1 + 12);
167             in_ptr1 += 16;
168 
169             // Add on row bias and column bias
170             v_in00 = vaddq_s32(v_in00, v_row_sum);
171             v_in01 = vaddq_s32(v_in01, v_row_sum);
172             v_in02 = vaddq_s32(v_in02, v_row_sum);
173             v_in03 = vaddq_s32(v_in03, v_row_sum);
174 
175             v_in10 = vaddq_s32(v_in10, v_row_sum1);
176             v_in11 = vaddq_s32(v_in11, v_row_sum1);
177             v_in12 = vaddq_s32(v_in12, v_row_sum1);
178             v_in13 = vaddq_s32(v_in13, v_row_sum1);
179 
180             v_in00 = vaddq_s32(v_in00, v_col0);
181             v_in01 = vaddq_s32(v_in01, v_col1);
182             v_in02 = vaddq_s32(v_in02, v_col2);
183             v_in03 = vaddq_s32(v_in03, v_col3);
184 
185             v_in10 = vaddq_s32(v_in10, v_col0);
186             v_in11 = vaddq_s32(v_in11, v_col1);
187             v_in12 = vaddq_s32(v_in12, v_col2);
188             v_in13 = vaddq_s32(v_in13, v_col3);
189 
190             // Quantize
191 
192             // If a left shift is needed it needs to happen first.
193             if (do_left_shift) {
194                 v_in00 = vrshlq_s32(v_in00, v_shf0l);
195                 v_in01 = vrshlq_s32(v_in01, v_shf1l);
196                 v_in02 = vrshlq_s32(v_in02, v_shf2l);
197                 v_in03 = vrshlq_s32(v_in03, v_shf3l);
198 
199                 v_in10 = vrshlq_s32(v_in10, v_shf0l);
200                 v_in11 = vrshlq_s32(v_in11, v_shf1l);
201                 v_in12 = vrshlq_s32(v_in12, v_shf2l);
202                 v_in13 = vrshlq_s32(v_in13, v_shf3l);
203             }
204 
205             // Multiply
206             v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
207             v_in01 = vqrdmulhq_s32(v_in01, v_mul1);
208             v_in02 = vqrdmulhq_s32(v_in02, v_mul2);
209             v_in03 = vqrdmulhq_s32(v_in03, v_mul3);
210 
211             v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
212             v_in11 = vqrdmulhq_s32(v_in11, v_mul1);
213             v_in12 = vqrdmulhq_s32(v_in12, v_mul2);
214             v_in13 = vqrdmulhq_s32(v_in13, v_mul3);
215 
216             // Compute and add on corrective offset
217             if (do_shift_correction) {
218                 int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0);
219                 int32x4_t v_temp01 = vandq_s32(v_in01, v_shf1);
220                 int32x4_t v_temp02 = vandq_s32(v_in02, v_shf2);
221                 int32x4_t v_temp03 = vandq_s32(v_in03, v_shf3);
222 
223                 int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0);
224                 int32x4_t v_temp11 = vandq_s32(v_in11, v_shf1);
225                 int32x4_t v_temp12 = vandq_s32(v_in12, v_shf2);
226                 int32x4_t v_temp13 = vandq_s32(v_in13, v_shf3);
227 
228                 v_temp00 = vshrq_n_s32(v_temp00, 31);
229                 v_temp01 = vshrq_n_s32(v_temp01, 31);
230                 v_temp02 = vshrq_n_s32(v_temp02, 31);
231                 v_temp03 = vshrq_n_s32(v_temp03, 31);
232 
233                 v_temp10 = vshrq_n_s32(v_temp10, 31);
234                 v_temp11 = vshrq_n_s32(v_temp11, 31);
235                 v_temp12 = vshrq_n_s32(v_temp12, 31);
236                 v_temp13 = vshrq_n_s32(v_temp13, 31);
237 
238                 v_in00 = vqaddq_s32(v_in00, v_temp00);
239                 v_in01 = vqaddq_s32(v_in01, v_temp01);
240                 v_in02 = vqaddq_s32(v_in02, v_temp02);
241                 v_in03 = vqaddq_s32(v_in03, v_temp03);
242 
243                 v_in10 = vqaddq_s32(v_in10, v_temp10);
244                 v_in11 = vqaddq_s32(v_in11, v_temp11);
245                 v_in12 = vqaddq_s32(v_in12, v_temp12);
246                 v_in13 = vqaddq_s32(v_in13, v_temp13);
247             }
248 
249             v_in00 = vrshlq_s32(v_in00, v_shf0);
250             v_in01 = vrshlq_s32(v_in01, v_shf1);
251             v_in02 = vrshlq_s32(v_in02, v_shf2);
252             v_in03 = vrshlq_s32(v_in03, v_shf3);
253 
254             v_in10 = vrshlq_s32(v_in10, v_shf0);
255             v_in11 = vrshlq_s32(v_in11, v_shf1);
256             v_in12 = vrshlq_s32(v_in12, v_shf2);
257             v_in13 = vrshlq_s32(v_in13, v_shf3);
258 
259             v_in00 = vaddq_s32(v_in00, v_c_offset);
260             v_in01 = vaddq_s32(v_in01, v_c_offset);
261             v_in02 = vaddq_s32(v_in02, v_c_offset);
262             v_in03 = vaddq_s32(v_in03, v_c_offset);
263 
264             v_in10 = vaddq_s32(v_in10, v_c_offset);
265             v_in11 = vaddq_s32(v_in11, v_c_offset);
266             v_in12 = vaddq_s32(v_in12, v_c_offset);
267             v_in13 = vaddq_s32(v_in13, v_c_offset);
268 
269             v_in00 = vmaxq_s32(v_in00, v_minval);
270             v_in01 = vmaxq_s32(v_in01, v_minval);
271             v_in02 = vmaxq_s32(v_in02, v_minval);
272             v_in03 = vmaxq_s32(v_in03, v_minval);
273 
274             v_in10 = vmaxq_s32(v_in10, v_minval);
275             v_in11 = vmaxq_s32(v_in11, v_minval);
276             v_in12 = vmaxq_s32(v_in12, v_minval);
277             v_in13 = vmaxq_s32(v_in13, v_minval);
278 
279             v_in00 = vminq_s32(v_in00, v_maxval);
280             v_in01 = vminq_s32(v_in01, v_maxval);
281             v_in02 = vminq_s32(v_in02, v_maxval);
282             v_in03 = vminq_s32(v_in03, v_maxval);
283 
284             v_in10 = vminq_s32(v_in10, v_maxval);
285             v_in11 = vminq_s32(v_in11, v_maxval);
286             v_in12 = vminq_s32(v_in12, v_maxval);
287             v_in13 = vminq_s32(v_in13, v_maxval);
288 
289             int16x8_t v_uz00 = vuzp1q_s16(vreinterpretq_s16_s32(v_in00), vreinterpretq_s16_s32(v_in01));
290             int16x8_t v_uz01 = vuzp1q_s16(vreinterpretq_s16_s32(v_in02), vreinterpretq_s16_s32(v_in03));
291 
292             int16x8_t v_uz10 = vuzp1q_s16(vreinterpretq_s16_s32(v_in10), vreinterpretq_s16_s32(v_in11));
293             int16x8_t v_uz11 = vuzp1q_s16(vreinterpretq_s16_s32(v_in12), vreinterpretq_s16_s32(v_in13));
294 
295             int8x16_t v_uz0 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz00), vreinterpretq_s8_s16(v_uz01));
296             int8x16_t v_uz1 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz10), vreinterpretq_s8_s16(v_uz11));
297 
298             vst1q_s8(out_ptr, v_uz0);
299             out_ptr += 16;
300             vst1q_s8(out_ptr1, v_uz1);
301             out_ptr1 += 16;
302         }
303 
304         // We are often quantizing one block of interleaved kernel output at a time - these are three registers
305         // wide.  Special case that here.
306         if (regs==3) {
307             regs -= 3;
308 
309             int32x4_t v_mul0;
310             int32x4_t v_mul1;
311             int32x4_t v_mul2;
312 
313             int32x4_t v_shf0;
314             int32x4_t v_shf1;
315             int32x4_t v_shf2;
316 
317             int32x4_t v_shf0l;
318             int32x4_t v_shf1l;
319             int32x4_t v_shf2l;
320 
321             if (per_channel) {
322                 v_mul0 = vld1q_s32(perch_mul_ptr);
323                 v_mul1 = vld1q_s32(perch_mul_ptr + 4);
324                 v_mul2 = vld1q_s32(perch_mul_ptr + 8);
325                 perch_mul_ptr += 12;
326 
327                 v_shf0 = vld1q_s32(perch_shift_ptr);
328                 v_shf1 = vld1q_s32(perch_shift_ptr + 4);
329                 v_shf2 = vld1q_s32(perch_shift_ptr + 8);
330                 perch_shift_ptr += 12;
331 
332                 if (do_left_shift) {
333                     v_shf0l = vld1q_s32(perch_shiftl_ptr);
334                     v_shf1l = vld1q_s32(perch_shiftl_ptr + 4);
335                     v_shf2l = vld1q_s32(perch_shiftl_ptr + 8);
336                     perch_shiftl_ptr += 12;
337                 }
338             } else {
339                 v_mul0=v_mul1=v_mul2=v_mul;
340                 v_shf0=v_shf1=v_shf2=v_right_shift;
341                 v_shf0l=v_shf1l=v_shf2l=v_left_shift;
342             }
343 
344             // Load column pointers
345             int32x4_t v_col0 = vld1q_s32(colptr);
346             int32x4_t v_col1 = vld1q_s32(colptr + 4);
347             int32x4_t v_col2 = vld1q_s32(colptr + 8);
348             colptr += 12;
349 
350             // Load input data (row 0);
351             int32x4_t v_in00 = vld1q_s32(in_ptr);
352             int32x4_t v_in01 = vld1q_s32(in_ptr + 4);
353             int32x4_t v_in02 = vld1q_s32(in_ptr + 8);
354             in_ptr += 12;
355 
356             // Load input data (row 1);
357             int32x4_t v_in10 = vld1q_s32(in_ptr1);
358             int32x4_t v_in11 = vld1q_s32(in_ptr1 + 4);
359             int32x4_t v_in12 = vld1q_s32(in_ptr1 + 8);
360             in_ptr1 += 12;
361 
362             // Add on row bias and column bias
363             v_in00 = vaddq_s32(v_in00, v_row_sum);
364             v_in01 = vaddq_s32(v_in01, v_row_sum);
365             v_in02 = vaddq_s32(v_in02, v_row_sum);
366 
367             v_in10 = vaddq_s32(v_in10, v_row_sum1);
368             v_in11 = vaddq_s32(v_in11, v_row_sum1);
369             v_in12 = vaddq_s32(v_in12, v_row_sum1);
370 
371             v_in00 = vaddq_s32(v_in00, v_col0);
372             v_in01 = vaddq_s32(v_in01, v_col1);
373             v_in02 = vaddq_s32(v_in02, v_col2);
374 
375             v_in10 = vaddq_s32(v_in10, v_col0);
376             v_in11 = vaddq_s32(v_in11, v_col1);
377             v_in12 = vaddq_s32(v_in12, v_col2);
378 
379             // Quantize
380 
381             // If a left shift is needed it needs to happen first.
382             if (do_left_shift) {
383                 v_in00 = vrshlq_s32(v_in00, v_shf0l);
384                 v_in01 = vrshlq_s32(v_in01, v_shf1l);
385                 v_in02 = vrshlq_s32(v_in02, v_shf2l);
386 
387                 v_in10 = vrshlq_s32(v_in10, v_shf0l);
388                 v_in11 = vrshlq_s32(v_in11, v_shf1l);
389                 v_in12 = vrshlq_s32(v_in12, v_shf2l);
390             }
391 
392             // Multiply
393             v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
394             v_in01 = vqrdmulhq_s32(v_in01, v_mul1);
395             v_in02 = vqrdmulhq_s32(v_in02, v_mul2);
396 
397             v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
398             v_in11 = vqrdmulhq_s32(v_in11, v_mul1);
399             v_in12 = vqrdmulhq_s32(v_in12, v_mul2);
400 
401             // Compute and add on corrective offset
402             if (do_shift_correction) {
403                 int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0);
404                 int32x4_t v_temp01 = vandq_s32(v_in01, v_shf1);
405                 int32x4_t v_temp02 = vandq_s32(v_in02, v_shf2);
406 
407                 int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0);
408                 int32x4_t v_temp11 = vandq_s32(v_in11, v_shf1);
409                 int32x4_t v_temp12 = vandq_s32(v_in12, v_shf2);
410 
411                 v_temp00 = vshrq_n_s32(v_temp00, 31);
412                 v_temp01 = vshrq_n_s32(v_temp01, 31);
413                 v_temp02 = vshrq_n_s32(v_temp02, 31);
414 
415                 v_temp10 = vshrq_n_s32(v_temp10, 31);
416                 v_temp11 = vshrq_n_s32(v_temp11, 31);
417                 v_temp12 = vshrq_n_s32(v_temp12, 31);
418 
419                 v_in00 = vqaddq_s32(v_in00, v_temp00);
420                 v_in01 = vqaddq_s32(v_in01, v_temp01);
421                 v_in02 = vqaddq_s32(v_in02, v_temp02);
422 
423                 v_in10 = vqaddq_s32(v_in10, v_temp10);
424                 v_in11 = vqaddq_s32(v_in11, v_temp11);
425                 v_in12 = vqaddq_s32(v_in12, v_temp12);
426             }
427 
428             v_in00 = vrshlq_s32(v_in00, v_shf0);
429             v_in01 = vrshlq_s32(v_in01, v_shf1);
430             v_in02 = vrshlq_s32(v_in02, v_shf2);
431 
432             v_in10 = vrshlq_s32(v_in10, v_shf0);
433             v_in11 = vrshlq_s32(v_in11, v_shf1);
434             v_in12 = vrshlq_s32(v_in12, v_shf2);
435 
436             v_in00 = vaddq_s32(v_in00, v_c_offset);
437             v_in01 = vaddq_s32(v_in01, v_c_offset);
438             v_in02 = vaddq_s32(v_in02, v_c_offset);
439 
440             v_in10 = vaddq_s32(v_in10, v_c_offset);
441             v_in11 = vaddq_s32(v_in11, v_c_offset);
442             v_in12 = vaddq_s32(v_in12, v_c_offset);
443 
444             v_in00 = vmaxq_s32(v_in00, v_minval);
445             v_in01 = vmaxq_s32(v_in01, v_minval);
446             v_in02 = vmaxq_s32(v_in02, v_minval);
447 
448             v_in10 = vmaxq_s32(v_in10, v_minval);
449             v_in11 = vmaxq_s32(v_in11, v_minval);
450             v_in12 = vmaxq_s32(v_in12, v_minval);
451 
452             v_in00 = vminq_s32(v_in00, v_maxval);
453             v_in01 = vminq_s32(v_in01, v_maxval);
454             v_in02 = vminq_s32(v_in02, v_maxval);
455 
456             v_in10 = vminq_s32(v_in10, v_maxval);
457             v_in11 = vminq_s32(v_in11, v_maxval);
458             v_in12 = vminq_s32(v_in12, v_maxval);
459 
460             int16x8_t v_uz00 = vuzp1q_s16(vreinterpretq_s16_s32(v_in00), vreinterpretq_s16_s32(v_in01));
461             int16x8_t v_uz01 = vuzp1q_s16(vreinterpretq_s16_s32(v_in02), vreinterpretq_s16_s32(v_in02));
462 
463             int16x8_t v_uz10 = vuzp1q_s16(vreinterpretq_s16_s32(v_in10), vreinterpretq_s16_s32(v_in11));
464             int16x8_t v_uz11 = vuzp1q_s16(vreinterpretq_s16_s32(v_in12), vreinterpretq_s16_s32(v_in12));
465 
466             int8x16_t v_uz0 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz00), vreinterpretq_s8_s16(v_uz01));
467             int8x16_t v_uz1 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz10), vreinterpretq_s8_s16(v_uz11));
468 
469             vst1q_lane_s64(reinterpret_cast<int64_t *>(out_ptr), vreinterpretq_s64_s8(v_uz0), 0);
470             vst1q_lane_s32(reinterpret_cast<int32_t *>(out_ptr + 8), vreinterpretq_s32_s8(v_uz0), 2);
471             out_ptr += 12;
472             vst1q_lane_s64(reinterpret_cast<int64_t *>(out_ptr1), vreinterpretq_s64_s8(v_uz1), 0);
473             vst1q_lane_s32(reinterpret_cast<int32_t *>(out_ptr1 + 8), vreinterpretq_s32_s8(v_uz1), 2);
474             out_ptr1 += 12;
475         }
476 
477         while (regs--) {
478             int32x4_t v_mul0;
479             int32x4_t v_shf0;
480             int32x4_t v_shf0l;
481 
482             if (per_channel) {
483                 v_mul0 = vld1q_s32(perch_mul_ptr);
484                 perch_mul_ptr += 4;
485 
486                 v_shf0 = vld1q_s32(perch_shift_ptr);
487                 perch_shift_ptr += 4;
488 
489                 if (do_left_shift) {
490                     v_shf0l = vld1q_s32(perch_shiftl_ptr);
491                     perch_shiftl_ptr += 4;
492                 }
493             } else {
494                 v_mul0=v_mul;
495                 v_shf0=v_right_shift;
496                 v_shf0l=v_left_shift;
497             }
498             // Load column pointers
499             int32x4_t v_col0 = vld1q_s32(colptr);
500             colptr += 4;
501 
502             // Load input data (row 0);
503             int32x4_t v_in00 = vld1q_s32(in_ptr);
504             in_ptr += 4;
505 
506             // Load input data (row 1);
507             int32x4_t v_in10 = vld1q_s32(in_ptr1);
508             in_ptr1 += 4;
509 
510             // Add on row sum and bias constant
511             v_in00 = vaddq_s32(v_in00, v_row_sum);
512 
513             v_in10 = vaddq_s32(v_in10, v_row_sum1);
514 
515             // Subtract col sum * a_offset
516             v_in00 = vaddq_s32(v_in00, v_col0);
517 
518             v_in10 = vaddq_s32(v_in10, v_col0);
519 
520             // Quantize - start with (optional) left shift
521             if (do_left_shift) {
522                 v_in00 = vrshlq_s32(v_in00, v_shf0l);
523 
524                 v_in10 = vrshlq_s32(v_in10, v_shf0l);
525             }
526 
527             // Then multiply
528             v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
529 
530             v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
531 
532             // Compute and add on corrective offset
533             if (do_shift_correction) {
534                 int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0);
535 
536                 int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0);
537 
538                 v_temp00 = vshrq_n_s32(v_temp00, 31);
539 
540                 v_temp10 = vshrq_n_s32(v_temp10, 31);
541 
542                 v_in00 = vqaddq_s32(v_in00, v_temp00);
543 
544                 v_in10 = vqaddq_s32(v_in10, v_temp10);
545             }
546 
547             v_in00 = vrshlq_s32(v_in00, v_shf0);
548 
549             v_in10 = vrshlq_s32(v_in10, v_shf0);
550 
551             v_in00 = vaddq_s32(v_in00, v_c_offset);
552 
553             v_in10 = vaddq_s32(v_in10, v_c_offset);
554 
555             v_in00 = vmaxq_s32(v_in00, v_minval);
556 
557             v_in10 = vmaxq_s32(v_in10, v_minval);
558 
559             v_in00 = vminq_s32(v_in00, v_maxval);
560 
561             v_in10 = vminq_s32(v_in10, v_maxval);
562 
563             int16x8_t v_uz00 = vuzp1q_s16(vreinterpretq_s16_s32(v_in00), vreinterpretq_s16_s32(v_in10));
564 
565             int8x16_t v_uz0 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz00), vreinterpretq_s8_s16(v_uz00));
566 
567             vst1q_lane_s32(reinterpret_cast<int32_t *>(out_ptr), vreinterpretq_s32_s8(v_uz0), 0);
568             out_ptr += 4;
569             vst1q_lane_s32(reinterpret_cast<int32_t *>(out_ptr1), vreinterpretq_s32_s8(v_uz0), 1);
570             out_ptr1 += 4;
571         }
572 
573         if (odds) {
574             int32x4_t v_col0 = vdupq_n_s32(0);
575             int32x4_t v_in00 = vdupq_n_s32(0);
576             int32x4_t v_in10 = vdupq_n_s32(0);
577             int32x4_t v_mul0 = vdupq_n_s32(0);
578             int32x4_t v_shf0 = vdupq_n_s32(0);
579             int32x4_t v_shf0l = vdupq_n_s32(0);
580 
581             if (!per_channel) {
582                 v_mul0 = v_mul;
583                 v_shf0 = v_right_shift;
584                 v_shf0l = v_left_shift;
585             }
586 
587             do {
588                 v_col0 = vld1q_lane_s32(colptr, v_col0, 0);
589                 v_in00 = vld1q_lane_s32(in_ptr, v_in00, 0);
590                 v_in10 = vld1q_lane_s32(in_ptr1, v_in10, 0);
591                 if (per_channel) {
592                     v_mul0 = vld1q_lane_s32(perch_mul_ptr, v_mul0, 0);
593                     v_shf0 = vld1q_lane_s32(perch_shift_ptr, v_shf0, 0);
594                     if (do_left_shift) {
595                         v_shf0l = vld1q_lane_s32(perch_shiftl_ptr, v_shf0l, 0);
596                     }
597                 }
598                 if (odds == 1) { break; }
599 
600                 v_col0 = vld1q_lane_s32(colptr + 1, v_col0, 1);
601                 v_in00 = vld1q_lane_s32(in_ptr + 1, v_in00, 1);
602                 v_in10 = vld1q_lane_s32(in_ptr1 + 1, v_in10, 1);
603                 if (per_channel) {
604                     v_mul0 = vld1q_lane_s32(perch_mul_ptr + 1, v_mul0, 1);
605                     v_shf0 = vld1q_lane_s32(perch_shift_ptr + 1, v_shf0, 1);
606                     if (do_left_shift) {
607                         v_shf0l = vld1q_lane_s32(perch_shiftl_ptr + 1, v_shf0l, 1);
608                     }
609                 }
610                 if (odds == 2) { break; }
611 
612                 v_col0 = vld1q_lane_s32(colptr + 2, v_col0, 2);
613                 v_in00 = vld1q_lane_s32(in_ptr + 2, v_in00, 2);
614                 v_in10 = vld1q_lane_s32(in_ptr1 + 2, v_in10, 2);
615                 if (per_channel) {
616                     v_mul0 = vld1q_lane_s32(perch_mul_ptr + 2, v_mul0, 2);
617                     v_shf0 = vld1q_lane_s32(perch_shift_ptr + 2, v_shf0, 2);
618                     if (do_left_shift) {
619                         v_shf0l = vld1q_lane_s32(perch_shiftl_ptr + 2, v_shf0l, 2);
620                     }
621                 }
622             } while (0);
623 
624             // Add on row sum and bias constant
625             v_in00 = vaddq_s32(v_in00, v_row_sum);
626 
627             v_in10 = vaddq_s32(v_in10, v_row_sum1);
628 
629             // Subtract col sum * a_offset
630             v_in00 = vaddq_s32(v_in00, v_col0);
631 
632             v_in10 = vaddq_s32(v_in10, v_col0);
633 
634             // Quantize - start with (optional) left shift
635             if (do_left_shift) {
636                 v_in00 = vrshlq_s32(v_in00, v_shf0l);
637 
638                 v_in10 = vrshlq_s32(v_in10, v_shf0l);
639             }
640 
641             // Then multiply
642             v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
643 
644             v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
645 
646             // Compute and add on corrective offset
647             if (do_shift_correction) {
648                 int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0);
649 
650                 int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0);
651 
652                 v_temp00 = vshrq_n_s32(v_temp00, 31);
653 
654                 v_temp10 = vshrq_n_s32(v_temp10, 31);
655 
656                 v_in00 = vqaddq_s32(v_in00, v_temp00);
657 
658                 v_in10 = vqaddq_s32(v_in10, v_temp10);
659             }
660 
661             v_in00 = vrshlq_s32(v_in00, v_shf0);
662 
663             v_in10 = vrshlq_s32(v_in10, v_shf0);
664 
665             v_in00 = vaddq_s32(v_in00, v_c_offset);
666 
667             v_in10 = vaddq_s32(v_in10, v_c_offset);
668 
669             v_in00 = vmaxq_s32(v_in00, v_minval);
670 
671             v_in10 = vmaxq_s32(v_in10, v_minval);
672 
673             v_in00 = vminq_s32(v_in00, v_maxval);
674 
675             v_in10 = vminq_s32(v_in10, v_maxval);
676 
677             do {
678                 vst1q_lane_s8(out_ptr, vreinterpretq_s8_s32(v_in00), 0);
679                 vst1q_lane_s8(out_ptr1, vreinterpretq_s8_s32(v_in10), 0);
680 
681                 if (odds==1) { break; }
682 
683                 vst1q_lane_s8(out_ptr + 1, vreinterpretq_s8_s32(v_in00), 4);
684                 vst1q_lane_s8(out_ptr1 + 1, vreinterpretq_s8_s32(v_in10), 4);
685 
686                 if (odds==2) { break; }
687 
688                 vst1q_lane_s8(out_ptr + 2, vreinterpretq_s8_s32(v_in00), 8);
689                 vst1q_lane_s8(out_ptr1 + 2, vreinterpretq_s8_s32(v_in10), 8);
690             } while(0);
691         }
692     }
693 }
694 
695 } // anonymous namespace
696 
697 template<typename Tin, typename Tout>
requantize_block_32(const Requantize32 & qp,unsigned int width,unsigned int height,const Tin * input,unsigned int in_stride,Tout * output,unsigned int out_stride,const int32_t * row_bias,const int32_t * col_bias,unsigned int start_col)698 void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height,
699                          const Tin *input, unsigned int in_stride, Tout *output, unsigned int out_stride,
700                          const int32_t *row_bias, const int32_t *col_bias, unsigned int start_col) {
701     if (qp.per_channel_requant) {
702         if (qp.minval >= qp.c_offset) {
703             if (qp.per_channel_left_shifts) {
704                 requantize_block_32_int<false, true, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
705                                  reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
706             } else {
707                 requantize_block_32_int<false, true, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
708                                  reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
709             }
710         } else {
711             if (qp.per_channel_left_shifts) {
712                 requantize_block_32_int<true, true, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
713                                  reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
714             } else {
715                 requantize_block_32_int<true, true, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
716                                  reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
717             }
718         }
719     } else {
720         if (qp.minval >= qp.c_offset) {
721             if (qp.per_layer_left_shift > 0) {
722                 requantize_block_32_int<false, false, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
723                                  reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
724             } else {
725                 requantize_block_32_int<false, false, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
726                                  reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
727             }
728         } else {
729             if (qp.per_layer_left_shift > 0) {
730                 requantize_block_32_int<true, false, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
731                                  reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
732             } else {
733                 requantize_block_32_int<true, false, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
734                                  reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
735             }
736         }
737     }
738 }
739 
740 template void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height,
741                          const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride,
742                          const int32_t *row_bias, const int32_t *col_bias, unsigned int start_col);
743 
744 template void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height,
745                          const uint32_t *input, unsigned int in_stride, uint8_t *output, unsigned int out_stride,
746                          const int32_t *row_bias, const int32_t *col_bias, unsigned int start_col);
747 
748 /*
749  * Routine (and helpers) to compute row sums needed for offset correction.
750  *
751  * This is often needed for a lot of short rows (e.g.  Syrax 5 - 6400 rows
752  * of length 27), therefore it's important not to sacrifice performance on
753  * odd length rows.
754  *
755  * To minimize performance loss in these cases, this routine will overread
756  * by up to 7 bytes.
757  *
758  * This is handled via "mask" and "mask mode" parameters to the inner
759  * routines; mask mode == 1 indicates that are between 1 and 8 bytes
760  * (inclusive) needed at the end; in these cases we always read 8 bytes.
761  * mask mode == 2 indicates that there are between 9 and 15 bytes needed at
762  * the end, and in this case we always read 16 bytes.  In both cases the
763  * 'mask' vector is set up so that the read value can be masked off to clear
764  * the overread lanes.  This is handled by 'accumulate_masked_8' and
765  * 'accumulate_masked_16' above.
766  *
767  * This routine is templated on the type to be accumulated, because the
768  * innermost instruction used needs to be of the correct signedness.
769  * However, beyond this point we always use signed values in both cases.
770  * The instructions that need to be different are therefore wrapped in
771  * helper functions below.
772  *
773  * The general strategy used is to load vectors of 16 bytes and accumulate
774  * (using uadalp/sadalp or AArch32 equivalents) into 8x16-bit accumulators.
775  * These are then reduced (using uadalp/sadalp again) into 4x32-bit
776  * accumulators.  The 4 accumulators for up to 4 rows being processed are
777  * then added together into a single output vector using pairwise adds.
778  *
779  * This reduction from the 8x16-bit into the 4x32-bit accumulators needs to
780  * occur before the 16-bit accumulators can overflow - which is every 32
781  * iterations (512 total bytes processed).  This is explained more below.
782  */
783 namespace {
784     struct row_sum_helpers {
785         const Requantize32 &qp;
786 
787         /* Load a full 16 byte vector, pairwise accumulate into 'sum' with uadalp or sadalp */
788         template<typename T>
789         inline int16x8_t accumulate_16(const T *ptr, int16x8_t sum);
790 
791         /* Load a full 16 byte vector, but mask before accumulation (see above). */
792         template<typename T>
793         inline int16x8_t accumulate_masked_16(const T *ptr, int16x8_t sum, uint64x2_t mask);
794 
795         /* Load 8 bytes and mask before accumulation. */
796         template<typename T>
797         inline int16x8_t accumulate_masked_8(const T *ptr, int16x8_t sum, uint64x2_t mask);
798 
799         /* This function does the actual work for up to 4 rows at a time.
800          * It's pulled out so we can template on the row count to generate
801          * the 4 different cases.  4 rows are computed at a time as this
802          * reduces to a single vector write.  */
803         template<unsigned int rows, typename T>
compute_some_rowsarm_gemm::__anonbb3cc3330211::row_sum_helpers804         void compute_some_rows(unsigned int blocks, const T *input, unsigned int in_stride, int32_t *row_bias, unsigned int mask_mode, uint64x2_t mask, int32x4_t offset_mul) {
805             int16x8_t sums[rows];
806             int32x4_t finalsums[rows];
807 
808             for (unsigned int i=0; i<rows; i++) {
809                 sums[i]      = vdupq_n_s16(0);
810                 finalsums[i] = vdupq_n_s32(0);
811             }
812 
813             for (unsigned int i=0; i<blocks; i++) {
814                 for (unsigned int r=0; r<rows; r++) {
815                     /* If we add too many blocks together, we run the risk
816                      * of overflowing the intermediate 16-bit accumulators,
817                      * especially in the unsigned case where we later treat
818                      * the accumulator as signed.
819                      *
820                      * In that case, the maximum (signed) value is 16383,
821                      * which is safe for 64 (unsigned) accumulations (255*64
822                      * = 16,320).
823                      *
824                      * Each invocation of pairwise add adds 2 values to the
825                      * accumulator - so in the unsigned case we can do 32
826                      * adds before we need to reset the 16-bit accumulator
827                      * by adding into the 32-bit 'finalsums'.
828                      *
829                      * We could do 64 adds in the signed case, but that
830                      * optimization is not worth the complexity.
831                      */
832                     if (i > 0 && ((i & 31) == 0)) {
833                         finalsums[r] = vpadalq_s16(finalsums[r], sums[r]);
834                         sums[r] = vdupq_n_s16(0);
835                     }
836                     sums[r] = accumulate_16(input + (r * in_stride) + (i * 16), sums[r]);
837                 }
838             }
839 
840             /* Handle the final masked read if needed. */
841             if (mask_mode > 0) {
842                 for (unsigned int r=0; r<rows; r++) {
843                     if (mask_mode == 1) {
844                         sums[r] = accumulate_masked_8(input + (r * in_stride) + (blocks * 16), sums[r], mask);
845                     } else {
846                         sums[r] = accumulate_masked_16(input + (r * in_stride) + (blocks * 16), sums[r], mask);
847                     }
848                 }
849             }
850 
851             for (unsigned int i=0; i<rows; i++) {
852                 finalsums[i] = vpadalq_s16(finalsums[i], sums[i]);
853             }
854 
855             int32x4_t t0, t1;
856             int32x2_t t2;
857 
858             /* Result writeback - need to write back one value per row
859              * processed.  Multiply all the final totals by -b_offset so
860              * that the terms can simply be added in the requantize code.
861              * */
862             switch (rows) {
863                 case 1:
864                     /* If we only have one output, just use ADDV.  Multiply
865                      * the offset into all four components separately so it
866                      * can stay in the SIMD register file.  */
867                     t0 = vmulq_s32(finalsums[0], offset_mul);
868                     *row_bias = vaddvq_s32(t0);
869                     break;
870 
871                 case 2:
872                     /* For two outputs, two rounds of pairwise adds will
873                      * generate the result in a 2-vector we can store in one
874                      * go.  */
875                     t0 = vpaddq_s32(finalsums[0], finalsums[1]);
876                     t0 = vpaddq_s32(t0, t0);
877                     t2 = vmul_s32(vget_low_s32(t0), vget_low_s32(offset_mul));
878                     vst1_s32(row_bias, t2);
879                     break;
880 
881                 case 3:
882                     /* Three rows - need to store the low two words plus the odd value from lane 2 */
883                     t0 = vpaddq_s32(finalsums[0], finalsums[1]);
884                     t1 = vpaddq_s32(finalsums[2], finalsums[2]);
885 
886                     t0 = vpaddq_s32(t0, t1);
887                     t0 = vmulq_s32(t0, offset_mul);
888 
889                     vst1_s32(row_bias, vget_low_s32(t0));
890                     row_bias[2] = vgetq_lane_s32(t0, 2);
891                     break;
892 
893                 case 4:
894                     /* Four rows (most common case) - reduce to a single
895                      * vector with pairwise adds.  */
896                     t0 = vpaddq_s32(finalsums[0], finalsums[1]);
897                     t1 = vpaddq_s32(finalsums[2], finalsums[3]);
898 
899                     t0 = vpaddq_s32(t0, t1);
900                     t0 = vmulq_s32(t0, offset_mul);
901 
902                     vst1q_s32(row_bias, t0);
903                     break;
904 
905                 default:
906                     UNREACHABLE("Impossible.");
907             }
908         }
909 
row_sum_helpersarm_gemm::__anonbb3cc3330211::row_sum_helpers910         row_sum_helpers(const Requantize32 &qp) : qp(qp) { }
911     };
912 
913     template<>
accumulate_16(const uint8_t * ptr,int16x8_t sum)914     int16x8_t row_sum_helpers::accumulate_16(const uint8_t *ptr, int16x8_t sum) {
915         return vreinterpretq_s16_u16(vpadalq_u8(vreinterpretq_u16_s16(sum), vld1q_u8(ptr)));
916     }
917 
918     template<>
accumulate_16(const int8_t * ptr,int16x8_t sum)919     int16x8_t row_sum_helpers::accumulate_16(const int8_t *ptr, int16x8_t sum) {
920         return vpadalq_s8(sum, vld1q_s8(ptr));
921     }
922 
923     template<>
accumulate_masked_16(const int8_t * ptr,int16x8_t sum,uint64x2_t mask)924     int16x8_t row_sum_helpers::accumulate_masked_16(const int8_t *ptr, int16x8_t sum, uint64x2_t mask) {
925         int8x16_t v = vandq_s8(vld1q_s8(ptr), vreinterpretq_s8_u64(mask));
926         return vpadalq_s8(sum, v);
927     }
928 
929     template<>
accumulate_masked_16(const uint8_t * ptr,int16x8_t sum,uint64x2_t mask)930     int16x8_t row_sum_helpers::accumulate_masked_16(const uint8_t *ptr, int16x8_t sum, uint64x2_t mask) {
931         uint8x16_t v = vandq_u8(vld1q_u8(ptr), vreinterpretq_u8_u64(mask));
932         return vreinterpretq_s16_u16(vpadalq_u8(vreinterpretq_u16_s16(sum), v));
933     }
934 
935     template<>
accumulate_masked_8(const int8_t * ptr,int16x8_t sum,uint64x2_t mask)936     int16x8_t row_sum_helpers::accumulate_masked_8(const int8_t *ptr, int16x8_t sum, uint64x2_t mask) {
937         int8x16_t v = vcombine_s8(vld1_s8(ptr), vdup_n_s8(0));
938         v = vreinterpretq_s8_u64(vandq_u64(mask, vreinterpretq_u64_s8(v)));
939         return vpadalq_s8(sum, v);
940     }
941 
942     template<>
accumulate_masked_8(const uint8_t * ptr,int16x8_t sum,uint64x2_t mask)943     int16x8_t row_sum_helpers::accumulate_masked_8(const uint8_t *ptr, int16x8_t sum, uint64x2_t mask) {
944         uint8x16_t v = vcombine_u8(vld1_u8(ptr), vdup_n_u8(0));
945         v = vreinterpretq_u8_u64(vandq_u64(mask, vreinterpretq_u64_u8(v)));
946         return vreinterpretq_s16_u16(vpadalq_u8(vreinterpretq_u16_s16(sum), v));
947     }
948 }
949 
950 template<typename T>
compute_row_sums(const Requantize32 & qp,unsigned int width,unsigned int height,const T * input,unsigned int in_stride,int32_t * row_bias)951 void compute_row_sums(const Requantize32 &qp, unsigned int width, unsigned int height,
952                       const T *input, unsigned int in_stride, int32_t *row_bias) {
953     /* If the 'b' offset is zero, just skip this entirely. */
954     if (qp.b_offset == 0) {
955         memset(row_bias, 0, height * sizeof(int32_t));
956         return;
957     }
958 
959     row_sum_helpers thehelpers(qp);
960 
961     const int32x4_t offset_mul = vdupq_n_s32(-qp.b_offset);
962 
963     /* Work out how many full vectors of 16 bytes we will read, and how many
964      * odd bytes at the end */
965     unsigned int blocks = (width / 16);
966     const unsigned int odds = width % 16;
967 
968     /* Generate a mask to use on the last iteration, if necessary. */
969     uint64x2_t mask;
970     unsigned int mask_mode = 0;
971 
972     if (odds > 0 && odds <= 8) {
973         /* 1-8 odds: mask in the low lane, 0 in the top */
974         uint64_t maskval = (~0ULL) >> (8 * (8-odds));
975 
976         mask = vsetq_lane_u64(maskval, vdupq_n_u64(0), 0);
977 
978         mask_mode = 1;
979     } else if (odds > 8) {
980         /* 9-15 odds: mask in the top lane, all 1s in the bottom. */
981         uint64_t maskval = (~0ULL) >> (8 * (16-odds));
982 
983         mask = vsetq_lane_u64(maskval, vdupq_n_u64(~0ULL), 1);
984 
985         mask_mode = 2;
986     }
987 
988     for (unsigned int row=0; row<height; row+=4) {
989         switch(height-row) {
990             default:
991             case 4:
992                 thehelpers.compute_some_rows<4>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
993                 break;
994             case 3:
995                 thehelpers.compute_some_rows<3>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
996                 break;
997             case 2:
998                 thehelpers.compute_some_rows<2>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
999                 break;
1000             case 1:
1001                 thehelpers.compute_some_rows<1>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul);
1002                 break;
1003         }
1004     }
1005 }
1006 
1007 /* Instantiate the two versions for uint8_t and int8_t. */
1008 template void compute_row_sums(const Requantize32 &, unsigned int, unsigned int, const int8_t *, unsigned int, int32_t *);
1009 template void compute_row_sums(const Requantize32 &, unsigned int, unsigned int, const uint8_t *, unsigned int, int32_t *);
1010 
1011 template<unsigned int active_rows, typename T>
1012 inline void add_block(const T *input, unsigned int in_stride, int32_t *output);
1013 
1014 template<unsigned int active_rows>
add_block(const uint8_t * input,unsigned int in_stride,int32_t * output)1015 inline void add_block(const uint8_t *input, unsigned int in_stride, int32_t *output) {
1016     uint8x16_t inputs[4];
1017 
1018     for (unsigned int i=0; i<4; i++) {
1019         if (i < active_rows) {
1020             inputs[i] = vld1q_u8(input + i * in_stride);
1021         } else {
1022             inputs[i] = vdupq_n_u8(0);
1023         }
1024     }
1025 
1026     int16x8_t sums_16b[4];
1027 
1028     // Two adds for the low pairs
1029     sums_16b[0]=vreinterpretq_s16_u16(vaddl_u8(vget_low_u8(inputs[0]), vget_low_u8(inputs[1])));
1030     sums_16b[1]=vreinterpretq_s16_u16(vaddl_u8(vget_low_u8(inputs[2]), vget_low_u8(inputs[3])));
1031     // Two adds for the high pairs
1032     sums_16b[2]=vreinterpretq_s16_u16(vaddl_high_u8(inputs[0], inputs[1]));
1033     sums_16b[3]=vreinterpretq_s16_u16(vaddl_high_u8(inputs[2], inputs[3]));
1034 
1035     int32x4_t sums_32b[4];
1036 
1037     sums_32b[0]=vaddl_s16(vget_low_s16(sums_16b[0]), vget_low_s16(sums_16b[1]));
1038     sums_32b[1]=vaddl_high_s16(sums_16b[0], sums_16b[1]);
1039     sums_32b[2]=vaddl_s16(vget_low_s16(sums_16b[2]), vget_low_s16(sums_16b[3]));
1040     sums_32b[3]=vaddl_high_s16(sums_16b[2], sums_16b[3]);
1041 
1042     for (unsigned int i=0; i<4; i++) {
1043         vst1q_s32(output + 4*i, vaddq_s32(sums_32b[i], vld1q_s32(output + 4*i)));
1044     }
1045 }
1046 
1047 template<unsigned int active_rows>
add_block(const int8_t * input,unsigned int in_stride,int32_t * output)1048 inline void add_block(const int8_t *input, unsigned int in_stride, int32_t *output) {
1049     int8x16_t inputs[4];
1050 
1051     for (unsigned int i=0; i<4; i++) {
1052         if (i < active_rows) {
1053             inputs[i] = vld1q_s8(input + i * in_stride);
1054         } else {
1055             inputs[i] = vdupq_n_s8(0);
1056         }
1057     }
1058 
1059     int16x8_t sums_16b[4];
1060 
1061     // Two adds for the low pairs
1062     sums_16b[0]=vaddl_s8(vget_low_s8(inputs[0]), vget_low_s8(inputs[1]));
1063     sums_16b[1]=vaddl_s8(vget_low_s8(inputs[2]), vget_low_s8(inputs[3]));
1064     // Two adds for the high pairs
1065     sums_16b[2]=vaddl_high_s8(inputs[0], inputs[1]);
1066     sums_16b[3]=vaddl_high_s8(inputs[2], inputs[3]);
1067 
1068     int32x4_t sums_32b[4];
1069 
1070     sums_32b[0]=vaddl_s16(vget_low_s16(sums_16b[0]), vget_low_s16(sums_16b[1]));
1071     sums_32b[1]=vaddl_high_s16(sums_16b[0], sums_16b[1]);
1072     sums_32b[2]=vaddl_s16(vget_low_s16(sums_16b[2]), vget_low_s16(sums_16b[3]));
1073     sums_32b[3]=vaddl_high_s16(sums_16b[2], sums_16b[3]);
1074 
1075     for (unsigned int i=0; i<4; i++) {
1076         vst1q_s32(output + 4*i, vaddq_s32(sums_32b[i], vld1q_s32(output + 4*i)));
1077     }
1078 }
1079 
1080 /* "first_col" parameter is used to offset the read into the qp.bias array,
1081  * in cases where we are not computing the first columns of the output (i.e.
1082  * in multithreaded cases where we divide columns across threads) */
1083 template<typename T>
compute_col_sums(const Requantize32 & qp,unsigned int width,unsigned int height,const T * input,unsigned int in_stride,int32_t * col_bias,unsigned int depth,unsigned int multi,unsigned int first_col)1084 void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const T *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col) {
1085     /* Only actually add up the columns if a_offset is non-zero. */
1086     if (qp.a_offset != 0) {
1087         memset(reinterpret_cast<void *>(col_bias), 0, width * sizeof(int32_t));
1088 
1089         for (unsigned int row=0; row<height; row+=4) {
1090             unsigned int numrows=std::min(height-row, 4u);
1091 
1092             for (unsigned int col=0; col<width; col+=16) {
1093                 unsigned int numcols=std::min(width-col, 16u);
1094 
1095                 if (numcols==16) {
1096                     switch(numrows) {
1097                         case 1:
1098                             add_block<1>(input + row * in_stride + col, in_stride, col_bias + col);
1099                             break;
1100 
1101                         case 2:
1102                             add_block<2>(input + row * in_stride + col, in_stride, col_bias + col);
1103                             break;
1104 
1105                         case 3:
1106                             add_block<3>(input + row * in_stride + col, in_stride, col_bias + col);
1107                             break;
1108 
1109                         case 4:
1110                             add_block<4>(input + row * in_stride + col, in_stride, col_bias + col);
1111                             break;
1112 
1113                         default:
1114                             UNREACHABLE("Impossible.");
1115                     }
1116                 } else {
1117                     for (; col<width; col++) {
1118                         int32_t sum=0;
1119                         for (unsigned int r=0; r<numrows; r++) {
1120                             sum += input[(row + r)*in_stride + col];
1121                         }
1122                         col_bias[col] += sum;
1123                     }
1124                 }
1125             }
1126         }
1127     }
1128 
1129     for (unsigned int col=0; col<width; col++) {
1130         int32_t result = col_bias[col];
1131 
1132         result = (qp.a_offset * qp.b_offset * depth) - (result * qp.a_offset);
1133 
1134         if (qp.bias != nullptr) {
1135             result += qp.bias[multi * qp.bias_multi_stride + col + first_col];
1136         }
1137 
1138         col_bias[col] = result;
1139     }
1140 }
1141 
1142 template void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const int8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col);
1143 template void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const uint8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col);
1144 
1145 } // namespace arm_gemm
1146 
1147 #endif // __aarch64__
1148