xref: /aosp_15_r20/external/XNNPACK/src/qs8-gemm/c4-armsimd32.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2022 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert REQUANTIZATION == "FP32"
7$assert DATATYPE in ["QC8", "QS8", "QU8"]
8$assert 1 <= MR <= 2
9$assert 1 <= NR <= 2
10#include <assert.h>
11
12#include <arm_acle.h>
13
14#include <xnnpack/intrinsics-polyfill.h>
15#include <xnnpack/math.h>
16#include <xnnpack/gemm.h>
17#include <xnnpack/unaligned.h>
18
19
20$PARAMS_STRUCT = REQUANTIZATION.lower() + "_armsimd32"
21$PARAMS_UNION = "xnn_%s_conv_minmax_params" % DATATYPE.lower()
22$__XXTB16 = "__uxtb16" if DATATYPE == "QU8" else "__sxtb16"
23$__XSAT = "__usat" if DATATYPE == "QU8" else "__ssat"
24$__XSUB8 = "__usub8" if DATATYPE == "QU8" else "__ssub8"
25$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t"
26void xnn_${DATATYPE.lower()}_gemm_minmax_${REQUANTIZATION.lower()}_ukernel_${MR}x${NR}c4__armsimd32(
27    size_t mr,
28    size_t nc,
29    size_t kc,
30    const ${XINT8_T}* restrict a,
31    size_t a_stride,
32    const void* restrict w,
33    ${XINT8_T}* restrict c,
34    size_t cm_stride,
35    size_t cn_stride,
36    const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)])
37{
38  assert(mr != 0);
39  assert(mr <= ${MR});
40  assert(nc != 0);
41  assert(kc != 0);
42
43  kc = round_up_po2(kc, 4 * sizeof(int8_t));
44  const ${XINT8_T}* a0 = a;
45  ${XINT8_T}* c0 = c;
46  $for M in range(1, MR):
47    const ${XINT8_T}* a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M-1} + a_stride);
48    ${XINT8_T}* c${M} = (${XINT8_T}*) ((uintptr_t) c${M-1} + cm_stride);
49    $if M % 2 == 0:
50      if XNN_UNPREDICTABLE(mr <= ${M}) {
51        a${M} = a${M-1};
52        c${M} = c${M-1};
53      }
54    $elif M + 1 == MR:
55      if XNN_UNPREDICTABLE(mr != ${M+1}) {
56        a${M} = a${M-1};
57        c${M} = c${M-1};
58      }
59    $else:
60      if XNN_UNPREDICTABLE(mr < ${M+1}) {
61        a${M} = a${M-1};
62        c${M} = c${M-1};
63      }
64
65  $if DATATYPE == "QU8":
66    const int16x2_t vb_minus_zero_point = (int16x2_t) params->${PARAMS_STRUCT}.minus_kernel_zero_point;
67  $if REQUANTIZATION == "FP32":
68    $if DATATYPE != "QC8":
69      const float vscale = params->${PARAMS_STRUCT}.scale;
70    const float vmagic_bias = params->${PARAMS_STRUCT}.magic_bias;
71  do {
72    $for N in range(NR):
73      int32_t vacc0x${N} = ((const int32_t*) w)[${N}];
74    $for M in range(1, MR):
75      $for N in range(NR):
76        int32_t vacc${M}x${N} = vacc0x${N};
77    w = (const void*) ((const int32_t*) w + ${NR});
78
79    size_t k = kc;
80    do {
81      $for M in range(MR):
82        const int8x4_t va${M} = (int8x4_t) unaligned_load_s32(a${M}); a${M} += 4;
83
84      $for M in range(MR):
85        const int16x2_t va${M}c02 = ${__XXTB16}(va${M});
86        const int16x2_t va${M}c13 = ${__XXTB16}(__ror(va${M}, 8));
87
88      $for N in range(NR):
89        const int8x4_t vb${N} = *((const int8x4_t*) w); w = (const int8_t*) w + 4;
90        $if DATATYPE == "QU8":
91          const int16x2_t vb${N}c02 = __uxtab16(vb_minus_zero_point, vb${N});
92        $else:
93          const int16x2_t vb${N}c02 = __sxtb16(vb${N});
94
95        $for M in range(MR):
96          vacc${M}x${N} = __smlad(va${M}c02, vb${N}c02, vacc${M}x${N});
97
98        $if DATATYPE == "QU8":
99          const int16x2_t vb${N}c13 = __uxtab16(vb_minus_zero_point, __ror(vb${N}, 8));
100        $else:
101          const int16x2_t vb${N}c13 = __sxtb16(__ror(vb${N}, 8));
102        $for M in range(MR):
103          vacc${M}x${N} = __smlad(va${M}c13, vb${N}c13, vacc${M}x${N});
104
105      k -= 4 * sizeof(${XINT8_T});
106    } while (k != 0);
107
108    $for M in range(MR):
109      $for N in range(NR):
110        float vfpacc${M}x${N} = (float) vacc${M}x${N};
111
112    $if DATATYPE == "QC8":
113      $for N in range(NR):
114        const float vscale${N} = ((const float*) w)[${N}];
115        $for M in range(MR):
116          vfpacc${M}x${N} *= vscale${N};
117      w = (const void*) ((const float*) w + ${NR});
118    $else:
119      $for M in range(MR):
120        $for N in range(NR):
121          vfpacc${M}x${N} *= vscale;
122
123    $for M in range(MR):
124      $for N in range(NR):
125        vfpacc${M}x${N} += vmagic_bias;
126
127    $for M in range(MR):
128      $for N in range(NR):
129        int32_t vout${M}x${N} = (int32_t) float_as_uint32(vfpacc${M}x${N});
130
131    const int32_t vmagic_bias_less_zero_point = params->${PARAMS_STRUCT}.magic_bias_less_zero_point;
132    $for M in range(MR):
133      $for N in range(NR):
134        vout${M}x${N} = __qsub(vout${M}x${N}, vmagic_bias_less_zero_point);
135
136    $for M in range(MR):
137      $for N in range(NR):
138        vout${M}x${N} = ${__XSAT}(vout${M}x${N}, 8);
139
140    $for M in range(MR):
141      $if NR == 1:
142        const uint32_t vout${M} = (uint32_t) vout${M}x0;
143      $else:
144        const uint32_t vout${M} = (uint32_t) (uint8_t) vout${M}x0 | ((uint32_t) vout${M}x1 << 8);
145
146    $if MR == 1:
147      uint32_t vout = vout0;
148    $else:
149      uint32_t vout = (uint32_t) (uint16_t) vout0 | (vout1 << 16);
150
151    const int8x4_t voutput_min = (int8x4_t) params->${PARAMS_STRUCT}.output_min;
152    ${__XSUB8}((int8x4_t) vout, voutput_min);
153    vout = (uint32_t) __sel((uint8x4_t) vout, (uint8x4_t) voutput_min);
154
155    const int8x4_t voutput_max = (int8x4_t) params->${PARAMS_STRUCT}.output_max;
156    ${__XSUB8}((int8x4_t) vout, voutput_max);
157    vout = (uint32_t) __sel((uint8x4_t) voutput_max, (uint8x4_t) vout);
158
159    $if NR == 2:
160      if XNN_LIKELY(nc >= ${NR}) {
161        $for M in range(MR):
162          unaligned_store_u16(c${M}, (uint16_t) vout);
163          $if M + 1 != MR:
164            vout >>= 16;
165
166        $for M in range(MR):
167          a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} - kc);
168
169        $for M in range(MR):
170          c${M} = (${XINT8_T}*) ((uintptr_t) c${M} + cn_stride);
171
172        nc -= ${NR};
173      } else {
174        $for M in range(MR):
175          *c${M} = (${XINT8_T}) vout;
176          $if M + 1 != MR:
177            vout >>= 16;
178
179        nc = 0;
180      }
181    $else:
182      $for M in range(MR):
183        *c${M} = (${XINT8_T}) vout;
184        $if M + 1 != MR:
185          vout >>= 16;
186
187      $for M in range(MR):
188        a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} - kc);
189
190      $for M in range(MR):
191        c${M} = (${XINT8_T}*) ((uintptr_t) c${M} + cn_stride);
192
193      nc -= 1;
194  } while (nc != 0);
195}
196