xref: /aosp_15_r20/external/XNNPACK/src/qs8-dwconv/unipass-sse-mul32.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert SSE == 4
7$assert not XOP or AVX
8$assert not AVX or SSE == 4
9$assert REQUANTIZATION == "FP32"
10$assert DATATYPE in ["QC8", "QS8", "QU8"]
11$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
12$assert CHANNEL_TILE % 8 == 0
13$assert CHANNEL_TILE >= 8
14$assert KERNEL_TILE >= 2
15#include <assert.h>
16
17$if XOP:
18  #if defined(__GNUC__) || defined(__clang__)
19    #include <x86intrin.h>
20  #else
21    #include <immintrin.h>
22    #include <ammintrin.h>
23  #endif
24$else:
25  #include <immintrin.h>
26
27#include <xnnpack/dwconv.h>
28#include <xnnpack/intrinsics-polyfill.h>
29#include <xnnpack/unaligned.h>
30
31
32$PARAMS_STRUCT = REQUANTIZATION.lower() + "_" + ("sse2" if DATATYPE == "QU8" else "sse4")
33$PARAMS_UNION = "xnn_%s_conv_minmax_params" % DATATYPE.lower()
34$ISA = "xop" if XOP else "avx" if AVX else {4: "sse41"}[SSE]
35$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t"
36void xnn_${DATATYPE.lower()}_dwconv_minmax_${REQUANTIZATION.lower()}_ukernel_up${CHANNEL_TILE}x${KERNEL_TILE}__${ISA}_mul32(
37    size_t channels,
38    size_t output_width,
39    const ${XINT8_T}** input,
40    const void* weights,
41    ${XINT8_T}* output,
42    size_t input_stride,
43    size_t output_increment,
44    size_t input_offset,
45    const ${XINT8_T}* zero,
46    const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
47{
48  assert(channels != 0);
49  assert(output_width != 0);
50
51  $if DATATYPE == "QU8":
52    const __m128i vk_zero_point = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i*) params->${PARAMS_STRUCT}.kernel_zero_point));
53  do {
54    $for K in range(KERNEL_TILE):
55      const ${XINT8_T}* i${K} = input[${K}];
56      assert(i${K} != NULL);
57      if XNN_UNPREDICTABLE(i${K} != zero) {
58        i${K} = (const ${XINT8_T}*) ((uintptr_t) i${K} + input_offset);
59      }
60    input = (const ${XINT8_T}**) ((uintptr_t) input + input_stride);
61
62    size_t c = channels;
63    const void* w = weights;
64    for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) {
65      __m128i vacc${ABC[0:4]} = _mm_loadu_si128((const __m128i*) w);
66      $for C in range(4, CHANNEL_TILE, 4):
67        __m128i vacc${ABC[C:C+4]} = _mm_loadu_si128((const __m128i*) ((const int32_t*) w + ${C}));
68
69      $for K in range(KERNEL_TILE):
70
71        $for C in range(0, CHANNEL_TILE, 4):
72          $if DATATYPE == "QU8":
73            $if C == 0:
74              const __m128i vi${K}x${ABC[0:4]} = _mm_cvtepu8_epi32(_mm_cvtsi32_si128((int) unaligned_load_s32(i${K})));
75            $else:
76              const __m128i vi${K}x${ABC[C:C+4]} = _mm_cvtepu8_epi32(_mm_cvtsi32_si128((int) unaligned_load_s32(i${K} + ${C})));
77            const __m128i vk${K}x${ABC[C:C+4]} = _mm_sub_epi32(_mm_cvtepu8_epi32(_mm_cvtsi32_si128(*((const int*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${K * CHANNEL_TILE + C} * sizeof(${XINT8_T}))))), vk_zero_point);
78          $else:
79            $if C == 0:
80              const __m128i vi${K}x${ABC[0:4]} = _mm_cvtepi8_epi32(_mm_cvtsi32_si128((int) unaligned_load_s32(i${K})));
81            $else:
82              const __m128i vi${K}x${ABC[C:C+4]} = _mm_cvtepi8_epi32(_mm_cvtsi32_si128((int) unaligned_load_s32(i${K} + ${C})));
83            const __m128i vk${K}x${ABC[C:C+4]} = _mm_cvtepi8_epi32(_mm_cvtsi32_si128(*((const int*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${K * CHANNEL_TILE + C} * sizeof(${XINT8_T})))));
84        i${K} += ${CHANNEL_TILE};
85
86        $for C in range(0, CHANNEL_TILE, 4):
87          $if XOP:
88            vacc${ABC[C:C+4]} = _mm_macc_epi32(vi${K}x${ABC[C:C+4]}, vk${K}x${ABC[C:C+4]}, vacc${ABC[C:C+4]});
89          $else:
90            vacc${ABC[C:C+4]} = _mm_add_epi32(vacc${ABC[C:C+4]}, _mm_mullo_epi32(vi${K}x${ABC[C:C+4]}, vk${K}x${ABC[C:C+4]}));
91
92      w = (const void*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${KERNEL_TILE * CHANNEL_TILE} * sizeof(${XINT8_T}));
93
94      $for C in range(0, CHANNEL_TILE, 4):
95        __m128 vscaled${ABC[C:C+4]} = _mm_cvtepi32_ps(vacc${ABC[C:C+4]});
96
97      $if DATATYPE == "QC8":
98        const __m128 vscale${ABC[0:4]} = _mm_loadu_ps((const float*) w);
99        $for C in range(4, CHANNEL_TILE, 4):
100          const __m128 vscale${ABC[C:C+4]} = _mm_loadu_ps((const float*) w + ${C});
101        w = (const void*) ((const float*) w + ${CHANNEL_TILE});
102        $for C in range(0, CHANNEL_TILE, 4):
103          vscaled${ABC[C:C+4]} = _mm_mul_ps(vscaled${ABC[C:C+4]}, vscale${ABC[C:C+4]});
104      $else:
105        const __m128 vscale = _mm_load_ps(params->${PARAMS_STRUCT}.scale);
106        $for C in range(0, CHANNEL_TILE, 4):
107          vscaled${ABC[C:C+4]} = _mm_mul_ps(vscaled${ABC[C:C+4]}, vscale);
108
109      const __m128 voutput_max_less_zero_point = _mm_load_ps(params->${PARAMS_STRUCT}.output_max_less_zero_point);
110      $for C in range(0, CHANNEL_TILE, 4):
111        vscaled${ABC[C:C+4]} = _mm_min_ps(vscaled${ABC[C:C+4]}, voutput_max_less_zero_point);
112
113      $for C in range(0, CHANNEL_TILE, 4):
114        vacc${ABC[C:C+4]} = _mm_cvtps_epi32(vscaled${ABC[C:C+4]});
115
116      const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_zero_point);
117      $for C in range(0, CHANNEL_TILE, 8):
118        __m128i vout${ABC[C:C+8]} = _mm_adds_epi16(_mm_packs_epi32(vacc${ABC[C:C+4]}, vacc${ABC[C+4:C+8]}), voutput_zero_point);
119
120      const __m128i voutput_min = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min);
121      $if DATATYPE == "QU8":
122        $for C in range(0, CHANNEL_TILE, 16):
123          $if C + 8 < CHANNEL_TILE:
124            __m128i vout${ABC[C:C+16]} = _mm_packus_epi16(vout${ABC[C:C+8]}, vout${ABC[C+8:C+16]});
125            vout${ABC[C:C+16]} = _mm_max_epu8(vout${ABC[C:C+16]}, voutput_min);
126          $else:
127            __m128i vout${ABC[C:C+8]}${ABC[C:C+8]} = _mm_packus_epi16(vout${ABC[C:C+8]}, vout${ABC[C:C+8]});
128            vout${ABC[C:C+8]}${ABC[C:C+8]} = _mm_max_epu8(vout${ABC[C:C+8]}${ABC[C:C+8]}, voutput_min);
129      $else:
130        $for C in range(0, CHANNEL_TILE, 16):
131          $if C + 8 < CHANNEL_TILE:
132            __m128i vout${ABC[C:C+16]} = _mm_packs_epi16(vout${ABC[C:C+8]}, vout${ABC[C+8:C+16]});
133            vout${ABC[C:C+16]} = _mm_max_epi8(vout${ABC[C:C+16]}, voutput_min);
134          $else:
135            __m128i vout${ABC[C:C+8]}${ABC[C:C+8]} = _mm_packs_epi16(vout${ABC[C:C+8]}, vout${ABC[C:C+8]});
136            vout${ABC[C:C+8]}${ABC[C:C+8]} = _mm_max_epi8(vout${ABC[C:C+8]}${ABC[C:C+8]}, voutput_min);
137
138      $if CHANNEL_TILE > 8:
139        _mm_storeu_si128((__m128i*) output, vout${ABC[0:16]});
140      $else:
141        _mm_storel_epi64((__m128i*) output, vout${ABC[0:8]}${ABC[0:8]});
142      $for C in range(16, CHANNEL_TILE, 16):
143        $if C + 8 < CHANNEL_TILE:
144          _mm_storeu_si128((__m128i*) (output + ${C}), vout${ABC[C:C+16]});
145        $else:
146          _mm_storel_epi64((__m128i*) (output + ${C}), vout${ABC[C:C+8]}${ABC[C:C+8]});
147      output += ${CHANNEL_TILE};
148    }
149    if XNN_UNLIKELY(c != 0) {
150      $if CHANNEL_TILE > 4:
151        const ${XINT8_T}* k = (const ${XINT8_T}*) ((const int32_t*) w + ${CHANNEL_TILE});
152      ${"do " if CHANNEL_TILE > 4 else ""}{
153        __m128i vacc${ABC[0:4]} = _mm_loadu_si128((const __m128i*) w);
154
155        $for K in range(KERNEL_TILE):
156          $if DATATYPE == "QU8":
157            const __m128i vi${K}x${ABC[0:4]} = _mm_cvtepu8_epi32(_mm_cvtsi32_si128((int) unaligned_load_s32(i${K})));
158            $if CHANNEL_TILE > 4:
159              $if K == 0:
160                const __m128i vk${K}x${ABC[0:4]} = _mm_sub_epi32(_mm_cvtepu8_epi32(_mm_cvtsi32_si128(*((const int*) k))), vk_zero_point);
161              $else:
162                const __m128i vk${K}x${ABC[0:4]} = _mm_sub_epi32(_mm_cvtepu8_epi32(_mm_cvtsi32_si128(*((const int*) (k + ${K * CHANNEL_TILE})))), vk_zero_point);
163            $else:
164              const __m128i vk${K}x${ABC[0:4]} = _mm_sub_epi32(_mm_cvtepu8_epi32(_mm_cvtsi32_si128(*((const int*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${K * CHANNEL_TILE} * sizeof(${XINT8_T}))))), vk_zero_point);
165          $else:
166            const __m128i vi${K}x${ABC[0:4]} = _mm_cvtepi8_epi32(_mm_cvtsi32_si128((int) unaligned_load_s32(i${K})));
167            $if CHANNEL_TILE > 4:
168              $if K == 0:
169                const __m128i vk${K}x${ABC[0:4]} = _mm_cvtepi8_epi32(_mm_cvtsi32_si128(*((const int*) k)));
170              $else:
171                const __m128i vk${K}x${ABC[0:4]} = _mm_cvtepi8_epi32(_mm_cvtsi32_si128(*((const int*) (k + ${K * CHANNEL_TILE}))));
172            $else:
173              const __m128i vk${K}x${ABC[0:4]} = _mm_cvtepi8_epi32(_mm_cvtsi32_si128(*((const int*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${K * CHANNEL_TILE} * sizeof(${XINT8_T})))));
174          $if CHANNEL_TILE > 4:
175            i${K} += 4;
176
177          $if XOP:
178            vacc${ABC[0:4]} = _mm_macc_epi32(vi${K}x${ABC[0:4]}, vk${K}x${ABC[0:4]}, vacc${ABC[0:4]});
179          $else:
180            vacc${ABC[0:4]} = _mm_add_epi32(vacc${ABC[0:4]}, _mm_mullo_epi32(vi${K}x${ABC[0:4]}, vk${K}x${ABC[0:4]}));
181
182        $if CHANNEL_TILE > 4:
183          k += 4;
184
185        __m128 vscaled${ABC[0:4]} = _mm_cvtepi32_ps(vacc${ABC[0:4]});
186        $if DATATYPE == "QC8":
187          const __m128 vscale${ABC[0:4]} = _mm_loadu_ps((const float*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${CHANNEL_TILE * KERNEL_TILE} * sizeof(${XINT8_T})));
188          vscaled${ABC[0:4]} = _mm_mul_ps(vscaled${ABC[0:4]}, vscale${ABC[0:4]});
189        $else:
190          vscaled${ABC[0:4]} = _mm_mul_ps(vscaled${ABC[0:4]}, _mm_load_ps(params->${PARAMS_STRUCT}.scale));
191        vscaled${ABC[0:4]} = _mm_min_ps(vscaled${ABC[0:4]}, _mm_load_ps(params->${PARAMS_STRUCT}.output_max_less_zero_point));
192        vacc${ABC[0:4]} = _mm_cvtps_epi32(vscaled${ABC[0:4]});
193
194        $if CHANNEL_TILE > 4:
195          w = (const void*) ((const int32_t*) w + 4);
196
197        const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_zero_point);
198        __m128i vout${ABC[0:4]} = _mm_adds_epi16(_mm_packs_epi32(vacc${ABC[0:4]}, vacc${ABC[0:4]}), voutput_zero_point);
199
200        $if DATATYPE == "QU8":
201          vout${ABC[0:4]} = _mm_packus_epi16(vout${ABC[0:4]}, vout${ABC[0:4]});
202          vout${ABC[0:4]} = _mm_max_epu8(vout${ABC[0:4]}, _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min));
203        $else:
204          vout${ABC[0:4]} = _mm_packs_epi16(vout${ABC[0:4]}, vout${ABC[0:4]});
205          vout${ABC[0:4]} = _mm_max_epi8(vout${ABC[0:4]}, _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min));
206
207        $if CHANNEL_TILE > 4:
208          if XNN_LIKELY(c >= 4) {
209            _mm_storeu_si32(output, vout${ABC[0:4]});
210            output += 4;
211            c -= 4;
212          } else {
213            if (c & 2) {
214              unaligned_store_u16(output, (uint16_t) _mm_extract_epi16(vout${ABC[0:4]}, 0));
215              vout${ABC[0:4]} = _mm_srli_epi32(vout${ABC[0:4]}, 16);
216              output += 2;
217            }
218            if (c & 1) {
219              *output = (${XINT8_T}) _mm_extract_epi8(vout${ABC[0:4]}, 0);
220              output += 1;
221            }
222            c = 0;
223          }
224        $else:
225          if (c & 2) {
226            unaligned_store_u16(output, (uint16_t) _mm_extract_epi16(vout${ABC[0:4]}, 0));
227            vout${ABC[0:4]} = _mm_srli_epi32(vout${ABC[0:4]}, 16);
228            output += 2;
229          }
230          if (c & 1) {
231            *output = (${XINT8_T}) _mm_extract_epi8(vout${ABC[0:4]}, 0);
232            output += 1;
233          }
234      }${" while (c != 0);" if CHANNEL_TILE > 4 else ""}
235    }
236
237    output = (${XINT8_T}*) ((uintptr_t) output + output_increment);
238  } while (--output_width != 0);
239}
240