xref: /aosp_15_r20/external/XNNPACK/src/qs8-gemm/c2-neon-mull-shuffle.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2021 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
7$assert NR % 8 == 0
8$assert 8 <= NR <= 16
9$assert REQUANTIZATION in ["FP32", "RNDNU"]
10$assert not CHANNELWISE or REQUANTIZATION == "FP32"
11#include <assert.h>
12
13#include <arm_neon.h>
14
15#include <xnnpack/gemm.h>
16$if REQUANTIZATION == "FP32" and ARMV8:
17  #include <xnnpack/intrinsics-polyfill.h>
18#include <xnnpack/math.h>
19
20
21$DATATYPE = "qc8" if CHANNELWISE else "qs8"
22$PARAMS_STRUCT = REQUANTIZATION.lower() + "_" + ("neonv8" if REQUANTIZATION == "FP32" and ARMV8 else "neon")
23$PARAMS_UNION = "xnn_%s_conv_minmax_params" % DATATYPE.lower()
24$ISA = "neonv8" if ARMV8 else "neon"
25void xnn_${DATATYPE}_gemm_minmax_${REQUANTIZATION.lower()}_ukernel_${MR}x${NR}c2s4__${ISA}_${"mlal" if MLA else "mull"}(
26    size_t mr,
27    size_t nc,
28    size_t kc,
29    const int8_t* restrict a,
30    size_t a_stride,
31    const void* restrict w,
32    int8_t* restrict c,
33    size_t cm_stride,
34    size_t cn_stride,
35    const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
36{
37  assert(mr != 0);
38  assert(mr <= ${MR});
39  assert(nc != 0);
40  assert(kc != 0);
41  assert(kc % sizeof(int8_t) == 0);
42  assert(a != NULL);
43  assert(w != NULL);
44  assert(c != NULL);
45
46  const int8_t* a0 = a;
47  int8_t* c0 = c;
48  $for M in range(1, MR):
49    const int8_t* a${M} = (const int8_t*) ((uintptr_t) a${M-1} + a_stride);
50    int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride);
51    $if M % 2 == 0:
52      if XNN_UNPREDICTABLE(mr <= ${M}) {
53        a${M} = a${M-1};
54        c${M} = c${M-1};
55      }
56    $elif M + 1 == MR:
57      if XNN_UNPREDICTABLE(mr != ${M+1}) {
58        a${M} = a${M-1};
59        c${M} = c${M-1};
60      }
61    $else:
62      if XNN_UNPREDICTABLE(mr < ${M+1}) {
63        a${M} = a${M-1};
64        c${M} = c${M-1};
65      }
66
67  kc = round_up_po2(kc, 8 * sizeof(int8_t));
68  do {
69    $for N in range(0, NR, 4):
70      int32x4_t vacc0x${ABC[N:N+4]} = vld1q_s32(w); w = (const int32_t*) w + 4;
71    $for M in range(1, MR):
72      $for N in range(0, NR, 4):
73        int32x4_t vacc${M}x${ABC[N:N+4]} = vacc0x${ABC[N:N+4]};
74
75    size_t k = kc;
76    $if MLA:
77      while (k >= 16 * sizeof(int8_t)) {
78        $for M in range(MR):
79          int8x8_t va${M}x0 = vld1_s8(a${M}); a${M} += 8;
80          int8x8_t va${M}x1 = vld1_s8(a${M}); a${M} += 8;
81
82        $for K in range(4):
83          $for N in range(0, NR, 4):
84            const int8x8_t vb${ABC[N:N+4]}c${K}x0 = vld1_s8(w); w = (const int8_t*) w + 8;
85
86        $for K in range(4):
87          $for N in range(0, NR, 4):
88            $for M in range(MR):
89              int16x8_t vprod${M}x${ABC[N:N+4]}c${K} = vmull_s8(vb${ABC[N:N+4]}c${K}x0, va${M}x0);
90            const int8x8_t vb${ABC[N:N+4]}c${K}x1 = vld1_s8(w); w = (const int8_t*) w + 8;
91            $for M in range(MR):
92              vprod${M}x${ABC[N:N+4]}c${K} = vmlal_s8(vprod${M}x${ABC[N:N+4]}c${K}, vb${ABC[N:N+4]}c${K}x1, va${M}x1);
93            $for M in range(MR):
94              vacc${M}x${ABC[N:N+4]} = vpadalq_s16(vacc${M}x${ABC[N:N+4]}, vprod${M}x${ABC[N:N+4]}c${K});
95          $if K + 1 != 4:
96            $for M in range(MR):
97              va${M}x0 = vext_s8(va${M}x0, va${M}x0, 2);
98              va${M}x1 = vext_s8(va${M}x1, va${M}x1, 2);
99
100        k -= 16 * sizeof(int8_t);
101      }
102    ${"if (k != 0)" if MLA else "do"} {
103      $for M in range(MR):
104        int8x8_t va${M}x0 = vld1_s8(a${M}); a${M} += 8;
105
106      $for K in range(4):
107        $for N in range(0, NR, 4):
108          const int8x8_t vb${ABC[N:N+4]}c${K}x0 = vld1_s8(w); w = (const int8_t*) w + 8;
109
110      $for K in range(4):
111        $for N in range(0, NR, 4):
112          $for M in range(MR):
113            int16x8_t vprod${M}x${ABC[N:N+4]}c${K} = vmull_s8(vb${ABC[N:N+4]}c${K}x0, va${M}x0);
114          $for M in range(MR):
115            vacc${M}x${ABC[N:N+4]} = vpadalq_s16(vacc${M}x${ABC[N:N+4]}, vprod${M}x${ABC[N:N+4]}c${K});
116        $if K + 1 != 4:
117          $for M in range(MR):
118            va${M}x0 = vext_s8(va${M}x0, va${M}x0, 2);
119
120      $if not MLA:
121        k -= 8 * sizeof(int8_t);
122    }${"" if MLA else " while (k != 0);"}
123
124    $if REQUANTIZATION == "RNDNU":
125      const int32x4_t vright_pre_shift = vld1q_dup_s32(&params->${PARAMS_STRUCT}.right_pre_shift);
126      const int32x4_t vmultiplier = vld1q_dup_s32(&params->${PARAMS_STRUCT}.multiplier);
127      const int32x4_t vright_post_shift = vld1q_dup_s32(&params->${PARAMS_STRUCT}.right_post_shift);
128
129      $for M in range(MR):
130        $for N in range(0, NR, 4):
131          vacc${M}x${ABC[N:N+4]} = vqshlq_s32(vacc${M}x${ABC[N:N+4]}, vright_pre_shift);
132
133      $for M in range(MR):
134        $for N in range(0, NR, 4):
135          vacc${M}x${ABC[N:N+4]} = vqdmulhq_s32(vacc${M}x${ABC[N:N+4]}, vmultiplier);
136
137      $for M in range(MR):
138        $for N in range(0, NR, 4):
139          vacc${M}x${ABC[N:N+4]} = vrshlq_s32(vacc${M}x${ABC[N:N+4]}, vright_post_shift);
140    $elif REQUANTIZATION == "FP32":
141      $for M in range(MR):
142        $for N in range(0, NR, 4):
143          float32x4_t vfpacc${M}x${ABC[N:N+4]} = vcvtq_f32_s32(vacc${M}x${ABC[N:N+4]});
144
145      $if CHANNELWISE:
146        $for N in range(0, NR, 4):
147          const float32x4_t vscale${ABC[N:N+4]} = vld1q_f32(w); w = (const float*) w + 4;
148          $for M in range(MR):
149            vfpacc${M}x${ABC[N:N+4]} = vmulq_f32(vfpacc${M}x${ABC[N:N+4]}, vscale${ABC[N:N+4]});
150      $else:
151        const float32x4_t vscale = vld1q_dup_f32(&params->${PARAMS_STRUCT}.scale);
152        $for M in range(MR):
153          $for N in range(0, NR, 4):
154            vfpacc${M}x${ABC[N:N+4]} = vmulq_f32(vfpacc${M}x${ABC[N:N+4]}, vscale);
155
156      $if ARMV8:
157        $for M in range(MR):
158          $for N in range(0, NR, 4):
159            vacc${M}x${ABC[N:N+4]} = vcvtnq_s32_f32(vfpacc${M}x${ABC[N:N+4]});
160      $else:
161        const float32x4_t vmagic_bias = vld1q_dup_f32(&params->${PARAMS_STRUCT}.magic_bias);
162        $for M in range(MR):
163          $for N in range(0, NR, 4):
164            vacc${M}x${ABC[N:N+4]} = vreinterpretq_s32_f32(vaddq_f32(vfpacc${M}x${ABC[N:N+4]}, vmagic_bias));
165
166        const int32x4_t vmagic_bias_less_output_zero_point = vld1q_dup_s32(&params->${PARAMS_STRUCT}.magic_bias_less_output_zero_point);
167        $for M in range(MR):
168          $for N in range(0, NR, 4):
169            vacc${M}x${ABC[N:N+4]} = vqsubq_s32(vacc${M}x${ABC[N:N+4]}, vmagic_bias_less_output_zero_point);
170
171    $if REQUANTIZATION != "FP32" or ARMV8:
172      const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->${PARAMS_STRUCT}.output_zero_point);
173#if XNN_ARCH_ARM64
174    $for M in range(MR):
175      $for N in range(0, NR, 8):
176        int16x8_t vacc${M}x${ABC[N:N+8]} = vqmovn_high_s32(vqmovn_s32(vacc${M}x${ABC[N:N+4]}), vacc${M}x${ABC[N+4:N+8]});
177
178    $if REQUANTIZATION != "FP32" or ARMV8:
179      $for M in range(MR):
180        $for N in range(0, NR, 8):
181          vacc${M}x${ABC[N:N+8]} = vqaddq_s16(vacc${M}x${ABC[N:N+8]}, voutput_zero_point);
182
183    $for M in range(MR):
184      $for N in range(0, NR, 16):
185        $if N + 8 < NR:
186          int8x16_t vout${M}x${ABC[N:N+16]} = vqmovn_high_s16(vqmovn_s16(vacc${M}x${ABC[N:N+8]}), vacc${M}x${ABC[N+8:N+16]});
187        $elif M % 2 == 1:
188          int8x16_t vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vqmovn_high_s16(vqmovn_s16(vacc${M-1}x${ABC[N:N+8]}), vacc${M}x${ABC[N:N+8]});
189        $elif M + 1 == MR:
190          int8x8_t vout${M}x${ABC[N:N+8]} = vqmovn_s16(vacc${M}x${ABC[N:N+8]});
191#else
192    $for M in range(MR):
193      $for N in range(0, NR, 8):
194        int16x8_t vacc${M}x${ABC[N:N+8]} = vcombine_s16(vqmovn_s32(vacc${M}x${ABC[N:N+4]}), vqmovn_s32(vacc${M}x${ABC[N+4:N+8]}));
195
196    $if REQUANTIZATION != "FP32" or ARMV8:
197      $for M in range(MR):
198        $for N in range(0, NR, 8):
199          vacc${M}x${ABC[N:N+8]} = vqaddq_s16(vacc${M}x${ABC[N:N+8]}, voutput_zero_point);
200
201    $for M in range(MR):
202      $for N in range(0, NR, 16):
203        $if N + 8 < NR:
204          int8x16_t vout${M}x${ABC[N:N+16]} = vcombine_s8(vqmovn_s16(vacc${M}x${ABC[N:N+8]}), vqmovn_s16(vacc${M}x${ABC[N+8:N+16]}));
205        $elif M % 2 == 1:
206          int8x16_t vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vcombine_s8(vqmovn_s16(vacc${M-1}x${ABC[N:N+8]}), vqmovn_s16(vacc${M}x${ABC[N:N+8]}));
207        $elif M + 1 == MR:
208          int8x8_t vout${M}x${ABC[N:N+8]} = vqmovn_s16(vacc${M}x${ABC[N:N+8]});
209#endif
210
211    $if NR == 8 and MR == 1:
212      const int8x8_t voutput_min = vld1_dup_s8(&params->${PARAMS_STRUCT}.output_min);
213    $else:
214      const int8x16_t voutput_min = vld1q_dup_s8(&params->${PARAMS_STRUCT}.output_min);
215    $for M in range(MR):
216      $for N in range(0, NR, 16):
217        $if N + 8 < NR:
218          vout${M}x${ABC[N:N+16]} = vmaxq_s8(vout${M}x${ABC[N:N+16]}, voutput_min);
219        $elif M % 2 == 1:
220          vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vmaxq_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}, voutput_min);
221        $elif M + 1 == MR:
222          $if NR == 8 and MR == 1:
223            vout${M}x${ABC[N:N+8]} = vmax_s8(vout${M}x${ABC[N:N+8]}, voutput_min);
224          $else:
225            vout${M}x${ABC[N:N+8]} = vmax_s8(vout${M}x${ABC[N:N+8]}, vget_low_s8(voutput_min));
226
227    $if NR == 8 and MR == 1:
228      const int8x8_t voutput_max = vld1_dup_s8(&params->${PARAMS_STRUCT}.output_max);
229    $else:
230      const int8x16_t voutput_max = vld1q_dup_s8(&params->${PARAMS_STRUCT}.output_max);
231    $for M in range(MR):
232      $for N in range(0, NR, 16):
233        $if N + 8 < NR:
234          vout${M}x${ABC[N:N+16]} = vminq_s8(vout${M}x${ABC[N:N+16]}, voutput_max);
235        $elif M % 2 == 1:
236          vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vminq_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}, voutput_max);
237        $elif M + 1 == MR:
238          $if NR == 8 and MR == 1:
239            vout${M}x${ABC[N:N+8]} = vmin_s8(vout${M}x${ABC[N:N+8]}, voutput_max);
240          $else:
241            vout${M}x${ABC[N:N+8]} = vmin_s8(vout${M}x${ABC[N:N+8]}, vget_low_s8(voutput_max));
242
243    if (nc >= ${NR}) {
244      $for M in range(MR):
245        $for N in range(0, NR, 16):
246          $if N + 8 < NR:
247            vst1q_s8(c${M} + ${N}, vout${M}x${ABC[N:N+16]});
248          $elif M % 2 == 1:
249            vst1_s8(c${M-1} + ${N}, vget_low_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}));
250            vst1_s8(c${M} + ${N}, vget_high_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}));
251          $elif M + 1 == MR:
252            vst1_s8(c${M} + ${N}, vout${M}x${ABC[N:N+8]});
253
254      $for M in range(MR):
255        c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride);
256
257      $for M in range(MR):
258        a${M} = (const int8_t*) ((uintptr_t) a${M} - kc);
259
260      nc -= ${NR};
261    } else {
262      // Final case where not all of the ${NR} columns fit in the destination.
263      $if NR == 16:
264        $for M in range(MR):
265          $if M % 2 == 1:
266            int8x16_t vout${M-1}x01234567_${M}x01234567 = vcombine_s8(vget_low_s8(vout${M-1}x0123456789ABCDEF), vget_low_s8(vout${M}x0123456789ABCDEF));
267          $elif M + 1 == MR:
268            int8x8_t vout${M}x01234567 = vget_low_s8(vout${M}x0123456789ABCDEF);
269        if (nc & 8) {
270          $for M in range(MR):
271            $if M % 2 == 1:
272              vst1_s8(c${M-1}, vget_low_s8(vout${M-1}x01234567_${M}x01234567)); c${M-1} += 8;
273              vst1_s8(c${M}, vget_high_s8(vout${M-1}x01234567_${M}x01234567)); c${M} += 8;
274            $elif M + 1 == MR:
275              vst1_s8(c${M}, vout${M}x01234567); c${M} += 8;
276          $for M in range(MR):
277            $if M % 2 == 1:
278              vout${M-1}x01234567_${M}x01234567 = vcombine_s8(vget_high_s8(vout${M-1}x0123456789ABCDEF), vget_high_s8(vout${M}x0123456789ABCDEF));
279            $elif M + 1 == MR:
280              vout${M}x01234567 = vget_high_s8(vout${M}x0123456789ABCDEF);
281        }
282      if (nc & 4) {
283        $for M in range(MR):
284          $if M % 2 == 1:
285            vst1q_lane_u32((void*) c${M-1}, vreinterpretq_u32_s8(vout${M-1}x01234567_${M}x01234567), 0); c${M-1} += 4;
286            vst1q_lane_u32((void*) c${M}, vreinterpretq_u32_s8(vout${M-1}x01234567_${M}x01234567), 2); c${M} += 4;
287          $elif M + 1 == MR:
288            vst1_lane_u32((void*) c${M}, vreinterpret_u32_s8(vout${M}x01234567), 0); c${M} += 4;
289        $for M in range(MR):
290          $if M % 2 == 1:
291            vout${M-1}x01234567_${M}x01234567 = vextq_s8(vout${M-1}x01234567_${M}x01234567, vout${M-1}x01234567_${M}x01234567, 4);
292          $elif M + 1 == MR:
293            vout${M}x01234567 = vext_s8(vout${M}x01234567, vout${M}x01234567, 4);
294      }
295      if (nc & 2) {
296        $for M in range(MR):
297          $if M % 2 == 1:
298            vst1q_lane_u16((void*) c${M-1}, vreinterpretq_u16_s8(vout${M-1}x01234567_${M}x01234567), 0); c${M-1} += 2;
299            vst1q_lane_u16((void*) c${M}, vreinterpretq_u16_s8(vout${M-1}x01234567_${M}x01234567), 4); c${M} += 2;
300          $elif M + 1 == MR:
301            vst1_lane_u16((void*) c${M}, vreinterpret_u16_s8(vout${M}x01234567), 0); c${M} += 2;
302        $for M in range(MR):
303          $if M % 2 == 1:
304            vout${M-1}x01234567_${M}x01234567 = vextq_s8(vout${M-1}x01234567_${M}x01234567, vout${M-1}x01234567_${M}x01234567, 2);
305          $elif M + 1 == MR:
306            vout${M}x01234567 = vext_s8(vout${M}x01234567, vout${M}x01234567, 2);
307      }
308      if (nc & 1) {
309        $for M in range(MR):
310          $if M % 2 == 1:
311            vst1q_lane_s8(c${M-1}, vout${M-1}x01234567_${M}x01234567, 0);
312            vst1q_lane_s8(c${M}, vout${M-1}x01234567_${M}x01234567, 8);
313          $elif M + 1 == MR:
314            vst1_lane_s8(c${M}, vout${M}x01234567, 0);
315      }
316
317      nc = 0;
318    }
319  } while (nc != 0);
320}
321