xref: /aosp_15_r20/external/XNNPACK/src/qs8-dwconv/unipass-avx512skx-mul32.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
7$assert REQUANTIZATION == "FP32"
8$assert DATATYPE in ["QC8", "QS8", "QU8"]
9$assert CHANNEL_TILE % 16 == 0
10$assert CHANNEL_TILE >= 16
11$assert KERNEL_TILE >= 2
12#include <assert.h>
13
14#include <immintrin.h>
15
16#include <xnnpack/dwconv.h>
17#include <xnnpack/intrinsics-polyfill.h>
18
19
20$PARAMS_STRUCT = REQUANTIZATION.lower() + "_avx512"
21$PARAMS_UNION = "xnn_%s_conv_minmax_params" % DATATYPE.lower()
22$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t"
23$_MM512_CVTEPX8_EPI32 = "_mm512_cvtepu8_epi32" if DATATYPE == "QU8" else "_mm512_cvtepi8_epi32"
24$_MM256_PACKXS_EPI16 = "_mm256_packus_epi16" if DATATYPE == "QU8" else "_mm256_packs_epi16"
25$_MM_PACKXS_EPI16 = "_mm_packus_epi16" if DATATYPE == "QU8" else "_mm_packs_epi16"
26$_MM256_MIN_EPX8 = "_mm256_min_epu8" if DATATYPE == "QU8" else "_mm256_min_epi8"
27$_MM256_MAX_EPX8 = "_mm256_max_epu8" if DATATYPE == "QU8" else "_mm256_max_epi8"
28$_MM_MIN_EPX8 = "_mm_min_epu8" if DATATYPE == "QU8" else "_mm_min_epi8"
29$_MM_MAX_EPX8 = "_mm_max_epu8" if DATATYPE == "QU8" else "_mm_max_epi8"
30void xnn_${DATATYPE.lower()}_dwconv_minmax_${REQUANTIZATION.lower()}_ukernel_up${CHANNEL_TILE}x${KERNEL_TILE}__avx512skx_mul32(
31    size_t channels,
32    size_t output_width,
33    const ${XINT8_T}** input,
34    const void* weights,
35    ${XINT8_T}* output,
36    size_t input_stride,
37    size_t output_increment,
38    size_t input_offset,
39    const ${XINT8_T}* zero,
40    const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_MSAN
41{
42  assert(channels != 0);
43  assert(output_width != 0);
44
45  $if DATATYPE != "QC8":
46    const __m512 vscale = _mm512_load_ps(params->${PARAMS_STRUCT}.scale);
47  const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->${PARAMS_STRUCT}.output_max_less_zero_point);
48  $if CHANNEL_TILE > 16:
49    const __m512i voutput_zero_point = _mm512_load_si512(params->${PARAMS_STRUCT}.output_zero_point);
50    const __m256i voutput_min = _mm256_load_si256((const __m256i*) params->${PARAMS_STRUCT}.output_min);
51    const __m256i vpermute_mask = _mm256_set_epi32(7, 3, 5, 1, 6, 2, 4, 0);
52  $else:
53    const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->${PARAMS_STRUCT}.output_zero_point);
54    const __m128i voutput_min = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min);
55
56  $if DATATYPE == "QU8":
57    const __m512i vk_zero_point = _mm512_cvtepu16_epi32(_mm256_load_si256((const __m256i*) params->${PARAMS_STRUCT}.kernel_zero_point));
58  do {
59    $for K in range(KERNEL_TILE):
60      const ${XINT8_T}* i${K} = input[${K}];
61      assert(i${K} != NULL);
62      if XNN_UNPREDICTABLE(i${K} != zero) {
63        i${K} = (const ${XINT8_T}*) ((uintptr_t) i${K} + input_offset);
64      }
65    input = (const ${XINT8_T}**) ((uintptr_t) input + input_stride);
66
67    size_t c = channels;
68    const void* w = weights;
69    for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) {
70      __m512i vacc${ABC[0:16]} = _mm512_loadu_si512(w);
71      $for C in range(16, CHANNEL_TILE, 16):
72        __m512i vacc${ABC[C:C+16]} = _mm512_loadu_si512((const void*) ((uintptr_t) w + ${C} * sizeof(int32_t)));
73
74      $for K in range(KERNEL_TILE):
75
76        $for C in range(0, CHANNEL_TILE, 16):
77          $if C == 0:
78            const __m512i vi${K}x${ABC[0:16]} = ${_MM512_CVTEPX8_EPI32}(_mm_loadu_si128((const __m128i*) i${K}));
79          $else:
80            const __m512i vi${K}x${ABC[C:C+16]} = ${_MM512_CVTEPX8_EPI32}(_mm_loadu_si128((const __m128i*) (i${K} + ${C})));
81          $if DATATYPE == "QU8":
82            const __m512i vk${K}x${ABC[C:C+16]} = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${K * CHANNEL_TILE + C} * sizeof(${XINT8_T})))), vk_zero_point);
83          $else:
84            const __m512i vk${K}x${ABC[C:C+16]} = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${K * CHANNEL_TILE + C} * sizeof(${XINT8_T}))));
85        i${K} += ${CHANNEL_TILE};
86
87        $for C in range(0, CHANNEL_TILE, 16):
88          vacc${ABC[C:C+16]} = _mm512_add_epi32(vacc${ABC[C:C+16]}, _mm512_mullo_epi32(vi${K}x${ABC[C:C+16]}, vk${K}x${ABC[C:C+16]}));
89
90      w = (const void*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${KERNEL_TILE * CHANNEL_TILE} * sizeof(${XINT8_T}));
91
92      $for C in range(0, CHANNEL_TILE, 16):
93        __m512 vscaled${ABC[C:C+16]} = _mm512_cvtepi32_ps(vacc${ABC[C:C+16]});
94
95      $if DATATYPE == "QC8":
96        const __m512 vscale${ABC[0:16]} = _mm512_loadu_ps(w);
97        $for C in range(16, CHANNEL_TILE, 16):
98          const __m512 vscale${ABC[C:C+16]} = _mm512_loadu_ps((const void*) ((uintptr_t) w + ${C} * sizeof(float)));
99        w = (const void*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(float));
100        $for C in range(0, CHANNEL_TILE, 16):
101          vscaled${ABC[C:C+16]} = _mm512_mul_ps(vscaled${ABC[C:C+16]}, vscale${ABC[C:C+16]});
102      $else:
103        $for C in range(0, CHANNEL_TILE, 16):
104          vscaled${ABC[C:C+16]} = _mm512_mul_ps(vscaled${ABC[C:C+16]}, vscale);
105
106      $for C in range(0, CHANNEL_TILE, 16):
107        vscaled${ABC[C:C+16]} = _mm512_min_ps(vscaled${ABC[C:C+16]}, voutput_max_less_zero_point);
108
109      $for C in range(0, CHANNEL_TILE, 16):
110        vacc${ABC[C:C+16]} = _mm512_cvtps_epi32(vscaled${ABC[C:C+16]});
111
112      $for C in range(0, CHANNEL_TILE, 16):
113        $if C + 16 < CHANNEL_TILE:
114          __m512i vout${ABC[C:C+4]}${ABC[C+16:C+20]}${ABC[C+4:C+8]}${ABC[C+20:C+24]}${ABC[C+8:C+12]}${ABC[C+24:C+28]}${ABC[C+12:C+16]}${ABC[C+28:C+32]} = _mm512_adds_epi16(_mm512_packs_epi32(vacc${ABC[C:C+16]}, vacc${ABC[C+16:C+32]}), voutput_zero_point);
115        $elif CHANNEL_TILE > 16:
116          __m256i vout${ABC[C:C+4]}${ABC[C+8:C+12]}${ABC[C+4:C+8]}${ABC[C+12:C+16]} = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc${ABC[C:C+16]}), _mm512_extracti32x8_epi32(vacc${ABC[C:C+16]}, 1)), _mm512_castsi512_si256(voutput_zero_point));
117        $else:
118          __m256i vout${ABC[C:C+4]}${ABC[C+8:C+12]}${ABC[C+4:C+8]}${ABC[C+12:C+16]} = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc${ABC[C:C+16]}), _mm512_extracti32x8_epi32(vacc${ABC[C:C+16]}, 1)), voutput_zero_point);
119
120      $for C in range(0, CHANNEL_TILE, 16):
121        $if C + 16 < CHANNEL_TILE:
122          const __m256i vout${ABC[C:C+4]}${ABC[C+16:C+20]}${ABC[C+4:C+8]}${ABC[C+20:C+24]} = _mm512_castsi512_si256(vout${ABC[C:C+4]}${ABC[C+16:C+20]}${ABC[C+4:C+8]}${ABC[C+20:C+24]}${ABC[C+8:C+12]}${ABC[C+24:C+28]}${ABC[C+12:C+16]}${ABC[C+28:C+32]});
123          const __m256i vout${ABC[C+8:C+12]}${ABC[C+24:C+28]}${ABC[C+12:C+16]}${ABC[C+28:C+32]} = _mm512_extracti32x8_epi32(vout${ABC[C:C+4]}${ABC[C+16:C+20]}${ABC[C+4:C+8]}${ABC[C+20:C+24]}${ABC[C+8:C+12]}${ABC[C+24:C+28]}${ABC[C+12:C+16]}${ABC[C+28:C+32]}, 1);
124          const __m256i vout${ABC[C:C+4]}${ABC[C+16:C+20]}${ABC[C+8:C+12]}${ABC[C+24:C+28]}${ABC[C+4:C+8]}${ABC[C+20:C+24]}${ABC[C+12:C+16]}${ABC[C+28:C+32]} = ${_MM256_PACKXS_EPI16}(vout${ABC[C:C+4]}${ABC[C+16:C+20]}${ABC[C+4:C+8]}${ABC[C+20:C+24]}, vout${ABC[C+8:C+12]}${ABC[C+24:C+28]}${ABC[C+12:C+16]}${ABC[C+28:C+32]});
125          __m256i vout${ABC[C:C+32]} = _mm256_permutevar8x32_epi32(vout${ABC[C:C+4]}${ABC[C+16:C+20]}${ABC[C+8:C+12]}${ABC[C+24:C+28]}${ABC[C+4:C+8]}${ABC[C+20:C+24]}${ABC[C+12:C+16]}${ABC[C+28:C+32]}, vpermute_mask);
126        $else:
127          const __m128i vout${ABC[C:C+4]}${ABC[C+8:C+12]} = _mm256_castsi256_si128(vout${ABC[C:C+4]}${ABC[C+8:C+12]}${ABC[C+4:C+8]}${ABC[C+12:C+16]});
128          const __m128i vout${ABC[C+4:C+8]}${ABC[C+12:C+16]} = _mm256_extracti128_si256(vout${ABC[C:C+4]}${ABC[C+8:C+12]}${ABC[C+4:C+8]}${ABC[C+12:C+16]}, 1);
129          __m128i vout${ABC[C:C+16]} = _mm_shuffle_epi32(${_MM_PACKXS_EPI16}(vout${ABC[C:C+4]}${ABC[C+8:C+12]}, vout${ABC[C+4:C+8]}${ABC[C+12:C+16]}), _MM_SHUFFLE(3, 1, 2, 0));
130
131      $for C in range(0, CHANNEL_TILE, 16):
132        $if C + 16 < CHANNEL_TILE:
133          vout${ABC[C:C+32]} = ${_MM256_MAX_EPX8}(vout${ABC[C:C+32]}, voutput_min);
134        $elif CHANNEL_TILE > 16:
135          vout${ABC[C:C+16]} = ${_MM_MAX_EPX8}(vout${ABC[C:C+16]}, _mm256_castsi256_si128(voutput_min));
136        $else:
137          vout${ABC[C:C+16]} = ${_MM_MAX_EPX8}(vout${ABC[C:C+16]}, voutput_min);
138
139      $if CHANNEL_TILE > 16:
140        _mm256_storeu_si256((__m256i*) output, vout${ABC[0:32]});
141      $else:
142        _mm_storeu_si128((__m128i*) output, vout${ABC[0:16]});
143      $for C in range(16, CHANNEL_TILE, 16):
144        $if C + 16 < CHANNEL_TILE:
145          _mm256_storeu_si256((__m256i*) (output + ${C}), vout${ABC[C:C+32]});
146        $else:
147          _mm_storeu_si128((__m128i*) (output + ${C}), vout${ABC[C:C+16]});
148      output += ${CHANNEL_TILE};
149    }
150    if XNN_UNLIKELY(c != 0) {
151      // Prepare mask for valid 8-bit elements (depends on nc).
152      const __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << (c & 15)) - UINT32_C(1)));
153      $if CHANNEL_TILE > 16:
154        const ${XINT8_T}* k = (const ${XINT8_T}*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t));
155      ${"do " if CHANNEL_TILE > 16 else ""}{
156        __m512i vacc${ABC[0:16]} = _mm512_loadu_si512(w);
157
158        $for K in range(KERNEL_TILE):
159
160          const __m512i vi${K}x${ABC[0:16]} = ${_MM512_CVTEPX8_EPI32}(_mm_loadu_si128((const __m128i*) i${K}));
161          $if DATATYPE == "QU8":
162            $if CHANNEL_TILE > 16:
163              $if K == 0:
164                const __m512i vk${K}x${ABC[0:16]} = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) k)), vk_zero_point);
165              $else:
166                const __m512i vk${K}x${ABC[0:16]} = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + ${K * CHANNEL_TILE}))), vk_zero_point);
167            $else:
168              const __m512i vk${K}x${ABC[0:16]} = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${K * CHANNEL_TILE} * sizeof(${XINT8_T})))), vk_zero_point);
169          $else:
170            $if CHANNEL_TILE > 16:
171              $if K == 0:
172                const __m512i vk${K}x${ABC[0:16]} = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) k));
173              $else:
174                const __m512i vk${K}x${ABC[0:16]} = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + ${K * CHANNEL_TILE})));
175            $else:
176              const __m512i vk${K}x${ABC[0:16]} = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${K * CHANNEL_TILE} * sizeof(${XINT8_T}))));
177          $if CHANNEL_TILE > 16:
178            i${K} += 16;
179
180          vacc${ABC[0:16]} = _mm512_add_epi32(vacc${ABC[0:16]}, _mm512_mullo_epi32(vi${K}x${ABC[0:16]}, vk${K}x${ABC[0:16]}));
181
182        $if CHANNEL_TILE > 16:
183          k += 16;
184
185        __m512 vscaled${ABC[0:16]} = _mm512_cvtepi32_ps(vacc${ABC[0:16]});
186        $if DATATYPE == "QC8":
187          const __m512 vscale${ABC[0:16]} = _mm512_loadu_ps((const void*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${CHANNEL_TILE * KERNEL_TILE} * sizeof(${XINT8_T})));
188          vscaled${ABC[0:16]} = _mm512_mul_ps(vscaled${ABC[0:16]}, vscale${ABC[0:16]});
189        $else:
190          vscaled${ABC[0:16]} = _mm512_mul_ps(vscaled${ABC[0:16]}, vscale);
191        vscaled${ABC[0:16]} = _mm512_min_ps(vscaled${ABC[0:16]}, voutput_max_less_zero_point);
192        vacc${ABC[0:16]} = _mm512_cvtps_epi32(vscaled${ABC[0:16]});
193
194        $if CHANNEL_TILE > 16:
195          w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
196
197        $if CHANNEL_TILE > 16:
198          __m256i vout${ABC[0:4]}${ABC[8:12]}${ABC[4:8]}${ABC[12:16]} = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc${ABC[0:16]}), _mm512_extracti32x8_epi32(vacc${ABC[0:16]}, 1)), _mm512_castsi512_si256(voutput_zero_point));
199        $else:
200          __m256i vout${ABC[0:4]}${ABC[8:12]}${ABC[4:8]}${ABC[12:16]} = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc${ABC[0:16]}), _mm512_extracti32x8_epi32(vacc${ABC[0:16]}, 1)), voutput_zero_point);
201
202        const __m128i vout${ABC[0:4]}${ABC[8:12]} = _mm256_castsi256_si128(vout${ABC[0:4]}${ABC[8:12]}${ABC[4:8]}${ABC[12:16]});
203        const __m128i vout${ABC[4:8]}${ABC[12:16]} = _mm256_extracti128_si256(vout${ABC[0:4]}${ABC[8:12]}${ABC[4:8]}${ABC[12:16]}, 1);
204        __m128i vout${ABC[0:16]} = _mm_shuffle_epi32(${_MM_PACKXS_EPI16}(vout${ABC[0:4]}${ABC[8:12]}, vout${ABC[4:8]}${ABC[12:16]}), _MM_SHUFFLE(3, 1, 2, 0));
205        $if CHANNEL_TILE > 16:
206          vout${ABC[0:16]} = ${_MM_MAX_EPX8}(vout${ABC[0:16]}, _mm256_castsi256_si128(voutput_min));
207        $else:
208          vout${ABC[0:16]} = ${_MM_MAX_EPX8}(vout${ABC[0:16]}, voutput_min);
209
210        $if CHANNEL_TILE > 16:
211          if XNN_LIKELY(c >= 16) {
212            _mm_storeu_si128((__m128i*) output, vout${ABC[0:16]});
213            output += 16;
214            c -= 16;
215          } else {
216            _mm_mask_storeu_epi8(output, vmask, vout${ABC[0:16]});
217            output = (${XINT8_T}*) ((uintptr_t) output + c);
218            c = 0;
219          }
220        $else:
221          _mm_mask_storeu_epi8(output, vmask, vout${ABC[0:16]});
222          output = (${XINT8_T}*) ((uintptr_t) output + c);
223      }${" while (c != 0);" if CHANNEL_TILE > 16 else ""}
224    }
225
226    output = (${XINT8_T}*) ((uintptr_t) output + output_increment);
227  } while (--output_width != 0);
228}
229