xref: /aosp_15_r20/external/XNNPACK/src/qs8-dwconv/unipass-neon-mul8.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert REQUANTIZATION in ["FP32", "RNDNU"]
7$assert DATATYPE in ["QC8", "QS8"]
8$assert DATATYPE != "QC8" or REQUANTIZATION == "FP32"
9$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
10$assert LOAD_VARIANT in ["LD64", "LD128"]
11$assert CHANNEL_TILE % {"LD64": 8, "LD128": 16}[LOAD_VARIANT] == 0
12$assert CHANNEL_TILE >= 8
13$assert KERNEL_TILE >= 2
14#include <assert.h>
15
16#include <arm_neon.h>
17
18#include <xnnpack/dwconv.h>
19$if REQUANTIZATION == "FP32" and ARMV8:
20  #include <xnnpack/intrinsics-polyfill.h>
21
22
23$PARAMS_STRUCT = REQUANTIZATION.lower() + "_" + ("neonv8" if ARMV8 else "neon")
24$PARAMS_UNION = "xnn_%s_conv_minmax_params" % DATATYPE.lower()
25$ISA = "neonv8" if ARMV8 else "neon"
26void xnn_${DATATYPE.lower()}_dwconv_minmax_${REQUANTIZATION.lower()}_ukernel_up${CHANNEL_TILE}x${KERNEL_TILE}__${ISA}_${"mla8" if MLA else "mul8"}_${LOAD_VARIANT.lower()}(
27    size_t channels,
28    size_t output_width,
29    const int8_t** input,
30    const void* weights,
31    int8_t* output,
32    size_t input_stride,
33    size_t output_increment,
34    size_t input_offset,
35    const int8_t* zero,
36    const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
37{
38  assert(channels != 0);
39  assert(output_width != 0);
40
41  $if REQUANTIZATION == "RNDNU":
42    const int32x4_t vright_pre_shift = vld1q_dup_s32(&params->${PARAMS_STRUCT}.right_pre_shift);
43    const int32x4_t vmultiplier = vld1q_dup_s32(&params->${PARAMS_STRUCT}.multiplier);
44    const int32x4_t vright_post_shift = vld1q_dup_s32(&params->${PARAMS_STRUCT}.right_post_shift);
45  $elif REQUANTIZATION == "FP32":
46    $if DATATYPE != "QC8":
47      const float32x4_t vscale = vld1q_dup_f32(&params->${PARAMS_STRUCT}.scale);
48    $if not ARMV8:
49      const float32x4_t vmagic_bias = vld1q_dup_f32(&params->${PARAMS_STRUCT}.magic_bias);
50      const int32x4_t vmagic_bias_less_output_zero_point = vld1q_dup_s32(&params->${PARAMS_STRUCT}.magic_bias_less_output_zero_point);
51  $if REQUANTIZATION != "FP32" or ARMV8:
52    const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->${PARAMS_STRUCT}.output_zero_point);
53  $if CHANNEL_TILE == 8:
54    const int8x8_t voutput_min = vld1_dup_s8(&params->${PARAMS_STRUCT}.output_min);
55    const int8x8_t voutput_max = vld1_dup_s8(&params->${PARAMS_STRUCT}.output_max);
56  $else:
57    const int8x16_t voutput_min = vld1q_dup_s8(&params->${PARAMS_STRUCT}.output_min);
58    const int8x16_t voutput_max = vld1q_dup_s8(&params->${PARAMS_STRUCT}.output_max);
59  do {
60    $for K in range(KERNEL_TILE):
61      const int8_t* i${K} = input[${K}];
62      assert(i${K} != NULL);
63      if XNN_UNPREDICTABLE(i${K} != zero) {
64        i${K} = (const int8_t*) ((uintptr_t) i${K} + input_offset);
65      }
66    input = (const int8_t**) ((uintptr_t) input + input_stride);
67
68    size_t c = channels;
69    const void* w = weights;
70    for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) {
71      $for C in range(0, CHANNEL_TILE, 4):
72        int32x4_t vacc${ABC[C:C+4]} = vld1q_s32(w); w = (const void*) ((const int32_t*) w + 4);
73
74      $for K in range(KERNEL_TILE):
75        $if LOAD_VARIANT == "LD128":
76          $for C in range(0, CHANNEL_TILE, 16):
77            const int8x16_t vi${K}x${ABC[C:C+16]} = vld1q_s8(i${K}); i${K} += 16;
78            const int8x16_t vk${K}x${ABC[C:C+16]} = vld1q_s8(w); w = (const void*) ((const int8_t*) w + 16);
79
80          $if K == 0:
81            $for C in range(0, CHANNEL_TILE, 16):
82              int16x8_t vprod${ABC[C:C+8]} = vmull_s8(vget_low_s8(vi${K}x${ABC[C:C+16]}), vget_low_s8(vk${K}x${ABC[C:C+16]}));
83              int16x8_t vprod${ABC[C+8:C+16]} = vmull_s8(vget_high_s8(vi${K}x${ABC[C:C+16]}), vget_high_s8(vk${K}x${ABC[C:C+16]}));
84          $elif K % 2 == 0 or K + 1 == KERNEL_TILE or not MLA:
85            $for C in range(0, CHANNEL_TILE, 16):
86              vprod${ABC[C:C+8]} = vmull_s8(vget_low_s8(vi${K}x${ABC[C:C+16]}), vget_low_s8(vk${K}x${ABC[C:C+16]}));
87              vprod${ABC[C+8:C+16]} = vmull_s8(vget_high_s8(vi${K}x${ABC[C:C+16]}), vget_high_s8(vk${K}x${ABC[C:C+16]}));
88          $else:
89            $for C in range(0, CHANNEL_TILE, 16):
90              vprod${ABC[C:C+8]} = vmlal_s8(vprod${ABC[C:C+8]}, vget_low_s8(vi${K}x${ABC[C:C+16]}), vget_low_s8(vk${K}x${ABC[C:C+16]}));
91              vprod${ABC[C+8:C+16]} = vmlal_s8(vprod${ABC[C+8:C+16]}, vget_high_s8(vi${K}x${ABC[C:C+16]}), vget_high_s8(vk${K}x${ABC[C:C+16]}));
92        $else:
93          $for C in range(0, CHANNEL_TILE, 8):
94            const int8x8_t vi${K}x${ABC[C:C+8]} = vld1_s8(i${K}); i${K} += 8;
95            const int8x8_t vk${K}x${ABC[C:C+8]} = vld1_s8(w); w = (const void*) ((const int8_t*) w + 8);
96
97          $if K == 0:
98            $for C in range(0, CHANNEL_TILE, 8):
99              int16x8_t vprod${ABC[C:C+8]} = vmull_s8(vi${K}x${ABC[C:C+8]}, vk${K}x${ABC[C:C+8]});
100          $elif K % 2 == 0 or K + 1 == KERNEL_TILE or not MLA:
101            $for C in range(0, CHANNEL_TILE, 8):
102              vprod${ABC[C:C+8]} = vmull_s8(vi${K}x${ABC[C:C+8]}, vk${K}x${ABC[C:C+8]});
103          $else:
104            $for C in range(0, CHANNEL_TILE, 8):
105              vprod${ABC[C:C+8]} = vmlal_s8(vprod${ABC[C:C+8]}, vi${K}x${ABC[C:C+8]}, vk${K}x${ABC[C:C+8]});
106
107        $if not MLA or K % 2 == 1 or K + 1 == KERNEL_TILE:
108          $for C in range(0, CHANNEL_TILE, 8):
109            vacc${ABC[C:C+4]} = vaddw_s16(vacc${ABC[C:C+4]}, vget_low_s16(vprod${ABC[C:C+8]}));
110            vacc${ABC[C+4:C+8]} = vaddw_s16(vacc${ABC[C+4:C+8]}, vget_high_s16(vprod${ABC[C:C+8]}));
111
112      $if REQUANTIZATION == "RNDNU":
113        $for C in range(0, CHANNEL_TILE, 4):
114          vacc${ABC[C:C+4]} = vqshlq_s32(vacc${ABC[C:C+4]}, vright_pre_shift);
115
116        $for C in range(0, CHANNEL_TILE, 4):
117          vacc${ABC[C:C+4]} = vqdmulhq_s32(vacc${ABC[C:C+4]}, vmultiplier);
118
119        $for C in range(0, CHANNEL_TILE, 4):
120          vacc${ABC[C:C+4]} = vrshlq_s32(vacc${ABC[C:C+4]}, vright_post_shift);
121      $elif REQUANTIZATION == "FP32":
122        $for C in range(0, CHANNEL_TILE, 4):
123          float32x4_t vfpacc${ABC[C:C+4]} = vcvtq_f32_s32(vacc${ABC[C:C+4]});
124
125        $if DATATYPE == "QC8":
126          $for C in range(0, CHANNEL_TILE, 4):
127            const float32x4_t vscale${ABC[C:C+4]} = vld1q_f32((const float*) w); w = (const void*) ((const float*) w + 4);
128
129          $for C in range(0, CHANNEL_TILE, 4):
130            vfpacc${ABC[C:C+4]} = vmulq_f32(vfpacc${ABC[C:C+4]}, vscale${ABC[C:C+4]});
131        $else:
132          $for C in range(0, CHANNEL_TILE, 4):
133            vfpacc${ABC[C:C+4]} = vmulq_f32(vfpacc${ABC[C:C+4]}, vscale);
134
135        $if ARMV8:
136          $for C in range(0, CHANNEL_TILE, 4):
137            vacc${ABC[C:C+4]} = vcvtnq_s32_f32(vfpacc${ABC[C:C+4]});
138        $else:
139          $for C in range(0, CHANNEL_TILE, 4):
140            vacc${ABC[C:C+4]} = vreinterpretq_s32_f32(vaddq_f32(vfpacc${ABC[C:C+4]}, vmagic_bias));
141
142          $for C in range(0, CHANNEL_TILE, 4):
143            vacc${ABC[C:C+4]} = vqsubq_s32(vacc${ABC[C:C+4]}, vmagic_bias_less_output_zero_point);
144
145#if XNN_ARCH_ARM64
146      $for C in range(0, CHANNEL_TILE, 8):
147        int16x8_t vacc${ABC[C:C+8]} = vqmovn_high_s32(vqmovn_s32(vacc${ABC[C:C+4]}), vacc${ABC[C+4:C+8]});
148
149      $if REQUANTIZATION != "FP32" or ARMV8:
150        $for C in range(0, CHANNEL_TILE, 8):
151          vacc${ABC[C:C+8]} = vqaddq_s16(vacc${ABC[C:C+8]}, voutput_zero_point);
152
153      $for C in range(0, CHANNEL_TILE, 16):
154        $if C + 8 < CHANNEL_TILE:
155          int8x16_t vout${ABC[C:C+16]} = vqmovn_high_s16(vqmovn_s16(vacc${ABC[C:C+8]}), vacc${ABC[C+8:C+16]});
156        $else:
157          int8x8_t vout${ABC[C:C+8]} = vqmovn_s16(vacc${ABC[C:C+8]});
158#else  // !XNN_ARCH_ARM64
159      $for C in range(0, CHANNEL_TILE, 8):
160        int16x8_t vacc${ABC[C:C+8]} = vcombine_s16(vqmovn_s32(vacc${ABC[C:C+4]}), vqmovn_s32(vacc${ABC[C+4:C+8]}));
161
162      $if REQUANTIZATION != "FP32" or ARMV8:
163        $for C in range(0, CHANNEL_TILE, 8):
164          vacc${ABC[C:C+8]} = vqaddq_s16(vacc${ABC[C:C+8]}, voutput_zero_point);
165
166      $for C in range(0, CHANNEL_TILE, 16):
167        $if C + 8 < CHANNEL_TILE:
168          int8x16_t vout${ABC[C:C+16]} = vcombine_s8(vqmovn_s16(vacc${ABC[C:C+8]}), vqmovn_s16(vacc${ABC[C+8:C+16]}));
169        $else:
170          int8x8_t vout${ABC[C:C+8]} = vqmovn_s16(vacc${ABC[C:C+8]});
171#endif  // !XNN_ARCH_ARM64
172
173      $for C in range(0, CHANNEL_TILE, 16):
174        $if C + 8 < CHANNEL_TILE:
175          vout${ABC[C:C+16]} = vmaxq_s8(vout${ABC[C:C+16]}, voutput_min);
176        $elif CHANNEL_TILE == 8:
177          vout${ABC[C:C+8]} = vmax_s8(vout${ABC[C:C+8]}, voutput_min);
178        $else:
179          vout${ABC[C:C+8]} = vmax_s8(vout${ABC[C:C+8]}, vget_low_s8(voutput_min));
180
181      $for C in range(0, CHANNEL_TILE, 16):
182        $if C + 8 < CHANNEL_TILE:
183          vout${ABC[C:C+16]} = vminq_s8(vout${ABC[C:C+16]}, voutput_max);
184        $elif CHANNEL_TILE == 8:
185          vout${ABC[C:C+8]} = vmin_s8(vout${ABC[C:C+8]}, voutput_max);
186        $else:
187          vout${ABC[C:C+8]} = vmin_s8(vout${ABC[C:C+8]}, vget_low_s8(voutput_max));
188
189      $for C in range(0, CHANNEL_TILE, 16):
190        $if C + 8 < CHANNEL_TILE:
191          vst1q_s8(output, vout${ABC[C:C+16]}); output += 16;
192        $else:
193          vst1_s8(output, vout${ABC[C:C+8]}); output += 8;
194    }
195    if XNN_UNLIKELY(c != 0) {
196      $if CHANNEL_TILE > 8:
197        const int8_t* k = (const int8_t*) ((const int32_t*) w + ${CHANNEL_TILE});
198      ${"do " if CHANNEL_TILE > 8 else ""}{
199        int32x4_t vacc${ABC[0:4]} = vld1q_s32(w); w = (const void*) ((const int32_t*) w + 4);
200        int32x4_t vacc${ABC[4:8]} = vld1q_s32(w); w = (const void*) ((const int32_t*) w + 4);
201
202        $for K in range(KERNEL_TILE):
203          $if CHANNEL_TILE > 8:
204            const int8x8_t vi${K}x${ABC[0:8]} = vld1_s8(i${K}); i${K} += 8;
205          $else:
206            const int8x8_t vi${K}x${ABC[0:8]} = vld1_s8(i${K});
207          $if CHANNEL_TILE > 8:
208            $if K == 0:
209              const int8x8_t vk${K}x${ABC[0:8]} = vld1_s8(k); k += 8;
210            $else:
211              const int8x8_t vk${K}x${ABC[0:8]} = vld1_s8((const void*) (k + ${K * CHANNEL_TILE - 8}));
212          $else:
213            $if K == 0:
214              const int8x8_t vk${K}x${ABC[0:8]} = vld1_s8(w);
215            $else:
216              const int8x8_t vk${K}x${ABC[0:8]} = vld1_s8((const void*) ((const int8_t*) w + ${K * CHANNEL_TILE}));
217
218          $if K == 0:
219            int16x8_t vprod${ABC[0:8]} = vmull_s8(vi${K}x${ABC[0:8]}, vk${K}x${ABC[0:8]});
220          $elif K % 2 == 0 or K + 1 == KERNEL_TILE or not MLA:
221            vprod${ABC[0:8]} = vmull_s8(vi${K}x${ABC[0:8]}, vk${K}x${ABC[0:8]});
222          $else:
223            vprod${ABC[0:8]} = vmlal_s8(vprod${ABC[0:8]}, vi${K}x${ABC[0:8]}, vk${K}x${ABC[0:8]});
224
225          $if not MLA or K % 2 == 1 or K + 1 == KERNEL_TILE:
226            vacc${ABC[0:4]} = vaddw_s16(vacc${ABC[0:4]}, vget_low_s16(vprod${ABC[0:8]}));
227            vacc${ABC[4:8]} = vaddw_s16(vacc${ABC[4:8]}, vget_high_s16(vprod${ABC[0:8]}));
228
229        $if REQUANTIZATION == "RNDNU":
230          vacc${ABC[0:4]} = vqshlq_s32(vacc${ABC[0:4]}, vright_pre_shift);
231          vacc${ABC[4:8]} = vqshlq_s32(vacc${ABC[4:8]}, vright_pre_shift);
232
233          vacc${ABC[0:4]} = vqdmulhq_s32(vacc${ABC[0:4]}, vmultiplier);
234          vacc${ABC[4:8]} = vqdmulhq_s32(vacc${ABC[4:8]}, vmultiplier);
235
236          vacc${ABC[0:4]} = vrshlq_s32(vacc${ABC[0:4]}, vright_post_shift);
237          vacc${ABC[4:8]} = vrshlq_s32(vacc${ABC[4:8]}, vright_post_shift);
238        $elif REQUANTIZATION == "FP32":
239          float32x4_t vfpacc${ABC[0:4]} = vcvtq_f32_s32(vacc${ABC[0:4]});
240          float32x4_t vfpacc${ABC[4:8]} = vcvtq_f32_s32(vacc${ABC[4:8]});
241
242          $if DATATYPE == "QC8":
243            const float32x4_t vscale${ABC[0:4]} = vld1q_f32((const float*) ((uintptr_t) w + ${CHANNEL_TILE - 8} * sizeof(int32_t) + ${CHANNEL_TILE * KERNEL_TILE} * sizeof(int8_t)));
244            const float32x4_t vscale${ABC[4:8]} = vld1q_f32((const float*) ((uintptr_t) w + ${CHANNEL_TILE - 8} * sizeof(int32_t) + ${CHANNEL_TILE * KERNEL_TILE} * sizeof(int8_t) + 4 * sizeof(float)));
245            vfpacc${ABC[0:4]} = vmulq_f32(vfpacc${ABC[0:4]}, vscale${ABC[0:4]});
246            vfpacc${ABC[4:8]} = vmulq_f32(vfpacc${ABC[4:8]}, vscale${ABC[4:8]});
247          $else:
248            vfpacc${ABC[0:4]} = vmulq_f32(vfpacc${ABC[0:4]}, vscale);
249            vfpacc${ABC[4:8]} = vmulq_f32(vfpacc${ABC[4:8]}, vscale);
250
251          $if ARMV8:
252            vacc${ABC[0:4]} = vcvtnq_s32_f32(vfpacc${ABC[0:4]});
253            vacc${ABC[4:8]} = vcvtnq_s32_f32(vfpacc${ABC[4:8]});
254          $else:
255            vacc${ABC[0:4]} = vreinterpretq_s32_f32(vaddq_f32(vfpacc${ABC[0:4]}, vmagic_bias));
256            vacc${ABC[4:8]} = vreinterpretq_s32_f32(vaddq_f32(vfpacc${ABC[4:8]}, vmagic_bias));
257
258            vacc${ABC[0:4]} = vqsubq_s32(vacc${ABC[0:4]}, vmagic_bias_less_output_zero_point);
259            vacc${ABC[4:8]} = vqsubq_s32(vacc${ABC[4:8]}, vmagic_bias_less_output_zero_point);
260
261#if XNN_ARCH_ARM64
262        int16x8_t vacc${ABC[0:8]} = vqmovn_high_s32(vqmovn_s32(vacc${ABC[0:4]}), vacc${ABC[4:8]});
263#else
264        int16x8_t vacc${ABC[0:8]} = vcombine_s16(vqmovn_s32(vacc${ABC[0:4]}), vqmovn_s32(vacc${ABC[4:8]}));
265#endif
266        $if REQUANTIZATION != "FP32" or ARMV8:
267          vacc${ABC[0:8]} = vqaddq_s16(vacc${ABC[0:8]}, voutput_zero_point);
268
269        int8x8_t vout${ABC[0:8]} = vqmovn_s16(vacc${ABC[0:8]});
270        $if CHANNEL_TILE == 8:
271          vout${ABC[0:8]} = vmax_s8(vout${ABC[0:8]}, voutput_min);
272          vout${ABC[0:8]} = vmin_s8(vout${ABC[0:8]}, voutput_max);
273        $else:
274          vout${ABC[0:8]} = vmax_s8(vout${ABC[0:8]}, vget_low_s8(voutput_min));
275          vout${ABC[0:8]} = vmin_s8(vout${ABC[0:8]}, vget_low_s8(voutput_max));
276
277        $if CHANNEL_TILE > 8:
278          if XNN_LIKELY(c >= 8) {
279            vst1_s8(output, vout${ABC[0:8]}); output += 8;
280            c -= 8;
281          } else {
282            if (c & 4) {
283              vst1_lane_u32((void*) output, vreinterpret_u32_s8(vout${ABC[0:8]}), 0); output += 4;
284              vout${ABC[0:8]} = vext_s8(vout${ABC[0:8]}, vout${ABC[0:8]}, 4);
285            }
286            if (c & 2) {
287              vst1_lane_u16((void*) output, vreinterpret_u16_s8(vout${ABC[0:8]}), 0); output += 2;
288              vout${ABC[0:8]} = vext_s8(vout${ABC[0:8]}, vout${ABC[0:8]}, 2);
289            }
290            if (c & 1) {
291              vst1_lane_s8(output, vout${ABC[0:8]}, 0); output += 1;
292            }
293            c = 0;
294          }
295        $else:
296          if (c & 4) {
297            vst1_lane_u32((void*) output, vreinterpret_u32_s8(vout${ABC[0:8]}), 0); output += 4;
298            vout${ABC[0:8]} = vext_s8(vout${ABC[0:8]}, vout${ABC[0:8]}, 4);
299          }
300          if (c & 2) {
301            vst1_lane_u16((void*) output, vreinterpret_u16_s8(vout${ABC[0:8]}), 0); output += 2;
302            vout${ABC[0:8]} = vext_s8(vout${ABC[0:8]}, vout${ABC[0:8]}, 2);
303          }
304          if (c & 1) {
305            vst1_lane_s8(output, vout${ABC[0:8]}, 0); output += 1;
306          }
307      }${" while (c != 0);" if CHANNEL_TILE > 8 else ""}
308    }
309
310    output = (int8_t*) ((uintptr_t) output + output_increment);
311  } while (--output_width != 0);
312}
313