xref: /aosp_15_r20/external/XNNPACK/src/qs8-gemm/neon-mull-addw-dup.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
7$assert NR % 8 == 0
8$assert 8 <= NR <= 16
9$assert REQUANTIZATION == "RNDNU"
10#include <assert.h>
11
12#include <arm_neon.h>
13
14#include <xnnpack/common.h>
15#include <xnnpack/gemm.h>
16
17
18void xnn_qs8_gemm_minmax_rndnu_ukernel_${MR}x${NR}__neon_mull_addw_dup(
19    size_t mr,
20    size_t nc,
21    size_t kc,
22    const int8_t* restrict a,
23    size_t a_stride,
24    const void* restrict w,
25    int8_t* restrict c,
26    size_t cm_stride,
27    size_t cn_stride,
28    const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
29{
30  assert(mr != 0);
31  assert(mr <= ${MR});
32  assert(nc != 0);
33  assert(kc != 0);
34  assert(kc % sizeof(int8_t) == 0);
35  assert(a != NULL);
36  assert(w != NULL);
37  assert(c != NULL);
38
39  const int8_t* a0 = a;
40  int8_t* c0 = c;
41  $for M in range(1, MR):
42    const int8_t* a${M} = (const int8_t*) ((uintptr_t) a${M-1} + a_stride);
43    int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride);
44    $if M % 2 == 0:
45      if XNN_UNPREDICTABLE(mr <= ${M}) {
46        a${M} = a${M-1};
47        c${M} = c${M-1};
48      }
49    $elif M + 1 == MR:
50      if XNN_UNPREDICTABLE(mr != ${M+1}) {
51        a${M} = a${M-1};
52        c${M} = c${M-1};
53      }
54    $else:
55      if XNN_UNPREDICTABLE(mr < ${M+1}) {
56        a${M} = a${M-1};
57        c${M} = c${M-1};
58      }
59
60  do {
61    $for N in range(0, NR, 4):
62      int32x4_t vacc0x${ABC[N:N+4]} = vld1q_s32(w); w = (const void*) ((uintptr_t) w + 4 * sizeof(int32_t));
63    $for M in range(1, MR):
64      $for N in range(0, NR, 4):
65        int32x4_t vacc${M}x${ABC[N:N+4]} = vacc0x${ABC[N:N+4]};
66
67    size_t k = kc;
68    while (k >= 8 * sizeof(int8_t)) {
69      $for M in range(MR):
70        const int8x8_t va${M} = vld1_s8(a${M}); a${M} += 8;
71
72      $for K in range(8):
73        $for N in range(0, NR, 8):
74          const int8x8_t vb${ABC[N:N+8]}c${K} = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
75
76          $for M in range(MR):
77            const int16x8_t vprod${M}x${ABC[N:N+8]}c${K} = vmull_s8(vb${ABC[N:N+8]}c${K}, vdup_lane_s8(va${M}, ${K}));
78            vacc${M}x${ABC[N:N+4]} = vaddw_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vprod${M}x${ABC[N:N+8]}c${K}));
79            vacc${M}x${ABC[N+4:N+8]} = vaddw_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vprod${M}x${ABC[N:N+8]}c${K}));
80
81      k -= 8 * sizeof(int8_t);
82    }
83    if XNN_UNLIKELY(k != 0) {
84      $for M in range(MR):
85        const int8x8_t va${M} = vld1_s8(a${M}); a${M} = (const int8_t*) ((uintptr_t) a${M} + k);
86
87      $for N in range(0, NR, 8):
88        const int8x8_t vb${ABC[N:N+8]}c0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
89
90      $for M in range(MR):
91        $for N in range(0, NR, 8):
92          const int16x8_t vprod${M}x${ABC[N:N+8]}c0 = vmull_s8(vb${ABC[N:N+8]}c0, vdup_lane_s8(va${M}, 0));
93          vacc${M}x${ABC[N:N+4]} = vaddw_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vprod${M}x${ABC[N:N+8]}c0));
94          vacc${M}x${ABC[N+4:N+8]} = vaddw_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vprod${M}x${ABC[N:N+8]}c0));
95
96      if (k >= 2 * sizeof(int8_t)) {
97        $for N in range(0, NR, 8):
98          const int8x8_t vb${ABC[N:N+8]}c1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
99
100        $for M in range(MR):
101          $for N in range(0, NR, 8):
102            const int16x8_t vprod${M}x${ABC[N:N+8]}c1 = vmull_s8(vb${ABC[N:N+8]}c1, vdup_lane_s8(va${M}, 1));
103            vacc${M}x${ABC[N:N+4]} = vaddw_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vprod${M}x${ABC[N:N+8]}c1));
104            vacc${M}x${ABC[N+4:N+8]} = vaddw_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vprod${M}x${ABC[N:N+8]}c1));
105
106        if (k > 2 * sizeof(int8_t)) {
107          $for N in range(0, NR, 8):
108            const int8x8_t vb${ABC[N:N+8]}c2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
109
110          $for M in range(MR):
111            $for N in range(0, NR, 8):
112              const int16x8_t vprod${M}x${ABC[N:N+8]}c2 = vmull_s8(vb${ABC[N:N+8]}c2, vdup_lane_s8(va${M}, 2));
113              vacc${M}x${ABC[N:N+4]} = vaddw_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vprod${M}x${ABC[N:N+8]}c2));
114              vacc${M}x${ABC[N+4:N+8]} = vaddw_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vprod${M}x${ABC[N:N+8]}c2));
115
116          if (k >= 4 * sizeof(int8_t)) {
117            $for N in range(0, NR, 8):
118              const int8x8_t vb${ABC[N:N+8]}c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
119
120            $for M in range(MR):
121              $for N in range(0, NR, 8):
122                const int16x8_t vprod${M}x${ABC[N:N+8]}c3 = vmull_s8(vb${ABC[N:N+8]}c3, vdup_lane_s8(va${M}, 3));
123                vacc${M}x${ABC[N:N+4]} = vaddw_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vprod${M}x${ABC[N:N+8]}c3));
124                vacc${M}x${ABC[N+4:N+8]} = vaddw_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vprod${M}x${ABC[N:N+8]}c3));
125
126            if (k > 4 * sizeof(int8_t)) {
127              $for N in range(0, NR, 8):
128                const int8x8_t vb${ABC[N:N+8]}c4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
129
130              $for M in range(MR):
131                $for N in range(0, NR, 8):
132                  const int16x8_t vprod${M}x${ABC[N:N+8]}c4 = vmull_s8(vb${ABC[N:N+8]}c4, vdup_lane_s8(va${M}, 4));
133                  vacc${M}x${ABC[N:N+4]} = vaddw_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vprod${M}x${ABC[N:N+8]}c4));
134                  vacc${M}x${ABC[N+4:N+8]} = vaddw_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vprod${M}x${ABC[N:N+8]}c4));
135
136              if (k >= 6 * sizeof(int8_t)) {
137                $for N in range(0, NR, 8):
138                  const int8x8_t vb${ABC[N:N+8]}c5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
139
140                $for M in range(MR):
141                  $for N in range(0, NR, 8):
142                    const int16x8_t vprod${M}x${ABC[N:N+8]}c5 = vmull_s8(vb${ABC[N:N+8]}c5, vdup_lane_s8(va${M}, 5));
143                    vacc${M}x${ABC[N:N+4]} = vaddw_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vprod${M}x${ABC[N:N+8]}c5));
144                    vacc${M}x${ABC[N+4:N+8]} = vaddw_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vprod${M}x${ABC[N:N+8]}c5));
145
146                if (k > 6 * sizeof(int8_t)) {
147                  $for N in range(0, NR, 8):
148                    const int8x8_t vb${ABC[N:N+8]}c6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
149
150                  $for M in range(MR):
151                    $for N in range(0, NR, 8):
152                      const int16x8_t vprod${M}x${ABC[N:N+8]}c6 = vmull_s8(vb${ABC[N:N+8]}c6, vdup_lane_s8(va${M}, 6));
153                      vacc${M}x${ABC[N:N+4]} = vaddw_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vprod${M}x${ABC[N:N+8]}c6));
154                      vacc${M}x${ABC[N+4:N+8]} = vaddw_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vprod${M}x${ABC[N:N+8]}c6));
155                }
156              }
157            }
158          }
159        }
160      }
161    }
162
163    // Post-accumulation work
164    const int32x4_t vright_pre_shift = vld1q_dup_s32(&params->rndnu_neon.right_pre_shift);
165    const int32x4_t vmultiplier = vld1q_dup_s32(&params->rndnu_neon.multiplier);
166    const int32x4_t vright_post_shift = vld1q_dup_s32(&params->rndnu_neon.right_post_shift);
167
168    $for M in range(MR):
169      $for N in range(0, NR, 4):
170        vacc${M}x${ABC[N:N+4]} = vqshlq_s32(vacc${M}x${ABC[N:N+4]}, vright_pre_shift);
171
172    $for M in range(MR):
173      $for N in range(0, NR, 4):
174        vacc${M}x${ABC[N:N+4]} = vqdmulhq_s32(vacc${M}x${ABC[N:N+4]}, vmultiplier);
175
176    $for M in range(MR):
177      $for N in range(0, NR, 4):
178        vacc${M}x${ABC[N:N+4]} = vrshlq_s32(vacc${M}x${ABC[N:N+4]}, vright_post_shift);
179
180    const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->rndnu_neon.output_zero_point);
181#if XNN_ARCH_ARM64
182    $for M in range(MR):
183      $for N in range(0, NR, 8):
184        const int16x8_t vacc${M}x${ABC[N:N+8]} = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc${M}x${ABC[N:N+4]}), vacc${M}x${ABC[N+4:N+8]}), voutput_zero_point);
185
186    $for M in range(MR):
187      $for N in range(0, NR, 16):
188        $if N + 8 < NR:
189          int8x16_t vout${M}x${ABC[N:N+16]} = vqmovn_high_s16(vqmovn_s16(vacc${M}x${ABC[N:N+8]}), vacc${M}x${ABC[N+8:N+16]});
190        $elif M % 2 == 1:
191          int8x16_t vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vqmovn_high_s16(vqmovn_s16(vacc${M-1}x${ABC[N:N+8]}), vacc${M}x${ABC[N:N+8]});
192        $elif M + 1 == MR:
193          int8x8_t vout${M}x${ABC[N:N+8]} = vqmovn_s16(vacc${M}x${ABC[N:N+8]});
194#else
195    $for M in range(MR):
196      $for N in range(0, NR, 8):
197        const int16x8_t vacc${M}x${ABC[N:N+8]} = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc${M}x${ABC[N:N+4]}), vqmovn_s32(vacc${M}x${ABC[N+4:N+8]})), voutput_zero_point);
198
199    $for M in range(MR):
200      $for N in range(0, NR, 16):
201        $if N + 8 < NR:
202          int8x16_t vout${M}x${ABC[N:N+16]} = vcombine_s8(vqmovn_s16(vacc${M}x${ABC[N:N+8]}), vqmovn_s16(vacc${M}x${ABC[N+8:N+16]}));
203        $elif M % 2 == 1:
204          int8x16_t vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vcombine_s8(vqmovn_s16(vacc${M-1}x${ABC[N:N+8]}), vqmovn_s16(vacc${M}x${ABC[N:N+8]}));
205        $elif M + 1 == MR:
206          int8x8_t vout${M}x${ABC[N:N+8]} = vqmovn_s16(vacc${M}x${ABC[N:N+8]});
207#endif
208    $if NR == 8 and MR == 1:
209      const int8x8_t voutput_min = vld1_dup_s8(&params->rndnu_neon.output_min);
210      const int8x8_t voutput_max = vld1_dup_s8(&params->rndnu_neon.output_max);
211    $else:
212      const int8x16_t voutput_min = vld1q_dup_s8(&params->rndnu_neon.output_min);
213      const int8x16_t voutput_max = vld1q_dup_s8(&params->rndnu_neon.output_max);
214
215    $for M in range(MR):
216      $for N in range(0, NR, 16):
217        $if N + 8 < NR:
218          vout${M}x${ABC[N:N+16]} = vmaxq_s8(vout${M}x${ABC[N:N+16]}, voutput_min);
219        $elif M % 2 == 1:
220          vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vmaxq_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}, voutput_min);
221        $elif M + 1 == MR:
222          $if NR == 8 and MR == 1:
223            vout${M}x${ABC[N:N+8]} = vmax_s8(vout${M}x${ABC[N:N+8]}, voutput_min);
224          $else:
225            vout${M}x${ABC[N:N+8]} = vmax_s8(vout${M}x${ABC[N:N+8]}, vget_low_s8(voutput_min));
226
227    $for M in range(MR):
228      $for N in range(0, NR, 16):
229        $if N + 8 < NR:
230          vout${M}x${ABC[N:N+16]} = vminq_s8(vout${M}x${ABC[N:N+16]}, voutput_max);
231        $elif M % 2 == 1:
232          vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vminq_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}, voutput_max);
233        $elif M + 1 == MR:
234          $if NR == 8 and MR == 1:
235            vout${M}x${ABC[N:N+8]} = vmin_s8(vout${M}x${ABC[N:N+8]}, voutput_max);
236          $else:
237            vout${M}x${ABC[N:N+8]} = vmin_s8(vout${M}x${ABC[N:N+8]}, vget_low_s8(voutput_max));
238
239    if (nc >= ${NR}) {
240      // Main case where there the ${NR} columns fit in the destination.
241      $for M in range(MR):
242        $for N in range(0, NR, 16):
243          $if N + 8 < NR:
244            vst1q_s8(c${M} + ${N}, vout${M}x${ABC[N:N+16]});
245          $elif M % 2 == 1:
246            vst1_s8(c${M-1} + ${N}, vget_low_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}));
247            vst1_s8(c${M} + ${N}, vget_high_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}));
248          $elif M + 1 == MR:
249            vst1_s8(c${M} + ${N}, vout${M}x${ABC[N:N+8]});
250
251      // Advance to the next ${NR} columns.
252      $for M in range(MR):
253        c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride);
254
255      $for M in range(MR):
256        a${M} = (const int8_t*) ((uintptr_t) a${M} - kc);
257
258      nc -= ${NR};
259    } else {
260      // Final case where not all of the ${NR} columns fit in the destination.
261      $if NR == 16:
262        $for M in range(MR):
263          $if M % 2 == 1:
264            int8x16_t vout${M-1}x01234567_${M}x01234567 = vcombine_s8(vget_low_s8(vout${M-1}x0123456789ABCDEF), vget_low_s8(vout${M}x0123456789ABCDEF));
265          $elif M + 1 == MR:
266            int8x8_t vout${M}x01234567 = vget_low_s8(vout${M}x0123456789ABCDEF);
267        if (nc & 8) {
268          $for M in range(MR):
269            $if M % 2 == 1:
270              vst1_s8(c${M-1}, vget_low_s8(vout${M-1}x01234567_${M}x01234567)); c${M-1} += 8;
271              vst1_s8(c${M}, vget_high_s8(vout${M-1}x01234567_${M}x01234567)); c${M} += 8;
272            $elif M + 1 == MR:
273              vst1_s8(c${M}, vout${M}x01234567); c${M} += 8;
274          $for M in range(MR):
275            $if M % 2 == 1:
276              vout${M-1}x01234567_${M}x01234567 = vcombine_s8(vget_high_s8(vout${M-1}x0123456789ABCDEF), vget_high_s8(vout${M}x0123456789ABCDEF));
277            $elif M + 1 == MR:
278              vout${M}x01234567 = vget_high_s8(vout${M}x0123456789ABCDEF);
279        }
280      if (nc & 4) {
281        $for M in range(MR):
282          $if M % 2 == 1:
283            vst1q_lane_u32((void*) c${M-1}, vreinterpretq_u32_s8(vout${M-1}x01234567_${M}x01234567), 0); c${M-1} += 4;
284            vst1q_lane_u32((void*) c${M}, vreinterpretq_u32_s8(vout${M-1}x01234567_${M}x01234567), 2); c${M} += 4;
285          $elif M + 1 == MR:
286            vst1_lane_u32((void*) c${M}, vreinterpret_u32_s8(vout${M}x01234567), 0); c${M} += 4;
287        $for M in range(MR):
288          $if M % 2 == 1:
289            vout${M-1}x01234567_${M}x01234567 = vextq_s8(vout${M-1}x01234567_${M}x01234567, vout${M-1}x01234567_${M}x01234567, 4);
290          $elif M + 1 == MR:
291            vout${M}x01234567 = vext_s8(vout${M}x01234567, vout${M}x01234567, 4);
292      }
293      if (nc & 2) {
294        $for M in range(MR):
295          $if M % 2 == 1:
296            vst1q_lane_u16((void*) c${M-1}, vreinterpretq_u16_s8(vout${M-1}x01234567_${M}x01234567), 0); c${M-1} += 2;
297            vst1q_lane_u16((void*) c${M}, vreinterpretq_u16_s8(vout${M-1}x01234567_${M}x01234567), 4); c${M} += 2;
298          $elif M + 1 == MR:
299            vst1_lane_u16((void*) c${M}, vreinterpret_u16_s8(vout${M}x01234567), 0); c${M} += 2;
300        $for M in range(MR):
301          $if M % 2 == 1:
302            vout${M-1}x01234567_${M}x01234567 = vextq_s8(vout${M-1}x01234567_${M}x01234567, vout${M-1}x01234567_${M}x01234567, 2);
303          $elif M + 1 == MR:
304            vout${M}x01234567 = vext_s8(vout${M}x01234567, vout${M}x01234567, 2);
305      }
306      if (nc & 1) {
307        $for M in range(MR):
308          $if M % 2 == 1:
309            vst1q_lane_s8(c${M-1}, vout${M-1}x01234567_${M}x01234567, 0);
310            vst1q_lane_s8(c${M}, vout${M-1}x01234567_${M}x01234567, 8);
311          $elif M + 1 == MR:
312            vst1_lane_s8(c${M}, vout${M}x01234567, 0);
313      }
314
315      nc = 0;
316    }
317  } while (nc != 0);
318}
319