xref: /aosp_15_r20/external/XNNPACK/src/qs8-gemm/MRx4c2-sse.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert SSE in [2, 4]
7$assert not XOP or AVX
8$assert not AVX or SSE == 4
9$assert REQUANTIZATION == "FP32"
10$assert DATATYPE in ["QC8", "QS8", "QU8"]
11$assert VARIANT in ["LD64", "LD128", "EXTENDED"]
12$assert MR <= 4
13#include <assert.h>
14
15$if XOP:
16  #if defined(__GNUC__) || defined(__clang__)
17    #include <x86intrin.h>
18  #else
19    #include <immintrin.h>
20    #include <ammintrin.h>
21  #endif
22$else:
23  $SSE_HEADER = {2: "emmintrin.h", 4: "smmintrin.h"}[SSE]
24  #include <${SSE_HEADER}>
25
26#include <xnnpack/gemm.h>
27#include <xnnpack/math.h>
28#include <xnnpack/unaligned.h>
29
30
31
32$LOAD_SUFFIX = {"LD128": "_ld128", "LD64": "_ld64", "EXTENDED": ""}[VARIANT]
33$GEMM_SUFFIX = "_xw" if VARIANT == "EXTENDED" else ""
34$PARAMS_STRUCT = REQUANTIZATION.lower() + "_" + ("sse4" if SSE == 4 and DATATYPE != "QU8" else "sse2")
35$PARAMS_UNION = "xnn_%s_conv_minmax_params" % DATATYPE.lower()
36$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t"
37$ISA = "xop" if XOP else "avx" if AVX else {2: "sse2", 4: "sse41"}[SSE]
38void xnn_${DATATYPE.lower()}_gemm${GEMM_SUFFIX}_minmax_fp32_ukernel_${MR}x4c2__${ISA}${LOAD_SUFFIX}(
39    size_t mr,
40    size_t nc,
41    size_t kc,
42    const ${XINT8_T}* restrict a,
43    size_t a_stride,
44    const void* restrict w,
45    ${XINT8_T}* restrict c,
46    size_t cm_stride,
47    size_t cn_stride,
48    const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
49{
50  assert(mr != 0);
51  assert(mr <= ${MR});
52  assert(nc != 0);
53  assert(kc != 0);
54  assert(kc % sizeof(${XINT8_T}) == 0);
55  assert(a != NULL);
56  assert(w != NULL);
57  assert(c != NULL);
58
59  kc = round_up_po2(kc, 2 * sizeof(${XINT8_T}));
60  const ${XINT8_T}* a0 = a;
61  ${XINT8_T}* c0 = c;
62  $for M in range(1, MR):
63    const ${XINT8_T}* a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M-1} + a_stride);
64    ${XINT8_T}* c${M} = (${XINT8_T}*) ((uintptr_t) c${M-1} + cm_stride);
65    $if M % 2 == 0:
66      if XNN_UNPREDICTABLE(mr <= ${M}) {
67        a${M} = a${M-1};
68        c${M} = c${M-1};
69      }
70    $elif M + 1 == MR:
71      if XNN_UNPREDICTABLE(mr != ${M+1}) {
72        a${M} = a${M-1};
73        c${M} = c${M-1};
74      }
75    $else:
76      if XNN_UNPREDICTABLE(mr < ${M+1}) {
77        a${M} = a${M-1};
78        c${M} = c${M-1};
79      }
80
81  do {
82    __m128i vacc0x0123 = _mm_loadu_si128((const __m128i*) w);
83    $for M in range(1, MR):
84      __m128i vacc${M}x0123 = vacc0x0123;
85    w = (const void*) ((const int32_t*) w + 4);
86
87    size_t k = kc;
88    $if DATATYPE == "QU8":
89      const __m128i vb_zero_point = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.kernel_zero_point);
90      $if SSE < 4 or VARIANT == "LD128":
91        const __m128i vzero = _mm_setzero_si128();
92    while (k >= 8 * sizeof(${XINT8_T})) {
93      $for M in range(MR):
94        const __m128i va${M} = _mm_loadl_epi64((const __m128i*) a${M});
95        $if DATATYPE == "QU8":
96          $if SSE == 4:
97            const __m128i vxa${M} = _mm_cvtepu8_epi16(va${M});
98          $else:
99            const __m128i vxa${M} = _mm_unpacklo_epi8(va${M}, vzero);
100        $else:
101          $if SSE == 4:
102            const __m128i vxa${M} = _mm_cvtepi8_epi16(va${M});
103          $else:
104            const __m128i vxa${M} = _mm_srai_epi16(_mm_unpacklo_epi8(va${M}, va${M}), 8);
105        a${M} += 8;
106
107      $if VARIANT == "LD128":
108        $for K in range(0, 4, 2):
109          $if K == 0:
110            const __m128i vb${K}${K+1} = _mm_loadu_si128((const __m128i*) w);
111          $else:
112            const __m128i vb${K}${K+1} = _mm_loadu_si128((const __m128i*) ((const ${XINT8_T}*) w + ${K * 8}));
113          $if DATATYPE == "QU8":
114            const __m128i vxb${K} = _mm_sub_epi16(_mm_unpacklo_epi8(vb${K}${K+1}, vzero), vb_zero_point);
115            const __m128i vxb${K+1} = _mm_sub_epi16(_mm_unpackhi_epi8(vb${K}${K+1}, vzero), vb_zero_point);
116          $elif SSE == 4:
117            const __m128i vxb${K} = _mm_cvtepi8_epi16(vb${K}${K+1});
118            const __m128i vxb${K+1} = _mm_srai_epi16(_mm_unpackhi_epi8(vb${K}${K+1}, vb${K}${K+1}), 8);
119          $else:
120            const __m128i vsb${K}${K+1} = _mm_cmpgt_epi8(_mm_setzero_si128(), vb${K}${K+1});
121            const __m128i vxb${K} = _mm_unpacklo_epi8(vb${K}${K+1}, vsb${K}${K+1});
122            const __m128i vxb${K+1} = _mm_unpackhi_epi8(vb${K}${K+1}, vsb${K}${K+1});
123
124          $for M in range(MR):
125            $if XOP:
126              vacc${M}x0123 = _mm_maddd_epi16(
127                _mm_shuffle_epi32(vxa${M}, _MM_SHUFFLE(${K}, ${K}, ${K}, ${K})), vxb${K}, vacc${M}x0123);
128            $else:
129              vacc${M}x0123 = _mm_add_epi32(vacc${M}x0123,
130                _mm_madd_epi16(_mm_shuffle_epi32(vxa${M}, _MM_SHUFFLE(${K}, ${K}, ${K}, ${K})), vxb${K}));
131
132          $for M in range(MR):
133            $if XOP:
134              vacc${M}x0123 = _mm_maddd_epi16(
135                _mm_shuffle_epi32(vxa${M}, _MM_SHUFFLE(${K+1}, ${K+1}, ${K+1}, ${K+1})), vxb${K+1}, vacc${M}x0123);
136            $else:
137              vacc${M}x0123 = _mm_add_epi32(vacc${M}x0123,
138                _mm_madd_epi16(_mm_shuffle_epi32(vxa${M}, _MM_SHUFFLE(${K+1}, ${K+1}, ${K+1}, ${K+1})), vxb${K+1}));
139      $else:
140        $for K in range(4):
141          $if VARIANT == "LD64":
142            $if K == 0:
143              const __m128i vb${K} = _mm_loadl_epi64((const __m128i*) w);
144            $else:
145              const __m128i vb${K} = _mm_loadl_epi64((const __m128i*) ((const ${XINT8_T}*) w + ${K * 8}));
146            $if DATATYPE == "QU8":
147              $if SSE == 4:
148                const __m128i vxb${K} = _mm_sub_epi16(_mm_cvtepu8_epi16(vb${K}), vb_zero_point);
149              $else:
150                const __m128i vxb${K} = _mm_sub_epi16(_mm_unpacklo_epi8(vb${K}, vzero), vb_zero_point);
151            $else:
152              $if SSE == 4:
153                const __m128i vxb${K} = _mm_cvtepi8_epi16(vb${K});
154              $else:
155                const __m128i vxb${K} = _mm_srai_epi16(_mm_unpacklo_epi8(vb${K}, vb${K}), 8);
156          $elif VARIANT == "EXTENDED":
157            $if K == 0:
158              const __m128i vxb${K} = _mm_load_si128((const __m128i*) w);
159            $else:
160              const __m128i vxb${K} = _mm_load_si128((const __m128i*) ((const int16_t*) w + ${K * 8}));
161
162          $for M in range(MR):
163            $if XOP:
164              vacc${M}x0123 = _mm_maddd_epi16(
165                _mm_shuffle_epi32(vxa${M}, _MM_SHUFFLE(${K}, ${K}, ${K}, ${K})), vxb${K}, vacc${M}x0123);
166            $else:
167              vacc${M}x0123 = _mm_add_epi32(vacc${M}x0123,
168                _mm_madd_epi16(_mm_shuffle_epi32(vxa${M}, _MM_SHUFFLE(${K}, ${K}, ${K}, ${K})), vxb${K}));
169
170      $if VARIANT == "EXTENDED":
171        w = (const void*) ((const int16_t*) w + 32);
172      $else:
173        w = (const void*) ((const ${XINT8_T}*) w + 32);
174      k -= 8 * sizeof(${XINT8_T});
175    }
176    if (k != 0) {
177      $for M in range(MR):
178        const __m128i va${M} = _mm_loadl_epi64((const __m128i*) a${M});
179        $if DATATYPE == "QU8":
180          $if SSE == 4:
181            const __m128i vxa${M} = _mm_cvtepu8_epi16(va${M});
182          $else:
183            const __m128i vxa${M} = _mm_unpacklo_epi8(va${M}, vzero);
184        $else:
185          $if SSE == 4:
186            const __m128i vxa${M} = _mm_cvtepi8_epi16(va${M});
187          $else:
188            const __m128i vxa${M} = _mm_srai_epi16(_mm_unpacklo_epi8(va${M}, va${M}), 8);
189        a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} + k);
190
191      $if VARIANT == "EXTENDED":
192        const __m128i vxb0 = _mm_load_si128((const __m128i*) w);
193        w = (const void*) ((const int16_t*) w + 8);
194      $else:
195        const __m128i vb0 = _mm_loadl_epi64((const __m128i*) w);
196        $if DATATYPE == "QU8":
197          $if SSE == 4:
198            const __m128i vxb0 = _mm_sub_epi16(_mm_cvtepu8_epi16(vb0), vb_zero_point);
199          $else:
200            const __m128i vxb0 = _mm_sub_epi16(_mm_unpacklo_epi8(vb0, vzero), vb_zero_point);
201        $else:
202          $if SSE == 4:
203            const __m128i vxb0 = _mm_cvtepi8_epi16(vb0);
204          $else:
205            const __m128i vxb0 = _mm_srai_epi16(_mm_unpacklo_epi8(vb0, vb0), 8);
206        w = (const void*) ((const ${XINT8_T}*) w + 8);
207
208      $for M in range(MR):
209        $if XOP:
210          vacc${M}x0123 = _mm_maddd_epi16(
211            _mm_shuffle_epi32(vxa${M}, _MM_SHUFFLE(0, 0, 0, 0)), vxb0, vacc${M}x0123);
212        $else:
213          vacc${M}x0123 = _mm_add_epi32(vacc${M}x0123,
214            _mm_madd_epi16(_mm_shuffle_epi32(vxa${M}, _MM_SHUFFLE(0, 0, 0, 0)), vxb0));
215
216      if (k > 2 * sizeof(${XINT8_T})) {
217        $if VARIANT == "EXTENDED":
218          const __m128i vxb1 = _mm_load_si128((const __m128i*) w);
219          w = (const void*) ((const int16_t*) w + 8);
220        $else:
221          const __m128i vb1 = _mm_loadl_epi64((const __m128i*) w);
222          $if DATATYPE == "QU8":
223            $if SSE == 4:
224              const __m128i vxb1 = _mm_sub_epi16(_mm_cvtepu8_epi16(vb1), vb_zero_point);
225            $else:
226              const __m128i vxb1 = _mm_sub_epi16(_mm_unpacklo_epi8(vb1, vzero), vb_zero_point);
227          $else:
228            $if SSE == 4:
229              const __m128i vxb1 = _mm_cvtepi8_epi16(vb1);
230            $else:
231              const __m128i vxb1 = _mm_srai_epi16(_mm_unpacklo_epi8(vb1, vb1), 8);
232          w = (const void*) ((const ${XINT8_T}*) w + 8);
233
234        $for M in range(MR):
235          $if XOP:
236            vacc${M}x0123 = _mm_maddd_epi16(
237              _mm_shuffle_epi32(vxa${M}, _MM_SHUFFLE(1, 1, 1, 1)), vxb1, vacc${M}x0123);
238          $else:
239            vacc${M}x0123 = _mm_add_epi32(vacc${M}x0123,
240              _mm_madd_epi16(_mm_shuffle_epi32(vxa${M}, _MM_SHUFFLE(1, 1, 1, 1)), vxb1));
241
242        if (k > 4 * sizeof(${XINT8_T})) {
243          $if VARIANT == "EXTENDED":
244            const __m128i vxb2 = _mm_load_si128((const __m128i*) w);
245            w = (const void*) ((const int16_t*) w + 8);
246          $else:
247            const __m128i vb2 = _mm_loadl_epi64((const __m128i*) w);
248            $if DATATYPE == "QU8":
249              $if SSE == 4:
250                const __m128i vxb2 = _mm_sub_epi16(_mm_cvtepu8_epi16(vb2), vb_zero_point);
251              $else:
252                const __m128i vxb2 = _mm_sub_epi16(_mm_unpacklo_epi8(vb2, vzero), vb_zero_point);
253            $else:
254              $if SSE == 4:
255                const __m128i vxb2 = _mm_cvtepi8_epi16(vb2);
256              $else:
257                const __m128i vxb2 = _mm_srai_epi16(_mm_unpacklo_epi8(vb2, vb2), 8);
258            w = (const void*) ((const ${XINT8_T}*) w + 8);
259
260          $for M in range(MR):
261            $if XOP:
262              vacc${M}x0123 = _mm_maddd_epi16(
263                _mm_shuffle_epi32(vxa${M}, _MM_SHUFFLE(2, 2, 2, 2)), vxb2, vacc${M}x0123);
264            $else:
265              vacc${M}x0123 = _mm_add_epi32(vacc${M}x0123,
266                _mm_madd_epi16(_mm_shuffle_epi32(vxa${M}, _MM_SHUFFLE(2, 2, 2, 2)), vxb2));
267        }
268      }
269    }
270
271    $for M in range(MR):
272      __m128 vscaled${M}x0123 = _mm_cvtepi32_ps(vacc${M}x0123);
273
274    $if DATATYPE == "QC8":
275      const __m128 vscale0123 = _mm_loadu_ps((const float*) w);
276      w = (const void*) ((const float*) w + 4);
277      $for M in range(MR):
278        vscaled${M}x0123 = _mm_mul_ps(vscaled${M}x0123, vscale0123);
279    $else:
280      const __m128 vscale = _mm_load_ps(params->${PARAMS_STRUCT}.scale);
281      $for M in range(MR):
282        vscaled${M}x0123 = _mm_mul_ps(vscaled${M}x0123, vscale);
283
284    const __m128 voutput_max_less_zero_point = _mm_load_ps(params->${PARAMS_STRUCT}.output_max_less_zero_point);
285    $for M in range(MR):
286      vscaled${M}x0123 = _mm_min_ps(vscaled${M}x0123, voutput_max_less_zero_point);
287
288    $for M in range(MR):
289      vacc${M}x0123 = _mm_cvtps_epi32(vscaled${M}x0123);
290
291    const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_zero_point);
292    $for M in range(0, MR, 2):
293      __m128i vacc${M}${min(M+1, MR-1)}x0123 = _mm_adds_epi16(_mm_packs_epi32(vacc${M}x0123, vacc${min(M+1, MR-1)}x0123), voutput_zero_point);
294
295    $if DATATYPE == "QU8":
296      $if MR > 2:
297        __m128i vout = _mm_packus_epi16(vacc0${min(1, MR-1)}x0123, vacc${min(2, MR-1)}${min(3, MR-1)}x0123);
298      $else:
299        __m128i vout = _mm_packus_epi16(vacc0${min(1, MR-1)}x0123, vacc0${min(1, MR-1)}x0123);
300
301      vout = _mm_max_epu8(vout, _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min));
302    $else:
303      $if SSE < 4:
304        const __m128i voutput_min = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min);
305        $for M in range(0, MR, 2):
306          vacc${M}${min(M+1, MR-1)}x0123 = _mm_max_epi16(vacc${M}${min(M+1, MR-1)}x0123, voutput_min);
307
308      $if MR > 2:
309        __m128i vout = _mm_packs_epi16(vacc0${min(1, MR-1)}x0123, vacc${min(2, MR-1)}${min(3, MR-1)}x0123);
310      $else:
311        __m128i vout = _mm_packs_epi16(vacc0${min(1, MR-1)}x0123, vacc0${min(1, MR-1)}x0123);
312
313      $if SSE == 4:
314        vout = _mm_max_epi8(vout, _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min));
315
316    if (nc >= 4) {
317      unaligned_store_u32(c0, (uint32_t) _mm_cvtsi128_si32(vout));
318      $for M in range(1, MR):
319        $if SSE == 4:
320          unaligned_store_u32(c${M}, (uint32_t) _mm_extract_epi32(vout, ${M}));
321        $else:
322          vout = _mm_shuffle_epi32(vout, _MM_SHUFFLE(0, 3, 2, 1));
323          unaligned_store_u32(c${M}, (uint32_t) _mm_cvtsi128_si32(vout));
324
325      $for M in range(MR):
326        c${M} = (${XINT8_T}*) ((uintptr_t) c${M} + cn_stride);
327
328      $for M in range(MR):
329        a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} - kc);
330
331      nc -= 4;
332    } else {
333      if (nc & 2) {
334        $for M in range(MR):
335          unaligned_store_u16(c${M}, (uint16_t) _mm_extract_epi16(vout, ${M * 2}));
336          c${M} += 2;
337        vout = _mm_srli_epi32(vout, 16);
338      }
339      if (nc & 1) {
340        $if SSE == 4:
341          $for M in range(MR):
342            *c${M} = (${XINT8_T}) _mm_extract_epi8(vout, ${M * 4});
343        $else:
344          *c0 = (${XINT8_T}) _mm_cvtsi128_si32(vout);
345          $for M in range(1, MR):
346            *c${M} = (${XINT8_T}) _mm_extract_epi16(vout, ${M * 2});
347      }
348
349      nc = 0;
350    }
351  } while (nc != 0);
352}
353