xref: /aosp_15_r20/external/XNNPACK/src/qs8-gemm/MRx4c2s4-sse.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2022 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert SSE in [2, 4]
7$assert not XOP or AVX
8$assert not AVX or SSE == 4
9$assert REQUANTIZATION == "FP32"
10$assert DATATYPE in ["QC8", "QS8", "QU8"]
11$assert VARIANT in ["LD64", "LD128", "EXTENDED"]
12$assert MR <= 4
13#include <assert.h>
14
15$if XOP:
16  #if defined(__GNUC__) || defined(__clang__)
17    #include <x86intrin.h>
18  #else
19    #include <immintrin.h>
20    #include <ammintrin.h>
21  #endif
22$else:
23  $SSE_HEADER = {2: "emmintrin.h", 4: "smmintrin.h"}[SSE]
24  #include <${SSE_HEADER}>
25
26#include <xnnpack/gemm.h>
27#include <xnnpack/math.h>
28#include <xnnpack/unaligned.h>
29
30
31
32$LOAD_SUFFIX = {"LD128": "_ld128", "LD64": "_ld64", "EXTENDED": ""}[VARIANT]
33$GEMM_SUFFIX = "_xw" if VARIANT == "EXTENDED" else ""
34$PARAMS_STRUCT = REQUANTIZATION.lower() + "_" + ("sse4" if SSE == 4 and DATATYPE != "QU8" else "sse2")
35$PARAMS_UNION = "xnn_%s_conv_minmax_params" % DATATYPE.lower()
36$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t"
37$ISA = "xop" if XOP else "avx" if AVX else {2: "sse2", 4: "sse41"}[SSE]
38void xnn_${DATATYPE.lower()}_gemm${GEMM_SUFFIX}_minmax_fp32_ukernel_${MR}x4c2s4__${ISA}${LOAD_SUFFIX}(
39    size_t mr,
40    size_t nc,
41    size_t kc,
42    const ${XINT8_T}* restrict a,
43    size_t a_stride,
44    const void* restrict w,
45    ${XINT8_T}* restrict c,
46    size_t cm_stride,
47    size_t cn_stride,
48    const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
49{
50  assert(mr != 0);
51  assert(mr <= ${MR});
52  assert(nc != 0);
53  assert(kc != 0);
54  assert(kc % sizeof(${XINT8_T}) == 0);
55  assert(a != NULL);
56  assert(w != NULL);
57  assert(c != NULL);
58
59  kc = round_up_po2(kc, 8 * sizeof(${XINT8_T}));
60  const ${XINT8_T}* a0 = a;
61  ${XINT8_T}* c0 = c;
62  $for M in range(1, MR):
63    const ${XINT8_T}* a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M-1} + a_stride);
64    ${XINT8_T}* c${M} = (${XINT8_T}*) ((uintptr_t) c${M-1} + cm_stride);
65    $if M % 2 == 0:
66      if XNN_UNPREDICTABLE(mr <= ${M}) {
67        a${M} = a${M-1};
68        c${M} = c${M-1};
69      }
70    $elif M + 1 == MR:
71      if XNN_UNPREDICTABLE(mr != ${M+1}) {
72        a${M} = a${M-1};
73        c${M} = c${M-1};
74      }
75    $else:
76      if XNN_UNPREDICTABLE(mr < ${M+1}) {
77        a${M} = a${M-1};
78        c${M} = c${M-1};
79      }
80
81  do {
82    __m128i vacc0x0123 = _mm_loadu_si128((const __m128i*) w);
83    $for M in range(1, MR):
84      __m128i vacc${M}x0123 = vacc0x0123;
85    w = (const void*) ((const int32_t*) w + 4);
86
87    size_t k = kc;
88    $if DATATYPE == "QU8":
89      const __m128i vb_zero_point = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.kernel_zero_point);
90      $if SSE < 4 or VARIANT == "LD128":
91        const __m128i vzero = _mm_setzero_si128();
92    do {
93      $for M in range(MR):
94        const __m128i va${M} = _mm_loadl_epi64((const __m128i*) a${M});
95        $if DATATYPE == "QU8":
96          $if SSE == 4:
97            __m128i vxa${M} = _mm_cvtepu8_epi16(va${M});
98          $else:
99            __m128i vxa${M} = _mm_unpacklo_epi8(va${M}, vzero);
100        $else:
101          $if SSE == 4:
102            __m128i vxa${M} = _mm_cvtepi8_epi16(va${M});
103          $else:
104            __m128i vxa${M} = _mm_srai_epi16(_mm_unpacklo_epi8(va${M}, va${M}), 8);
105        a${M} += 8;
106
107      $if VARIANT == "LD128":
108        $for K in range(0, 4, 2):
109          $if K == 0:
110            const __m128i vb${K}${K+1} = _mm_loadu_si128((const __m128i*) w);
111          $else:
112            const __m128i vb${K}${K+1} = _mm_loadu_si128((const __m128i*) ((const ${XINT8_T}*) w + ${K * 8}));
113          $if DATATYPE == "QU8":
114            const __m128i vxb${K} = _mm_sub_epi16(_mm_unpacklo_epi8(vb${K}${K+1}, vzero), vb_zero_point);
115            const __m128i vxb${K+1} = _mm_sub_epi16(_mm_unpackhi_epi8(vb${K}${K+1}, vzero), vb_zero_point);
116          $elif SSE == 4:
117            const __m128i vxb${K} = _mm_cvtepi8_epi16(vb${K}${K+1});
118            const __m128i vxb${K+1} = _mm_srai_epi16(_mm_unpackhi_epi8(vb${K}${K+1}, vb${K}${K+1}), 8);
119          $else:
120            const __m128i vsb${K}${K+1} = _mm_cmpgt_epi8(_mm_setzero_si128(), vb${K}${K+1});
121            const __m128i vxb${K} = _mm_unpacklo_epi8(vb${K}${K+1}, vsb${K}${K+1});
122            const __m128i vxb${K+1} = _mm_unpackhi_epi8(vb${K}${K+1}, vsb${K}${K+1});
123
124          $for M in range(MR):
125            $if XOP:
126              vacc${M}x0123 = _mm_maddd_epi16(vxa${M}, vxb${K}, vacc${M}x0123);
127            $else:
128              vacc${M}x0123 = _mm_add_epi32(vacc${M}x0123, _mm_madd_epi16(vxa${M}, vxb${K}));
129            vxa${M} = _mm_shuffle_epi32(vxa${M}, _MM_SHUFFLE(0, 3, 2, 1));
130
131          $for M in range(MR):
132            $if XOP:
133              vacc${M}x0123 = _mm_maddd_epi16(vxa${M}, vxb${K+1}, vacc${M}x0123);
134            $else:
135              vacc${M}x0123 = _mm_add_epi32(vacc${M}x0123, _mm_madd_epi16(vxa${M}, vxb${K+1}));
136            $if K + 2 != 4:
137              vxa${M} = _mm_shuffle_epi32(vxa${M}, _MM_SHUFFLE(0, 3, 2, 1));
138      $else:
139        $for K in range(4):
140          $if VARIANT == "LD64":
141            $if K == 0:
142              const __m128i vb${K} = _mm_loadl_epi64((const __m128i*) w);
143            $else:
144              const __m128i vb${K} = _mm_loadl_epi64((const __m128i*) ((const ${XINT8_T}*) w + ${K * 8}));
145            $if DATATYPE == "QU8":
146              $if SSE == 4:
147                const __m128i vxb${K} = _mm_sub_epi16(_mm_cvtepu8_epi16(vb${K}), vb_zero_point);
148              $else:
149                const __m128i vxb${K} = _mm_sub_epi16(_mm_unpacklo_epi8(vb${K}, vzero), vb_zero_point);
150            $else:
151              $if SSE == 4:
152                const __m128i vxb${K} = _mm_cvtepi8_epi16(vb${K});
153              $else:
154                const __m128i vxb${K} = _mm_srai_epi16(_mm_unpacklo_epi8(vb${K}, vb${K}), 8);
155          $elif VARIANT == "EXTENDED":
156            $if K == 0:
157              const __m128i vxb${K} = _mm_load_si128((const __m128i*) w);
158            $else:
159              const __m128i vxb${K} = _mm_load_si128((const __m128i*) ((const int16_t*) w + ${K * 8}));
160
161          $for M in range(MR):
162            $if XOP:
163              vacc${M}x0123 = _mm_maddd_epi16(vxa${M}, vxb${K}, vacc${M}x0123);
164            $else:
165              vacc${M}x0123 = _mm_add_epi32(vacc${M}x0123, _mm_madd_epi16(vxa${M}, vxb${K}));
166            $if K + 1 != 4:
167              vxa${M} = _mm_shuffle_epi32(vxa${M}, _MM_SHUFFLE(0, 3, 2, 1));
168
169      $if VARIANT == "EXTENDED":
170        w = (const void*) ((const int16_t*) w + 32);
171      $else:
172        w = (const void*) ((const ${XINT8_T}*) w + 32);
173      k -= 8 * sizeof(${XINT8_T});
174    } while (k != 0);
175
176    $for M in range(MR):
177      __m128 vscaled${M}x0123 = _mm_cvtepi32_ps(vacc${M}x0123);
178
179    $if DATATYPE == "QC8":
180      const __m128 vscale0123 = _mm_loadu_ps((const float*) w);
181      w = (const void*) ((const float*) w + 4);
182      $for M in range(MR):
183        vscaled${M}x0123 = _mm_mul_ps(vscaled${M}x0123, vscale0123);
184    $else:
185      const __m128 vscale = _mm_load_ps(params->${PARAMS_STRUCT}.scale);
186      $for M in range(MR):
187        vscaled${M}x0123 = _mm_mul_ps(vscaled${M}x0123, vscale);
188
189    const __m128 voutput_max_less_zero_point = _mm_load_ps(params->${PARAMS_STRUCT}.output_max_less_zero_point);
190    $for M in range(MR):
191      vscaled${M}x0123 = _mm_min_ps(vscaled${M}x0123, voutput_max_less_zero_point);
192
193    $for M in range(MR):
194      vacc${M}x0123 = _mm_cvtps_epi32(vscaled${M}x0123);
195
196    const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_zero_point);
197    $for M in range(0, MR, 2):
198      __m128i vacc${M}${min(M+1, MR-1)}x0123 = _mm_adds_epi16(_mm_packs_epi32(vacc${M}x0123, vacc${min(M+1, MR-1)}x0123), voutput_zero_point);
199
200    $if DATATYPE == "QU8":
201      $if MR > 2:
202        __m128i vout = _mm_packus_epi16(vacc0${min(1, MR-1)}x0123, vacc${min(2, MR-1)}${min(3, MR-1)}x0123);
203      $else:
204        __m128i vout = _mm_packus_epi16(vacc0${min(1, MR-1)}x0123, vacc0${min(1, MR-1)}x0123);
205
206      vout = _mm_max_epu8(vout, _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min));
207    $else:
208      $if SSE < 4:
209        const __m128i voutput_min = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min);
210        $for M in range(0, MR, 2):
211          vacc${M}${min(M+1, MR-1)}x0123 = _mm_max_epi16(vacc${M}${min(M+1, MR-1)}x0123, voutput_min);
212
213      $if MR > 2:
214        __m128i vout = _mm_packs_epi16(vacc0${min(1, MR-1)}x0123, vacc${min(2, MR-1)}${min(3, MR-1)}x0123);
215      $else:
216        __m128i vout = _mm_packs_epi16(vacc0${min(1, MR-1)}x0123, vacc0${min(1, MR-1)}x0123);
217
218      $if SSE == 4:
219        vout = _mm_max_epi8(vout, _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min));
220
221    if (nc >= 4) {
222      unaligned_store_u32(c0, (uint32_t) _mm_cvtsi128_si32(vout));
223      $for M in range(1, MR):
224        $if SSE == 4:
225          unaligned_store_u32(c${M}, (uint32_t) _mm_extract_epi32(vout, ${M}));
226        $else:
227          vout = _mm_shuffle_epi32(vout, _MM_SHUFFLE(0, 3, 2, 1));
228          unaligned_store_u32(c${M}, (uint32_t) _mm_cvtsi128_si32(vout));
229
230      $for M in range(MR):
231        c${M} = (${XINT8_T}*) ((uintptr_t) c${M} + cn_stride);
232
233      $for M in range(MR):
234        a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} - kc);
235
236      nc -= 4;
237    } else {
238      if (nc & 2) {
239        $for M in range(MR):
240          unaligned_store_u16(c${M}, (uint16_t) _mm_extract_epi16(vout, ${M * 2}));
241          c${M} += 2;
242        vout = _mm_srli_epi32(vout, 16);
243      }
244      if (nc & 1) {
245        $if SSE == 4:
246          $for M in range(MR):
247            *c${M} = (${XINT8_T}) _mm_extract_epi8(vout, ${M * 4});
248        $else:
249          *c0 = (${XINT8_T}) _mm_cvtsi128_si32(vout);
250          $for M in range(1, MR):
251            *c${M} = (${XINT8_T}) _mm_extract_epi16(vout, ${M * 2});
252      }
253
254      nc = 0;
255    }
256  } while (nc != 0);
257}
258