xref: /aosp_15_r20/external/XNNPACK/src/qs8-dwconv/unipass-wasmsimd-mul16.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
7$assert REQUANTIZATION == "FP32"
8$assert DATATYPE in ["QC8", "QS8", "QU8"]
9$assert not ADD16 or DATATYPE != "QU8"
10$assert CHANNEL_TILE % 8 == 0
11$assert CHANNEL_TILE >= 8
12$assert KERNEL_TILE >= 2
13#include <assert.h>
14
15#include <wasm_simd128.h>
16
17#include <xnnpack/dwconv.h>
18
19
20$PARAMS_STRUCT = REQUANTIZATION.lower() + "_wasmsimd"
21$PARAMS_UNION = "xnn_%s_conv_minmax_params" % DATATYPE.lower()
22$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t"
23$WASM_X16X8_LOAD8X8 = "wasm_u16x8_load8x8" if DATATYPE == "QU8" else "wasm_i16x8_load8x8"
24$WASM_X32X4_EXTEND_LOW_X16X8 = "wasm_u32x4_extend_low_u16x8" if DATATYPE == "QU8" else "wasm_i32x4_extend_low_i16x8"
25$WASM_X32X4_EXTEND_HIGH_X16X8 = "wasm_u32x4_extend_high_u16x8" if DATATYPE == "QU8" else "wasm_i32x4_extend_high_i16x8"
26$WASM_X8X16_NARROW_I16X8 = "wasm_u8x16_narrow_i16x8" if DATATYPE == "QU8" else "wasm_i8x16_narrow_i16x8"
27$WASM_X8X16_MIN = "wasm_u8x16_min" if DATATYPE == "QU8" else "wasm_i8x16_min"
28void xnn_${DATATYPE.lower()}_dwconv_minmax_${REQUANTIZATION.lower()}_ukernel_up${CHANNEL_TILE}x${KERNEL_TILE}__wasmsimd_mul16${"_add16" if ADD16 else ""}(
29    size_t channels,
30    size_t output_width,
31    const ${XINT8_T}** input,
32    const void* weights,
33    ${XINT8_T}* output,
34    size_t input_stride,
35    size_t output_increment,
36    size_t input_offset,
37    const ${XINT8_T}* zero,
38    const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
39{
40  assert(channels != 0);
41  assert(output_width != 0);
42
43  $if DATATYPE == "QU8":
44    const v128_t vkernel_zero_point = wasm_u32x4_load16x4(params->${PARAMS_STRUCT}.kernel_zero_point);
45  do {
46    $for K in range(KERNEL_TILE):
47      const ${XINT8_T}* i${K} = input[${K}];
48      assert(i${K} != NULL);
49      if XNN_UNPREDICTABLE(i${K} != zero) {
50        i${K} = (const ${XINT8_T}*) ((uintptr_t) i${K} + input_offset);
51      }
52    input = (const ${XINT8_T}**) ((uintptr_t) input + input_stride);
53
54    size_t c = channels;
55    const void* w = weights;
56    for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) {
57      v128_t vacc${ABC[0:4]} = wasm_v128_load(w);
58      $for C in range(4, CHANNEL_TILE, 4):
59        v128_t vacc${ABC[C:C+4]} = wasm_v128_load((const void*) ((uintptr_t) w + ${C} * sizeof(int32_t)));
60
61      $for K in range(KERNEL_TILE):
62
63        $for C in range(0, CHANNEL_TILE, 8):
64          $if C == 0:
65            const v128_t vi${K}x${ABC[0:8]} = ${WASM_X16X8_LOAD8X8}(i${K});
66          $else:
67            const v128_t vi${K}x${ABC[C:C+8]} = ${WASM_X16X8_LOAD8X8}(i${K} + ${C});
68          const v128_t vk${K}x${ABC[C:C+8]} = ${WASM_X16X8_LOAD8X8}((const void*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${K * CHANNEL_TILE + C} * sizeof(${XINT8_T})));
69        $if DATATYPE == "QU8":
70          $for C in range(0, CHANNEL_TILE, 8):
71            $if K == 1:
72              v128_t vsumx${ABC[C:C+8]} = wasm_i16x8_add(vi0x${ABC[C:C+8]}, vi1x${ABC[C:C+8]});
73            $elif K > 1:
74              vsumx${ABC[C:C+8]} = wasm_i16x8_add(vsumx${ABC[C:C+8]}, vi${K}x${ABC[C:C+8]});
75        i${K} += ${CHANNEL_TILE};
76
77        $for C in range(0, CHANNEL_TILE, 8):
78          $if K == 0:
79            v128_t vprod${ABC[C:C+8]} = wasm_i16x8_mul(vi${K}x${ABC[C:C+8]}, vk${K}x${ABC[C:C+8]});
80          $elif K % 2 == 0 or K + 1 == KERNEL_TILE or not ADD16:
81            vprod${ABC[C:C+8]} = wasm_i16x8_mul(vi${K}x${ABC[C:C+8]}, vk${K}x${ABC[C:C+8]});
82          $else:
83            vprod${ABC[C:C+8]} = wasm_i16x8_add(vprod${ABC[C:C+8]}, wasm_i16x8_mul(vi${K}x${ABC[C:C+8]}, vk${K}x${ABC[C:C+8]}));
84
85        $if not ADD16 or K % 2 == 1 or K + 1 == KERNEL_TILE:
86          $for C in range(0, CHANNEL_TILE, 8):
87            vacc${ABC[C:C+4]} = wasm_i32x4_add(vacc${ABC[C:C+4]}, ${WASM_X32X4_EXTEND_LOW_X16X8}(vprod${ABC[C:C+8]}));
88            vacc${ABC[C+4:C+8]} = wasm_i32x4_add(vacc${ABC[C+4:C+8]}, ${WASM_X32X4_EXTEND_HIGH_X16X8}(vprod${ABC[C:C+8]}));
89
90      $if DATATYPE == "QU8":
91        $for C in range(0, CHANNEL_TILE, 8):
92          vacc${ABC[C:C+4]} = wasm_i32x4_sub(vacc${ABC[C:C+4]}, wasm_i32x4_mul(wasm_u32x4_extend_low_u16x8(vsumx${ABC[C:C+8]}), vkernel_zero_point));
93          vacc${ABC[C+4:C+8]} = wasm_i32x4_sub(vacc${ABC[C+4:C+8]}, wasm_i32x4_mul(wasm_u32x4_extend_high_u16x8(vsumx${ABC[C:C+8]}), vkernel_zero_point));
94
95      w = (const void*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${KERNEL_TILE * CHANNEL_TILE} * sizeof(${XINT8_T}));
96
97      $for C in range(0, CHANNEL_TILE, 4):
98        vacc${ABC[C:C+4]} = wasm_f32x4_convert_i32x4(vacc${ABC[C:C+4]});
99
100      $if DATATYPE == "QC8":
101        const v128_t vscale${ABC[0:4]} = wasm_v128_load(w);
102        $for C in range(4, CHANNEL_TILE, 4):
103          const v128_t vscale${ABC[C:C+4]} = wasm_v128_load((const float*) w + ${C});
104        w = (const void*) ((const float*) w + ${CHANNEL_TILE});
105
106        $for C in range(0, CHANNEL_TILE, 4):
107          vacc${ABC[C:C+4]} = wasm_f32x4_mul(vacc${ABC[C:C+4]}, vscale${ABC[C:C+4]});
108      $else:
109        const v128_t vscale = wasm_v128_load64_splat(params->${PARAMS_STRUCT}.scale);
110        $for C in range(0, CHANNEL_TILE, 4):
111          vacc${ABC[C:C+4]} = wasm_f32x4_mul(vacc${ABC[C:C+4]}, vscale);
112
113      const v128_t vmagic_bias = wasm_v128_load64_splat(params->${PARAMS_STRUCT}.magic_bias);
114      $for C in range(0, CHANNEL_TILE, 4):
115        vacc${ABC[C:C+4]} = wasm_f32x4_add(vacc${ABC[C:C+4]}, vmagic_bias);
116
117      const v128_t vmagic_min = wasm_v128_load64_splat(params->${PARAMS_STRUCT}.magic_min);
118      $for C in range(0, CHANNEL_TILE, 4):
119        vacc${ABC[C:C+4]} = wasm_i32x4_max(vacc${ABC[C:C+4]}, vmagic_min);
120
121      const v128_t vmagic_bias_less_output_zero_point = wasm_v128_load64_splat(params->${PARAMS_STRUCT}.magic_bias_less_output_zero_point);
122      $for C in range(0, CHANNEL_TILE, 4):
123        vacc${ABC[C:C+4]} = wasm_i32x4_sub(vacc${ABC[C:C+4]}, vmagic_bias_less_output_zero_point);
124
125      $for C in range(0, CHANNEL_TILE, 8):
126        v128_t vout${ABC[C:C+8]} = wasm_i16x8_narrow_i32x4(vacc${ABC[C:C+4]}, vacc${ABC[C+4:C+8]});
127
128      $for C in range(0, CHANNEL_TILE, 16):
129        $if C + 8 < CHANNEL_TILE:
130          v128_t vout${ABC[C:C+16]} = ${WASM_X8X16_NARROW_I16X8}(vout${ABC[C:C+8]}, vout${ABC[C+8:C+16]});
131        $else:
132          v128_t vout${ABC[C:C+8]}${ABC[C:C+8]} = ${WASM_X8X16_NARROW_I16X8}(vout${ABC[C:C+8]}, vout${ABC[C:C+8]});
133
134      const v128_t voutput_max = wasm_v128_load64_splat(params->${PARAMS_STRUCT}.output_max);
135      $for C in range(0, CHANNEL_TILE, 16):
136        $if C + 8 < CHANNEL_TILE:
137          vout${ABC[C:C+16]} = ${WASM_X8X16_MIN}(vout${ABC[C:C+16]}, voutput_max);
138        $else:
139          vout${ABC[C:C+8]}${ABC[C:C+8]} = ${WASM_X8X16_MIN}(vout${ABC[C:C+8]}${ABC[C:C+8]}, voutput_max);
140
141      $if CHANNEL_TILE > 8:
142        wasm_v128_store(output, vout${ABC[0:16]});
143      $else:
144        *((double*) output) = wasm_f64x2_extract_lane(vout${ABC[0:8]}${ABC[0:8]}, 0);
145      $for C in range(16, CHANNEL_TILE, 16):
146        $if C + 8 < CHANNEL_TILE:
147          wasm_v128_store(output + ${C}, vout${ABC[C:C+16]});
148        $else:
149          *((double*) (output + ${C})) = wasm_f64x2_extract_lane(vout${ABC[C:C+8]}${ABC[C:C+8]}, 0);
150      output += ${CHANNEL_TILE};
151    }
152    if XNN_UNLIKELY(c != 0) {
153      $if CHANNEL_TILE > 8:
154        const ${XINT8_T}* k = (const ${XINT8_T}*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t));
155      ${"do " if CHANNEL_TILE > 8 else ""}{
156        v128_t vacc${ABC[0:4]} = wasm_v128_load(w);
157        v128_t vacc${ABC[4:8]} = wasm_v128_load((const void*) ((uintptr_t) w + 4 * sizeof(int32_t)));
158
159        $for K in range(KERNEL_TILE):
160
161          const v128_t vi${K}x${ABC[0:8]} = ${WASM_X16X8_LOAD8X8}(i${K});
162          $if CHANNEL_TILE > 8:
163            $if K == 0:
164              const v128_t vk${K}x${ABC[0:8]} = ${WASM_X16X8_LOAD8X8}(k);
165            $else:
166              const v128_t vk${K}x${ABC[0:8]} = ${WASM_X16X8_LOAD8X8}((const void*) (k + ${K * CHANNEL_TILE}));
167          $else:
168            const v128_t vk${K}x${ABC[0:8]} = ${WASM_X16X8_LOAD8X8}((const void*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${K * CHANNEL_TILE} * sizeof(${XINT8_T})));
169          $if DATATYPE == "QU8":
170            $if K == 1:
171              v128_t vsumx${ABC[0:8]} = wasm_i16x8_add(vi0x${ABC[0:8]}, vi1x${ABC[0:8]});
172            $elif K > 1:
173              vsumx${ABC[0:8]} = wasm_i16x8_add(vsumx${ABC[0:8]}, vi${K}x${ABC[0:8]});
174          $if CHANNEL_TILE > 8:
175            i${K} += 8;
176
177          $if K == 0:
178            v128_t vprod${ABC[0:8]} = wasm_i16x8_mul(vi${K}x${ABC[0:8]}, vk${K}x${ABC[0:8]});
179          $elif K % 2 == 0 or K + 1 == KERNEL_TILE or not ADD16:
180            vprod${ABC[0:8]} = wasm_i16x8_mul(vi${K}x${ABC[0:8]}, vk${K}x${ABC[0:8]});
181          $else:
182            vprod${ABC[0:8]} = wasm_i16x8_add(vprod${ABC[0:8]}, wasm_i16x8_mul(vi${K}x${ABC[0:8]}, vk${K}x${ABC[0:8]}));
183
184          $if not ADD16 or K % 2 == 1 or K + 1 == KERNEL_TILE:
185            vacc${ABC[0:4]} = wasm_i32x4_add(vacc${ABC[0:4]}, ${WASM_X32X4_EXTEND_LOW_X16X8}(vprod${ABC[0:8]}));
186            vacc${ABC[4:8]} = wasm_i32x4_add(vacc${ABC[4:8]}, ${WASM_X32X4_EXTEND_HIGH_X16X8}(vprod${ABC[0:8]}));
187
188        $if CHANNEL_TILE > 8:
189          k += 8;
190
191      $if DATATYPE == "QU8":
192        vacc${ABC[0:4]} = wasm_i32x4_sub(vacc${ABC[0:4]}, wasm_i32x4_mul(wasm_u32x4_extend_low_u16x8(vsumx${ABC[0:8]}), vkernel_zero_point));
193        vacc${ABC[4:8]} = wasm_i32x4_sub(vacc${ABC[4:8]}, wasm_i32x4_mul(wasm_u32x4_extend_high_u16x8(vsumx${ABC[0:8]}), vkernel_zero_point));
194
195      vacc${ABC[0:4]} = wasm_f32x4_convert_i32x4(vacc${ABC[0:4]});
196      vacc${ABC[4:8]} = wasm_f32x4_convert_i32x4(vacc${ABC[4:8]});
197
198      $if DATATYPE == "QC8":
199        const v128_t vscale${ABC[0:4]} = wasm_v128_load((const float*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${CHANNEL_TILE * KERNEL_TILE} * sizeof(${XINT8_T})));
200        const v128_t vscale${ABC[4:8]} = wasm_v128_load((const float*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${CHANNEL_TILE * KERNEL_TILE} * sizeof(${XINT8_T}) + 4 * sizeof(float)));
201
202        vacc${ABC[0:4]} = wasm_f32x4_mul(vacc${ABC[0:4]}, vscale${ABC[0:4]});
203        vacc${ABC[4:8]} = wasm_f32x4_mul(vacc${ABC[4:8]}, vscale${ABC[4:8]});
204      $else:
205        const v128_t vscale = wasm_v128_load64_splat(params->${PARAMS_STRUCT}.scale);
206        vacc${ABC[0:4]} = wasm_f32x4_mul(vacc${ABC[0:4]}, vscale);
207        vacc${ABC[4:8]} = wasm_f32x4_mul(vacc${ABC[4:8]}, vscale);
208
209      const v128_t vmagic_bias = wasm_v128_load64_splat(params->${PARAMS_STRUCT}.magic_bias);
210      vacc${ABC[0:4]} = wasm_f32x4_add(vacc${ABC[0:4]}, vmagic_bias);
211      vacc${ABC[4:8]} = wasm_f32x4_add(vacc${ABC[4:8]}, vmagic_bias);
212
213      const v128_t vmagic_min = wasm_v128_load64_splat(params->${PARAMS_STRUCT}.magic_min);
214      vacc${ABC[0:4]} = wasm_i32x4_max(vacc${ABC[0:4]}, vmagic_min);
215      vacc${ABC[4:8]} = wasm_i32x4_max(vacc${ABC[4:8]}, vmagic_min);
216
217      const v128_t vmagic_bias_less_output_zero_point = wasm_v128_load64_splat(params->${PARAMS_STRUCT}.magic_bias_less_output_zero_point);
218      vacc${ABC[0:4]} = wasm_i32x4_sub(vacc${ABC[0:4]}, vmagic_bias_less_output_zero_point);
219      vacc${ABC[4:8]} = wasm_i32x4_sub(vacc${ABC[4:8]}, vmagic_bias_less_output_zero_point);
220
221      v128_t vout${ABC[0:8]} = wasm_i16x8_narrow_i32x4(vacc${ABC[0:4]}, vacc${ABC[4:8]});
222      v128_t vout${ABC[0:8]}${ABC[0:8]} = ${WASM_X8X16_NARROW_I16X8}(vout${ABC[0:8]}, vout${ABC[0:8]});
223
224      const v128_t voutput_max = wasm_v128_load64_splat(params->${PARAMS_STRUCT}.output_max);
225      vout${ABC[0:8]}${ABC[0:8]} = ${WASM_X8X16_MIN}(vout${ABC[0:8]}${ABC[0:8]}, voutput_max);
226
227      $if CHANNEL_TILE > 8:
228        w = (const void*) ((uintptr_t) w + 8 * sizeof(int32_t));
229
230      $if CHANNEL_TILE > 8:
231        if XNN_LIKELY(c >= 8) {
232          *((double*) output) = wasm_f64x2_extract_lane(vout${ABC[0:8]}${ABC[0:8]}, 0);
233          output += 8;
234          c -= 8;
235        } else {
236          if (c & 4) {
237            *((float*) output) = wasm_f32x4_extract_lane(vout${ABC[0:8]}${ABC[0:8]}, 0);
238            vout${ABC[0:8]}${ABC[0:8]} = wasm_u64x2_shr(vout${ABC[0:8]}${ABC[0:8]}, 32);
239            output += 4;
240          }
241          uint32_t vout${ABC[0:4]} = wasm_i32x4_extract_lane(vout${ABC[0:8]}${ABC[0:8]}, 0);
242          if (c & 2) {
243            *((uint16_t*) output) = (uint16_t) vout${ABC[0:4]};
244            vout${ABC[0:4]} >>= 16;
245            output += 2;
246          }
247          if (c & 1) {
248            *output = (${XINT8_T}) vout${ABC[0:4]};
249            output += 1;
250          }
251          c = 0;
252        }
253      $else:
254        if (c & 4) {
255          *((float*) output) = wasm_f32x4_extract_lane(vout${ABC[0:8]}${ABC[0:8]}, 0);
256          vout${ABC[0:8]}${ABC[0:8]} = wasm_u64x2_shr(vout${ABC[0:8]}${ABC[0:8]}, 32);
257          output += 4;
258        }
259        uint32_t vout${ABC[0:4]} = wasm_i32x4_extract_lane(vout${ABC[0:8]}${ABC[0:8]}, 0);
260        if (c & 2) {
261          *((uint16_t*) output) = (uint16_t) vout${ABC[0:4]};
262          vout${ABC[0:4]} >>= 16;
263          output += 2;
264        }
265        if (c & 1) {
266          *output = (${XINT8_T}) vout${ABC[0:4]};
267          output += 1;
268        }
269      }${" while (c != 0);" if CHANNEL_TILE > 8 else ""}
270    }
271
272    output = (${XINT8_T}*) ((uintptr_t) output + output_increment);
273  } while (--output_width != 0);
274}
275