xref: /aosp_15_r20/external/XNNPACK/src/qs8-gemm/MRx4c8-sse.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert SSE in [2, 3, 4]
7$assert not XOP or AVX
8$assert not AVX or SSE == 4
9$assert REQUANTIZATION == "FP32"
10$assert DATATYPE in ["QC8", "QS8", "QU8"]
11$assert VARIANT in ["LD64", "LD128", "EXTENDED"]
12$assert MR <= 4
13#include <assert.h>
14
15$if XOP:
16  #if defined(__GNUC__) || defined(__clang__)
17    #include <x86intrin.h>
18  #else
19    #include <immintrin.h>
20    #include <ammintrin.h>
21  #endif
22$else:
23  $SSE_HEADER = {2: "emmintrin.h", 3: "tmmintrin.h", 4: "smmintrin.h"}[SSE]
24  #include <${SSE_HEADER}>
25
26#include <xnnpack/gemm.h>
27#include <xnnpack/math.h>
28#include <xnnpack/unaligned.h>
29
30
31$LOAD_SUFFIX = {"LD128": "_ld128", "LD64": "_ld64", "EXTENDED": ""}[VARIANT]
32$GEMM_SUFFIX = "_xw" if VARIANT == "EXTENDED" else ""
33$PARAMS_STRUCT = REQUANTIZATION.lower() + "_" + ("sse4" if SSE == 4 and DATATYPE != "QU8" else "sse2")
34$PARAMS_UNION = "xnn_%s_conv_minmax_params" % DATATYPE.lower()
35$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t"
36$ISA = "xop" if XOP else "avx" if AVX else {2: "sse2", 3: "ssse3", 4: "sse41"}[SSE]
37void xnn_${DATATYPE.lower()}_gemm${GEMM_SUFFIX}_minmax_fp32_ukernel_${MR}x4c8__${ISA}${LOAD_SUFFIX}(
38    size_t mr,
39    size_t nc,
40    size_t kc,
41    const ${XINT8_T}* restrict a,
42    size_t a_stride,
43    const void* restrict w,
44    ${XINT8_T}* restrict c,
45    size_t cm_stride,
46    size_t cn_stride,
47    const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
48{
49  assert(mr != 0);
50  assert(mr <= ${MR});
51  assert(nc != 0);
52  assert(kc != 0);
53  assert(kc % sizeof(${XINT8_T}) == 0);
54  assert(a != NULL);
55  assert(w != NULL);
56  assert(c != NULL);
57
58  kc = round_up_po2(kc, 8);
59  const ${XINT8_T}* a0 = a;
60  ${XINT8_T}* c0 = c;
61  $for M in range(1, MR):
62    const ${XINT8_T}* a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M-1} + a_stride);
63    ${XINT8_T}* c${M} = (${XINT8_T}*) ((uintptr_t) c${M-1} + cm_stride);
64    $if M % 2 == 0:
65      if XNN_UNPREDICTABLE(mr <= ${M}) {
66        a${M} = a${M-1};
67        c${M} = c${M-1};
68      }
69    $elif M + 1 == MR:
70      if XNN_UNPREDICTABLE(mr != ${M+1}) {
71        a${M} = a${M-1};
72        c${M} = c${M-1};
73      }
74    $else:
75      if XNN_UNPREDICTABLE(mr < ${M+1}) {
76        a${M} = a${M-1};
77        c${M} = c${M-1};
78      }
79
80  do {
81    $for N in range(4):
82      __m128i vacc0x${N} = _mm_cvtsi32_si128(((const int*) w)[${N}]);
83    $for M in range(1, MR):
84      $for N in range(4):
85        __m128i vacc${M}x${N} = vacc0x${N};
86    w = (const int32_t*) w + 4;
87
88    size_t k = 0;
89    $if DATATYPE == "QU8":
90      const __m128i vb_zero_point = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.kernel_zero_point);
91      $if SSE < 4 or VARIANT == "LD128":
92        const __m128i vzero = _mm_setzero_si128();
93    while (k < kc) {
94      $for M in range(MR):
95        const __m128i va${M} = _mm_loadl_epi64((const __m128i*) a${M});
96        $if DATATYPE == "QU8":
97          $if SSE == 4:
98            const __m128i vxa${M} = _mm_cvtepu8_epi16(va${M});
99          $else:
100            const __m128i vxa${M} = _mm_unpacklo_epi8(va${M}, vzero);
101        $else:
102          $if SSE == 4:
103            const __m128i vxa${M} = _mm_cvtepi8_epi16(va${M});
104          $else:
105            const __m128i vxa${M} = _mm_srai_epi16(_mm_unpacklo_epi8(va${M}, va${M}), 8);
106        a${M} += 8;
107
108      $if VARIANT == "LD128":
109        $for N in range(0, 4, 2):
110          $if N == 0:
111            const __m128i vb${N}${N+1} = _mm_load_si128((const __m128i*) w);
112          $else:
113            const __m128i vb${N}${N+1} = _mm_load_si128((const __m128i*) ((const ${XINT8_T}*) w + ${N * 8}));
114          $if DATATYPE == "QU8":
115            const __m128i vxb${N} = _mm_sub_epi16(_mm_unpacklo_epi8(vb${N}${N+1}, vzero), vb_zero_point);
116            const __m128i vxb${N+1} = _mm_sub_epi16(_mm_unpackhi_epi8(vb${N}${N+1}, vzero), vb_zero_point);
117          $elif SSE == 4:
118            const __m128i vxb${N} = _mm_cvtepi8_epi16(vb${N}${N+1});
119            const __m128i vxb${N+1} = _mm_srai_epi16(_mm_unpackhi_epi8(vb${N}${N+1}, vb${N}${N+1}), 8);
120          $else:
121            const __m128i vsb${N}${N+1} = _mm_cmpgt_epi8(_mm_setzero_si128(), vb${N}${N+1});
122            const __m128i vxb${N} = _mm_unpacklo_epi8(vb${N}${N+1}, vsb${N}${N+1});
123            const __m128i vxb${N+1} = _mm_unpackhi_epi8(vb${N}${N+1}, vsb${N}${N+1});
124
125          $for M in range(MR):
126            $if XOP:
127              vacc${M}x${N} = _mm_maddd_epi16(vxa${M}, vxb${N}, vacc${M}x${N});
128              vacc${M}x${N+1} = _mm_maddd_epi16(vxa${M}, vxb${N+1}, vacc${M}x${N+1});
129            $else:
130              vacc${M}x${N} = _mm_add_epi32(vacc${M}x${N}, _mm_madd_epi16(vxa${M}, vxb${N}));
131              vacc${M}x${N+1} = _mm_add_epi32(vacc${M}x${N+1}, _mm_madd_epi16(vxa${M}, vxb${N+1}));
132      $else:
133        $for N in range(4):
134          $if VARIANT == "LD64":
135            $if N == 0:
136              const __m128i vb${N} = _mm_loadl_epi64((const __m128i*) w);
137            $else:
138              const __m128i vb${N} = _mm_loadl_epi64((const __m128i*) ((const ${XINT8_T}*) w + ${N * 8}));
139            $if DATATYPE == "QU8":
140              $if SSE == 4:
141                const __m128i vxb${N} = _mm_sub_epi16(_mm_cvtepu8_epi16(vb${N}), vb_zero_point);
142              $else:
143                const __m128i vxb${N} = _mm_sub_epi16(_mm_unpacklo_epi8(vb${N}, vzero), vb_zero_point);
144            $else:
145              $if SSE == 4:
146                const __m128i vxb${N} = _mm_cvtepi8_epi16(vb${N});
147              $else:
148                const __m128i vxb${N} = _mm_srai_epi16(_mm_unpacklo_epi8(vb${N}, vb${N}), 8);
149          $elif VARIANT == "EXTENDED":
150            $if N == 0:
151              const __m128i vxb${N} = _mm_load_si128((const __m128i*) w);
152            $else:
153              const __m128i vxb${N} = _mm_load_si128((const __m128i*) ((const int16_t*) w + ${N * 8}));
154
155          $for M in range(MR):
156            $if XOP:
157              vacc${M}x${N} = _mm_maddd_epi16(vxa${M}, vxb${N}, vacc${M}x${N});
158            $else:
159              vacc${M}x${N} = _mm_add_epi32(vacc${M}x${N}, _mm_madd_epi16(vxa${M}, vxb${N}));
160
161      $if VARIANT == "EXTENDED":
162        w = (const void*) ((const int16_t*) w + 32);
163      $else:
164        w = (const void*) ((const ${XINT8_T}*) w + 32);
165      k += 8 * sizeof(${XINT8_T});
166    }
167
168    $if SSE >= 3:
169      $for M in range(MR):
170        const __m128i vacc${M}x01 = _mm_hadd_epi32(vacc${M}x0, vacc${M}x1);
171        const __m128i vacc${M}x23 = _mm_hadd_epi32(vacc${M}x2, vacc${M}x3);
172
173      $for M in range(MR):
174        __m128i vacc${M}x0123 = _mm_hadd_epi32(vacc${M}x01, vacc${M}x23);
175    $else:
176      $for M in range(MR):
177        const __m128i vacc${M}x02 = _mm_add_epi32(_mm_unpacklo_epi32(vacc${M}x0, vacc${M}x2), _mm_unpackhi_epi32(vacc${M}x0, vacc${M}x2));
178        const __m128i vacc${M}x13 = _mm_add_epi32(_mm_unpacklo_epi32(vacc${M}x1, vacc${M}x3), _mm_unpackhi_epi32(vacc${M}x1, vacc${M}x3));
179
180      $for M in range(MR):
181        __m128i vacc${M}x0123 = _mm_add_epi32(_mm_unpacklo_epi32(vacc${M}x02, vacc${M}x13), _mm_unpackhi_epi32(vacc${M}x02, vacc${M}x13));
182
183    $for M in range(MR):
184      __m128 vscaled${M}x0123 = _mm_cvtepi32_ps(vacc${M}x0123);
185
186    $if DATATYPE == "QC8":
187      const __m128 vscale0123 = _mm_load_ps((const float*) w);
188      w = (const void*) ((const float*) w + 4);
189      $for M in range(MR):
190        vscaled${M}x0123 = _mm_mul_ps(vscaled${M}x0123, vscale0123);
191    $else:
192      const __m128 vscale = _mm_load_ps(params->${PARAMS_STRUCT}.scale);
193      $for M in range(MR):
194        vscaled${M}x0123 = _mm_mul_ps(vscaled${M}x0123, vscale);
195
196    const __m128 voutput_max_less_zero_point = _mm_load_ps(params->${PARAMS_STRUCT}.output_max_less_zero_point);
197    $for M in range(MR):
198      vscaled${M}x0123 = _mm_min_ps(vscaled${M}x0123, voutput_max_less_zero_point);
199
200    $for M in range(MR):
201      vacc${M}x0123 = _mm_cvtps_epi32(vscaled${M}x0123);
202
203    const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_zero_point);
204    $for M in range(0, MR, 2):
205      __m128i vacc${M}${min(M+1, MR-1)}x0123 = _mm_adds_epi16(_mm_packs_epi32(vacc${M}x0123, vacc${min(M+1, MR-1)}x0123), voutput_zero_point);
206
207    $if DATATYPE == "QU8":
208      $if MR > 2:
209        __m128i vout = _mm_packus_epi16(vacc0${min(1, MR-1)}x0123, vacc${min(2, MR-1)}${min(3, MR-1)}x0123);
210      $else:
211        __m128i vout = _mm_packus_epi16(vacc0${min(1, MR-1)}x0123, vacc0${min(1, MR-1)}x0123);
212
213      vout = _mm_max_epu8(vout, _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min));
214    $else:
215      $if SSE < 4:
216        const __m128i voutput_min = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min);
217        $for M in range(0, MR, 2):
218          vacc${M}${min(M+1, MR-1)}x0123 = _mm_max_epi16(vacc${M}${min(M+1, MR-1)}x0123, voutput_min);
219
220      $if MR > 2:
221        __m128i vout = _mm_packs_epi16(vacc0${min(1, MR-1)}x0123, vacc${min(2, MR-1)}${min(3, MR-1)}x0123);
222      $else:
223        __m128i vout = _mm_packs_epi16(vacc0${min(1, MR-1)}x0123, vacc0${min(1, MR-1)}x0123);
224
225      $if SSE == 4:
226        vout = _mm_max_epi8(vout, _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min));
227
228    if (nc >= 4) {
229      unaligned_store_u32(c0, (uint32_t) _mm_cvtsi128_si32(vout));
230      $for M in range(1, MR):
231        $if SSE == 4:
232          unaligned_store_u32(c${M}, (uint32_t) _mm_extract_epi32(vout, ${M}));
233        $else:
234          vout = _mm_srli_si128(vout, 4);
235          unaligned_store_u32(c${M}, (uint32_t) _mm_cvtsi128_si32(vout));
236
237      $for M in range(MR):
238        c${M} = (${XINT8_T}*) ((uintptr_t) c${M} + cn_stride);
239
240      $for M in range(MR):
241        a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} - kc);
242
243      nc -= 4;
244    } else {
245      if (nc & 2) {
246        $for M in range(MR):
247          unaligned_store_u16(c${M}, (uint16_t) _mm_extract_epi16(vout, ${M * 2}));
248          c${M} += 2;
249        vout = _mm_srli_epi32(vout, 16);
250      }
251      if (nc & 1) {
252        $if SSE == 4:
253          $for M in range(MR):
254            *c${M} = (${XINT8_T}) _mm_extract_epi8(vout, ${M * 4});
255        $else:
256          *c0 = (${XINT8_T}) _mm_cvtsi128_si32(vout);
257          $for M in range(1, MR):
258            *c${M} = (${XINT8_T}) _mm_extract_epi16(vout, ${M * 2});
259      }
260
261      nc = 0;
262    }
263  } while (nc != 0);
264}
265