xref: /aosp_15_r20/external/XNNPACK/src/qs8-igemm/c8-neon-mull.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2021 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
7$assert NR % 8 == 0
8$assert 8 <= NR <= 16
9$assert REQUANTIZATION in ["FP32", "RNDNU"]
10$assert not CHANNELWISE or REQUANTIZATION == "FP32"
11#include <assert.h>
12
13#include <arm_neon.h>
14
15#include <xnnpack/igemm.h>
16$if REQUANTIZATION == "FP32" and ARMV8:
17  #include <xnnpack/intrinsics-polyfill.h>
18#include <xnnpack/math.h>
19
20
21$DATATYPE = "qc8" if CHANNELWISE else "qs8"
22$PARAMS_STRUCT = REQUANTIZATION.lower() + "_" + ("neonv8" if REQUANTIZATION == "FP32" and ARMV8 else "neon")
23$PARAMS_UNION = "xnn_%s_conv_minmax_params" % DATATYPE.lower()
24$ISA = "neonv8" if ARMV8 else "neon"
25void xnn_${DATATYPE}_igemm_minmax_${REQUANTIZATION.lower()}_ukernel_${MR}x${NR}c8__${ISA}_${"mlal" if MLA else "mull"}(
26    size_t mr,
27    size_t nc,
28    size_t kc,
29    size_t ks,
30    const int8_t** restrict a,
31    const void* restrict w,
32    int8_t* restrict c,
33    size_t cm_stride,
34    size_t cn_stride,
35    size_t a_offset,
36    const int8_t* zero,
37    const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
38{
39  assert(mr != 0);
40  assert(mr <= ${MR});
41  assert(nc != 0);
42  assert(kc != 0);
43  assert(ks != 0);
44  assert(ks % (${MR} * sizeof(void*)) == 0);
45  assert(a_offset % sizeof(int8_t) == 0);
46  assert(a != NULL);
47  assert(w != NULL);
48  assert(c != NULL);
49
50  kc = round_up_po2(kc, 8 * sizeof(int8_t));
51  int8_t* c0 = c;
52  $for M in range(1, MR):
53    int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride);
54    $if M % 2 == 0:
55      if XNN_UNPREDICTABLE(mr <= ${M}) {
56        c${M} = c${M-1};
57      }
58    $elif M + 1 == MR:
59      if XNN_UNPREDICTABLE(mr != ${M+1}) {
60        c${M} = c${M-1};
61      }
62    $else:
63      if XNN_UNPREDICTABLE(mr < ${M+1}) {
64        c${M} = c${M-1};
65      }
66
67  do {
68    $for N in range(NR):
69      int32x4_t vacc0x${N} = vld1q_lane_s32(w, vmovq_n_s32(0), 0); w = (const void*) ((uintptr_t) w + sizeof(int32_t));
70    $for M in range(1, MR):
71      $for N in range(NR):
72        int32x4_t vacc${M}x${N} = vacc0x${N};
73
74    size_t p = ks;
75    do {
76      $for M in range(MR):
77        const int8_t* restrict a${M} = a[${M}];
78        if XNN_UNPREDICTABLE(a${M} != zero) {
79          a${M} = (const int8_t*) ((uintptr_t) a${M} + a_offset);
80        }
81      a += ${MR};
82
83      size_t k = kc;
84      $if MLA:
85        // 2x partial unrolled loop to load 16 bytes at a time using MLA.
86        while (k >= 16 * sizeof(int8_t)) {
87          $for M in range(MR):
88            const int8x8_t va${M}x0 = vld1_s8(a${M}); a${M} += 8;
89            const int8x8_t va${M}x1 = vld1_s8(a${M}); a${M} += 8;
90
91          $for N in range(NR):
92            const int8x8_t vb${N}x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(  int8_t));
93
94          $for N in range(NR):
95            const int8x8_t vb${N}x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(  int8_t));
96            $for M in range(MR):
97              int16x8_t vprod${M}x${N} = vmull_s8(vb${N}x0, va${M}x0);
98            $for M in range(MR):
99              vprod${M}x${N} = vmlal_s8(vprod${M}x${N}, vb${N}x1, va${M}x1);
100            $for M in range(MR):
101              vacc${M}x${N} = vpadalq_s16(vacc${M}x${N}, vprod${M}x${N});
102
103          k -= 16 * sizeof(int8_t);
104        }
105
106      // Handle 8 bytes at a time using MUL.
107      ${"if" if MLA else "while"} (k != 0) {
108        $for M in range(MR):
109          const int8x8_t va${M} = vld1_s8(a${M}); a${M} += 8;
110
111        $for N in range(NR):
112          const int8x8_t vb${N} = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
113          $for M in range(MR):
114            const int16x8_t vprod${M}x${N} = vmull_s8(vb${N}, va${M});
115          $for M in range(MR):
116            vacc${M}x${N} = vpadalq_s16(vacc${M}x${N}, vprod${M}x${N});
117
118        k -= 8 * sizeof(int8_t);
119      }
120
121      p -= ${MR} * sizeof(void*);
122    } while (p != 0);
123
124#if XNN_ARCH_ARM64
125    $for M in range(MR):
126      $for N in range(0, NR, 4):
127        const int32x4_t vsum${M}x${ABC[N:N+2]} = vpaddq_s32(vacc${M}x${N}, vacc${M}x${N+1});
128        const int32x4_t vsum${M}x${ABC[N+2:N+4]} = vpaddq_s32(vacc${M}x${N+2}, vacc${M}x${N+3});
129
130    $for M in range(MR):
131      $for N in range(0, NR, 4):
132        int32x4_t vacc${M}x${ABC[N:N+4]} = vpaddq_s32(vsum${M}x${ABC[N:N+2]}, vsum${M}x${ABC[N+2:N+4]});
133#else
134    $for M in range(MR):
135      $for N in range(0, NR, 4):
136        const int32x2_t vpsum${M}x${ABC[N]} = vadd_s32(vget_low_s32(vacc${M}x${N}), vget_high_s32(vacc${M}x${N}));
137        const int32x2_t vpsum${M}x${ABC[N+1]} = vadd_s32(vget_low_s32(vacc${M}x${N+1}), vget_high_s32(vacc${M}x${N+1}));
138        const int32x2_t vpsum${M}x${ABC[N+2]} = vadd_s32(vget_low_s32(vacc${M}x${N+2}), vget_high_s32(vacc${M}x${N+2}));
139        const int32x2_t vpsum${M}x${ABC[N+3]} = vadd_s32(vget_low_s32(vacc${M}x${N+3}), vget_high_s32(vacc${M}x${N+3}));
140        const int32x2_t vsum${M}x${ABC[N:N+2]} = vpadd_s32(vpsum${M}x${ABC[N]}, vpsum${M}x${ABC[N+1]});
141        const int32x2_t vsum${M}x${ABC[N+2:N+4]} = vpadd_s32(vpsum${M}x${ABC[N+2]}, vpsum${M}x${ABC[N+3]});
142        int32x4_t vacc${M}x${ABC[N:N+4]} = vcombine_s32(vsum${M}x${ABC[N:N+2]}, vsum${M}x${ABC[N+2:N+4]} );
143#endif
144
145    $if REQUANTIZATION == "RNDNU":
146      const int32x4_t vright_pre_shift = vld1q_dup_s32(&params->${PARAMS_STRUCT}.right_pre_shift);
147      const int32x4_t vmultiplier = vld1q_dup_s32(&params->${PARAMS_STRUCT}.multiplier);
148      const int32x4_t vright_post_shift = vld1q_dup_s32(&params->${PARAMS_STRUCT}.right_post_shift);
149
150      $for M in range(MR):
151        $for N in range(0, NR, 4):
152          vacc${M}x${ABC[N:N+4]} = vqshlq_s32(vacc${M}x${ABC[N:N+4]}, vright_pre_shift);
153
154      $for M in range(MR):
155        $for N in range(0, NR, 4):
156          vacc${M}x${ABC[N:N+4]} = vqdmulhq_s32(vacc${M}x${ABC[N:N+4]}, vmultiplier);
157
158      $for M in range(MR):
159        $for N in range(0, NR, 4):
160          vacc${M}x${ABC[N:N+4]} = vrshlq_s32(vacc${M}x${ABC[N:N+4]}, vright_post_shift);
161    $elif REQUANTIZATION == "FP32":
162      $for M in range(MR):
163        $for N in range(0, NR, 4):
164          float32x4_t vfpacc${M}x${ABC[N:N+4]} = vcvtq_f32_s32(vacc${M}x${ABC[N:N+4]});
165
166      $if CHANNELWISE:
167        $for N in range(0, NR, 4):
168          const float32x4_t vscale${ABC[N:N+4]} = vld1q_f32((const float*) w); w = (const void*) ((const float*) w + 4);
169          $for M in range(MR):
170            vfpacc${M}x${ABC[N:N+4]} = vmulq_f32(vfpacc${M}x${ABC[N:N+4]}, vscale${ABC[N:N+4]});
171      $else:
172        const float32x4_t vscale = vld1q_dup_f32(&params->${PARAMS_STRUCT}.scale);
173        $for M in range(MR):
174          $for N in range(0, NR, 4):
175            vfpacc${M}x${ABC[N:N+4]} = vmulq_f32(vfpacc${M}x${ABC[N:N+4]}, vscale);
176
177      $if ARMV8:
178        $for M in range(MR):
179          $for N in range(0, NR, 4):
180            vacc${M}x${ABC[N:N+4]} = vcvtnq_s32_f32(vfpacc${M}x${ABC[N:N+4]});
181      $else:
182        const float32x4_t vmagic_bias = vld1q_dup_f32(&params->${PARAMS_STRUCT}.magic_bias);
183        $for M in range(MR):
184          $for N in range(0, NR, 4):
185            vacc${M}x${ABC[N:N+4]} = vreinterpretq_s32_f32(vaddq_f32(vfpacc${M}x${ABC[N:N+4]}, vmagic_bias));
186
187        const int32x4_t vmagic_bias_less_output_zero_point = vld1q_dup_s32(&params->${PARAMS_STRUCT}.magic_bias_less_output_zero_point);
188        $for M in range(MR):
189          $for N in range(0, NR, 4):
190            vacc${M}x${ABC[N:N+4]} = vqsubq_s32(vacc${M}x${ABC[N:N+4]}, vmagic_bias_less_output_zero_point);
191
192    $if REQUANTIZATION != "FP32" or ARMV8:
193      const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->${PARAMS_STRUCT}.output_zero_point);
194#if XNN_ARCH_ARM64
195    $for M in range(MR):
196      $for N in range(0, NR, 8):
197        int16x8_t vacc${M}x${ABC[N:N+8]} = vqmovn_high_s32(vqmovn_s32(vacc${M}x${ABC[N:N+4]}), vacc${M}x${ABC[N+4:N+8]});
198
199    $if REQUANTIZATION != "FP32" or ARMV8:
200      $for M in range(MR):
201        $for N in range(0, NR, 8):
202          vacc${M}x${ABC[N:N+8]} = vqaddq_s16(vacc${M}x${ABC[N:N+8]}, voutput_zero_point);
203
204    $for M in range(MR):
205      $for N in range(0, NR, 16):
206        $if N + 8 < NR:
207          int8x16_t vout${M}x${ABC[N:N+16]} = vqmovn_high_s16(vqmovn_s16(vacc${M}x${ABC[N:N+8]}), vacc${M}x${ABC[N+8:N+16]});
208        $elif M % 2 == 1:
209          int8x16_t vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vqmovn_high_s16(vqmovn_s16(vacc${M-1}x${ABC[N:N+8]}), vacc${M}x${ABC[N:N+8]});
210        $elif M + 1 == MR:
211          int8x8_t vout${M}x${ABC[N:N+8]} = vqmovn_s16(vacc${M}x${ABC[N:N+8]});
212#else
213    $for M in range(MR):
214      $for N in range(0, NR, 8):
215        int16x8_t vacc${M}x${ABC[N:N+8]} = vcombine_s16(vqmovn_s32(vacc${M}x${ABC[N:N+4]}), vqmovn_s32(vacc${M}x${ABC[N+4:N+8]}));
216
217    $if REQUANTIZATION != "FP32" or ARMV8:
218      $for M in range(MR):
219        $for N in range(0, NR, 8):
220          vacc${M}x${ABC[N:N+8]} = vqaddq_s16(vacc${M}x${ABC[N:N+8]}, voutput_zero_point);
221
222    $for M in range(MR):
223      $for N in range(0, NR, 16):
224        $if N + 8 < NR:
225          int8x16_t vout${M}x${ABC[N:N+16]} = vcombine_s8(vqmovn_s16(vacc${M}x${ABC[N:N+8]}), vqmovn_s16(vacc${M}x${ABC[N+8:N+16]}));
226        $elif M % 2 == 1:
227          int8x16_t vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vcombine_s8(vqmovn_s16(vacc${M-1}x${ABC[N:N+8]}), vqmovn_s16(vacc${M}x${ABC[N:N+8]}));
228        $elif M + 1 == MR:
229          int8x8_t vout${M}x${ABC[N:N+8]} = vqmovn_s16(vacc${M}x${ABC[N:N+8]});
230#endif
231
232    $if NR == 8 and MR == 1:
233      const int8x8_t voutput_min = vld1_dup_s8(&params->${PARAMS_STRUCT}.output_min);
234    $else:
235      const int8x16_t voutput_min = vld1q_dup_s8(&params->${PARAMS_STRUCT}.output_min);
236    $for M in range(MR):
237      $for N in range(0, NR, 16):
238        $if N + 8 < NR:
239          vout${M}x${ABC[N:N+16]} = vmaxq_s8(vout${M}x${ABC[N:N+16]}, voutput_min);
240        $elif M % 2 == 1:
241          vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vmaxq_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}, voutput_min);
242        $elif M + 1 == MR:
243          $if NR == 8 and MR == 1:
244            vout${M}x${ABC[N:N+8]} = vmax_s8(vout${M}x${ABC[N:N+8]}, voutput_min);
245          $else:
246            vout${M}x${ABC[N:N+8]} = vmax_s8(vout${M}x${ABC[N:N+8]}, vget_low_s8(voutput_min));
247
248    $if NR == 8 and MR == 1:
249      const int8x8_t voutput_max = vld1_dup_s8(&params->${PARAMS_STRUCT}.output_max);
250    $else:
251      const int8x16_t voutput_max = vld1q_dup_s8(&params->${PARAMS_STRUCT}.output_max);
252    $for M in range(MR):
253      $for N in range(0, NR, 16):
254        $if N + 8 < NR:
255          vout${M}x${ABC[N:N+16]} = vminq_s8(vout${M}x${ABC[N:N+16]}, voutput_max);
256        $elif M % 2 == 1:
257          vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vminq_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}, voutput_max);
258        $elif M + 1 == MR:
259          $if NR == 8 and MR == 1:
260            vout${M}x${ABC[N:N+8]} = vmin_s8(vout${M}x${ABC[N:N+8]}, voutput_max);
261          $else:
262            vout${M}x${ABC[N:N+8]} = vmin_s8(vout${M}x${ABC[N:N+8]}, vget_low_s8(voutput_max));
263
264    if (nc >= ${NR}) {
265      $for M in reversed(range(MR)):
266        $for N in range(0, NR, 16):
267          $if N + 8 < NR:
268            vst1q_s8(c${M} + ${N}, vout${M}x${ABC[N:N+16]});
269          $elif M % 2 == 1:
270            vst1_s8(c${M} + ${N}, vget_high_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}));
271            vst1_s8(c${M-1} + ${N}, vget_low_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}));
272          $elif M + 1 == MR:
273            vst1_s8(c${M} + ${N}, vout${M}x${ABC[N:N+8]});
274
275      $for M in reversed(range(MR)):
276        c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride);
277
278      a = (const int8_t**restrict) ((uintptr_t) a - ks);
279
280      nc -= ${NR};
281    } else {
282      $if NR == 16:
283        $for M in reversed(range(MR)):
284          $if M % 2 == 1:
285            int8x16_t vout${M-1}x01234567_${M}x01234567 = vcombine_s8(vget_low_s8(vout${M-1}x0123456789ABCDEF), vget_low_s8(vout${M}x0123456789ABCDEF));
286          $elif M + 1 == MR:
287            int8x8_t vout${M}x01234567 = vget_low_s8(vout${M}x0123456789ABCDEF);
288        if (nc & 8) {
289          $for M in reversed(range(MR)):
290            $if M % 2 == 1:
291              vst1_s8(c${M}, vget_high_s8(vout${M-1}x01234567_${M}x01234567)); c${M} += 8;
292              vst1_s8(c${M-1}, vget_low_s8(vout${M-1}x01234567_${M}x01234567)); c${M-1} += 8;
293            $elif M + 1 == MR:
294              vst1_s8(c${M}, vout${M}x01234567); c${M} += 8;
295          $for M in reversed(range(MR)):
296            $if M % 2 == 1:
297              vout${M-1}x01234567_${M}x01234567 = vcombine_s8(vget_high_s8(vout${M-1}x0123456789ABCDEF), vget_high_s8(vout${M}x0123456789ABCDEF));
298            $elif M + 1 == MR:
299              vout${M}x01234567 = vget_high_s8(vout${M}x0123456789ABCDEF);
300        }
301      if (nc & 4) {
302        $for M in reversed(range(MR)):
303          $if M % 2 == 1:
304            vst1q_lane_u32((void*) c${M}, vreinterpretq_u32_s8(vout${M-1}x01234567_${M}x01234567), 2); c${M} += 4;
305            vst1q_lane_u32((void*) c${M-1}, vreinterpretq_u32_s8(vout${M-1}x01234567_${M}x01234567), 0); c${M-1} += 4;
306          $elif M + 1 == MR:
307            vst1_lane_u32((void*) c${M}, vreinterpret_u32_s8(vout${M}x01234567), 0); c${M} += 4;
308        $for M in reversed(range(MR)):
309          $if M % 2 == 1:
310            vout${M-1}x01234567_${M}x01234567 = vextq_s8(vout${M-1}x01234567_${M}x01234567, vout${M-1}x01234567_${M}x01234567, 4);
311          $elif M + 1 == MR:
312            vout${M}x01234567 = vext_s8(vout${M}x01234567, vout${M}x01234567, 4);
313      }
314      if (nc & 2) {
315        $for M in reversed(range(MR)):
316          $if M % 2 == 1:
317            vst1q_lane_u16((void*) c${M}, vreinterpretq_u16_s8(vout${M-1}x01234567_${M}x01234567), 4); c${M} += 2;
318            vst1q_lane_u16((void*) c${M-1}, vreinterpretq_u16_s8(vout${M-1}x01234567_${M}x01234567), 0); c${M-1} += 2;
319          $elif M + 1 == MR:
320            vst1_lane_u16((void*) c${M}, vreinterpret_u16_s8(vout${M}x01234567), 0); c${M} += 2;
321        $for M in reversed(range(MR)):
322          $if M % 2 == 1:
323            vout${M-1}x01234567_${M}x01234567 = vextq_s8(vout${M-1}x01234567_${M}x01234567, vout${M-1}x01234567_${M}x01234567, 2);
324          $elif M + 1 == MR:
325            vout${M}x01234567 = vext_s8(vout${M}x01234567, vout${M}x01234567, 2);
326      }
327      if (nc & 1) {
328        $for M in reversed(range(MR)):
329          $if M % 2 == 1:
330            vst1q_lane_s8(c${M}, vout${M-1}x01234567_${M}x01234567, 8);
331            vst1q_lane_s8(c${M-1}, vout${M-1}x01234567_${M}x01234567, 0);
332          $elif M + 1 == MR:
333            vst1_lane_s8(c${M}, vout${M}x01234567, 0);
334      }
335
336      nc = 0;
337    }
338  } while (nc != 0);
339}
340