xref: /aosp_15_r20/external/XNNPACK/src/s16-window/neon.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker// Copyright 2022 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker//
3*4bdc9457SAndroid Build Coastguard Worker// This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker// LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker
6*4bdc9457SAndroid Build Coastguard Worker$assert BATCH_TILE % 8 == 0
7*4bdc9457SAndroid Build Coastguard Worker$assert BATCH_TILE >= 8
8*4bdc9457SAndroid Build Coastguard Worker$SIMD_TILE = BATCH_TILE // 8
9*4bdc9457SAndroid Build Coastguard Worker#include <assert.h>
10*4bdc9457SAndroid Build Coastguard Worker#include <stddef.h>
11*4bdc9457SAndroid Build Coastguard Worker#include <stdint.h>
12*4bdc9457SAndroid Build Coastguard Worker
13*4bdc9457SAndroid Build Coastguard Worker#include <arm_neon.h>
14*4bdc9457SAndroid Build Coastguard Worker
15*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/math.h>
16*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/window.h>
17*4bdc9457SAndroid Build Coastguard Worker
18*4bdc9457SAndroid Build Coastguard Worker$SHIFT_VARIANT = "_shift%s" % SHIFT if SHIFT else ""
19*4bdc9457SAndroid Build Coastguard Worker
20*4bdc9457SAndroid Build Coastguard Workervoid xnn_s16_window${SHIFT_VARIANT}_ukernel__neon_x${BATCH_TILE}(
21*4bdc9457SAndroid Build Coastguard Worker    size_t rows,
22*4bdc9457SAndroid Build Coastguard Worker    size_t batch_size,
23*4bdc9457SAndroid Build Coastguard Worker    const int16_t* input,
24*4bdc9457SAndroid Build Coastguard Worker    const int16_t* weights,
25*4bdc9457SAndroid Build Coastguard Worker    int16_t* output,
26*4bdc9457SAndroid Build Coastguard Worker    uint32_t shift)
27*4bdc9457SAndroid Build Coastguard Worker{
28*4bdc9457SAndroid Build Coastguard Worker  assert(rows != 0);
29*4bdc9457SAndroid Build Coastguard Worker  assert(batch_size != 0);
30*4bdc9457SAndroid Build Coastguard Worker  assert(input != NULL);
31*4bdc9457SAndroid Build Coastguard Worker  assert(weights != NULL);
32*4bdc9457SAndroid Build Coastguard Worker  assert(output != NULL);
33*4bdc9457SAndroid Build Coastguard Worker  $if SHIFT != 0:
34*4bdc9457SAndroid Build Coastguard Worker    assert(shift == ${SHIFT});
35*4bdc9457SAndroid Build Coastguard Worker  $else:
36*4bdc9457SAndroid Build Coastguard Worker    assert(shift < 32);
37*4bdc9457SAndroid Build Coastguard Worker
38*4bdc9457SAndroid Build Coastguard Worker  $if SHIFT == 0:
39*4bdc9457SAndroid Build Coastguard Worker    const int32x4_t vshift = vdupq_n_s32(-(int32_t)shift);  // negative to shift right.
40*4bdc9457SAndroid Build Coastguard Worker
41*4bdc9457SAndroid Build Coastguard Worker  do {
42*4bdc9457SAndroid Build Coastguard Worker    const int16_t* w = weights;
43*4bdc9457SAndroid Build Coastguard Worker    size_t n = batch_size * sizeof(int16_t);
44*4bdc9457SAndroid Build Coastguard Worker    $if BATCH_TILE > 8:
45*4bdc9457SAndroid Build Coastguard Worker      for (; n >= ${BATCH_TILE} * sizeof(int16_t); n -= ${BATCH_TILE} * sizeof(int16_t)) {
46*4bdc9457SAndroid Build Coastguard Worker        $for N in range(SIMD_TILE):
47*4bdc9457SAndroid Build Coastguard Worker          const int16x8_t vi${N} = vld1q_s16(input); input += 8;
48*4bdc9457SAndroid Build Coastguard Worker
49*4bdc9457SAndroid Build Coastguard Worker        $for N in range(SIMD_TILE):
50*4bdc9457SAndroid Build Coastguard Worker          const int16x8_t vw${N} = vld1q_s16(w); w += 8;
51*4bdc9457SAndroid Build Coastguard Worker
52*4bdc9457SAndroid Build Coastguard Worker        $if SHIFT == 15:
53*4bdc9457SAndroid Build Coastguard Worker          $for N in range(SIMD_TILE):
54*4bdc9457SAndroid Build Coastguard Worker            const int16x8_t vout${N} = vqdmulhq_s16(vi${N}, vw${N});
55*4bdc9457SAndroid Build Coastguard Worker        $else:
56*4bdc9457SAndroid Build Coastguard Worker          $for N in range(SIMD_TILE):
57*4bdc9457SAndroid Build Coastguard Worker            int32x4_t vacc${N}_lo = vmull_s16(vget_low_s16(vi${N}), vget_low_s16(vw${N}));
58*4bdc9457SAndroid Build Coastguard Worker            int32x4_t vacc${N}_hi = vmull_s16(vget_high_s16(vi${N}), vget_high_s16(vw${N}));
59*4bdc9457SAndroid Build Coastguard Worker
60*4bdc9457SAndroid Build Coastguard Worker          $if SHIFT != 0:
61*4bdc9457SAndroid Build Coastguard Worker            $for N in range(SIMD_TILE):
62*4bdc9457SAndroid Build Coastguard Worker              const int16x4_t vshift${N}_lo = vqshrn_n_s32(vacc${N}_lo, ${SHIFT});
63*4bdc9457SAndroid Build Coastguard Worker              const int16x4_t vshift${N}_hi = vqshrn_n_s32(vacc${N}_hi, ${SHIFT});
64*4bdc9457SAndroid Build Coastguard Worker
65*4bdc9457SAndroid Build Coastguard Worker            $for N in range(SIMD_TILE):
66*4bdc9457SAndroid Build Coastguard Worker              const int16x8_t vout${N} = vcombine_s16(vshift${N}_lo, vshift${N}_hi);
67*4bdc9457SAndroid Build Coastguard Worker          $else:
68*4bdc9457SAndroid Build Coastguard Worker            $for N in range(SIMD_TILE):
69*4bdc9457SAndroid Build Coastguard Worker              vacc${N}_lo = vshlq_s32(vacc${N}_lo, vshift);
70*4bdc9457SAndroid Build Coastguard Worker              vacc${N}_hi = vshlq_s32(vacc${N}_hi, vshift);
71*4bdc9457SAndroid Build Coastguard Worker
72*4bdc9457SAndroid Build Coastguard Worker            $for N in range(SIMD_TILE):
73*4bdc9457SAndroid Build Coastguard Worker              const int16x8_t vout${N} = vcombine_s16(vqmovn_s32(vacc${N}_lo), vqmovn_s32(vacc${N}_hi));
74*4bdc9457SAndroid Build Coastguard Worker
75*4bdc9457SAndroid Build Coastguard Worker        $for N in range(SIMD_TILE):
76*4bdc9457SAndroid Build Coastguard Worker          vst1q_s16(output, vout${N}); output += 8;
77*4bdc9457SAndroid Build Coastguard Worker      }
78*4bdc9457SAndroid Build Coastguard Worker
79*4bdc9457SAndroid Build Coastguard Worker    // Remainder of full vectors
80*4bdc9457SAndroid Build Coastguard Worker    for (; n >= 8 * sizeof(int16_t); n -= 8 * sizeof(int16_t)) {
81*4bdc9457SAndroid Build Coastguard Worker      const int16x8_t vi = vld1q_s16(input); input += 8;
82*4bdc9457SAndroid Build Coastguard Worker      const int16x8_t vw = vld1q_s16(w); w += 8;
83*4bdc9457SAndroid Build Coastguard Worker      $if SHIFT == 15:
84*4bdc9457SAndroid Build Coastguard Worker        const int16x8_t vout = vqdmulhq_s16(vi, vw);
85*4bdc9457SAndroid Build Coastguard Worker      $else:
86*4bdc9457SAndroid Build Coastguard Worker        int32x4_t vacc_lo = vmull_s16(vget_low_s16(vi), vget_low_s16(vw));
87*4bdc9457SAndroid Build Coastguard Worker        int32x4_t vacc_hi = vmull_s16(vget_high_s16(vi), vget_high_s16(vw));
88*4bdc9457SAndroid Build Coastguard Worker        $if SHIFT != 0:
89*4bdc9457SAndroid Build Coastguard Worker          const int16x4_t vshift_lo = vqshrn_n_s32(vacc_lo, ${SHIFT});
90*4bdc9457SAndroid Build Coastguard Worker          const int16x4_t vshift_hi = vqshrn_n_s32(vacc_hi, ${SHIFT});
91*4bdc9457SAndroid Build Coastguard Worker          const int16x8_t vout = vcombine_s16(vshift_lo, vshift_hi);
92*4bdc9457SAndroid Build Coastguard Worker        $else:
93*4bdc9457SAndroid Build Coastguard Worker          vacc_lo = vshlq_s32(vacc_lo, vshift);
94*4bdc9457SAndroid Build Coastguard Worker          vacc_hi = vshlq_s32(vacc_hi, vshift);
95*4bdc9457SAndroid Build Coastguard Worker          const int16x8_t vout = vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi));
96*4bdc9457SAndroid Build Coastguard Worker      vst1q_s16(output, vout); output += 8;
97*4bdc9457SAndroid Build Coastguard Worker    }
98*4bdc9457SAndroid Build Coastguard Worker
99*4bdc9457SAndroid Build Coastguard Worker    assert(n % 2 == 0);
100*4bdc9457SAndroid Build Coastguard Worker    // Remainder of 1 to 7 batch_size
101*4bdc9457SAndroid Build Coastguard Worker    if XNN_UNLIKELY(n != 0) {
102*4bdc9457SAndroid Build Coastguard Worker      const int16x8_t vi = vld1q_s16(input); input = (const int16_t*) ((uintptr_t) input + n);
103*4bdc9457SAndroid Build Coastguard Worker      const int16x8_t vw = vld1q_s16(w);
104*4bdc9457SAndroid Build Coastguard Worker      $if SHIFT == 15:
105*4bdc9457SAndroid Build Coastguard Worker        int16x4_t vout = vqdmulh_s16(vget_low_s16(vi), vget_low_s16(vw));
106*4bdc9457SAndroid Build Coastguard Worker      $else:
107*4bdc9457SAndroid Build Coastguard Worker        int32x4_t vacc = vmull_s16(vget_low_s16(vi), vget_low_s16(vw));
108*4bdc9457SAndroid Build Coastguard Worker        $if SHIFT != 0:
109*4bdc9457SAndroid Build Coastguard Worker          int16x4_t vout = vqshrn_n_s32(vacc, ${SHIFT});
110*4bdc9457SAndroid Build Coastguard Worker        $else:
111*4bdc9457SAndroid Build Coastguard Worker          vacc = vshlq_s32(vacc, vshift);
112*4bdc9457SAndroid Build Coastguard Worker          int16x4_t vout = vqmovn_s32(vacc);
113*4bdc9457SAndroid Build Coastguard Worker      if (n & (4 * sizeof(int16_t))) {
114*4bdc9457SAndroid Build Coastguard Worker        vst1_s16(output, vout); output += 4;
115*4bdc9457SAndroid Build Coastguard Worker        $if SHIFT == 15:
116*4bdc9457SAndroid Build Coastguard Worker          vout = vqdmulh_s16(vget_high_s16(vi), vget_high_s16(vw));
117*4bdc9457SAndroid Build Coastguard Worker        $else:
118*4bdc9457SAndroid Build Coastguard Worker          vacc = vmull_s16(vget_high_s16(vi), vget_high_s16(vw));
119*4bdc9457SAndroid Build Coastguard Worker          $if SHIFT != 0:
120*4bdc9457SAndroid Build Coastguard Worker            vout = vqshrn_n_s32(vacc, ${SHIFT});
121*4bdc9457SAndroid Build Coastguard Worker          $else:
122*4bdc9457SAndroid Build Coastguard Worker            vacc = vshlq_s32(vacc, vshift);
123*4bdc9457SAndroid Build Coastguard Worker            vout = vqmovn_s32(vacc);
124*4bdc9457SAndroid Build Coastguard Worker      }
125*4bdc9457SAndroid Build Coastguard Worker      if (n & (2 * sizeof(int16_t))) {
126*4bdc9457SAndroid Build Coastguard Worker        vst1_lane_u32((void*) output, vreinterpret_u32_s16(vout), 0); output += 2;
127*4bdc9457SAndroid Build Coastguard Worker        vout = vext_s16(vout, vout, 2);
128*4bdc9457SAndroid Build Coastguard Worker      }
129*4bdc9457SAndroid Build Coastguard Worker      if (n & (1 * sizeof(int16_t))) {
130*4bdc9457SAndroid Build Coastguard Worker        vst1_lane_s16(output, vout, 0); output += 1;
131*4bdc9457SAndroid Build Coastguard Worker      }
132*4bdc9457SAndroid Build Coastguard Worker    }
133*4bdc9457SAndroid Build Coastguard Worker
134*4bdc9457SAndroid Build Coastguard Worker  } while (--rows != 0);
135*4bdc9457SAndroid Build Coastguard Worker}
136