1 // Auto-generated file. Do not edit!
2 // Template: src/s8-ibilinear/neon.c.in
3 // Generator: tools/xngen
4 //
5 // Copyright 2021 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9
10 #include <assert.h>
11
12 #include <arm_neon.h>
13
14 #include <xnnpack/common.h>
15 #include <xnnpack/ibilinear.h>
16
17
xnn_s8_ibilinear_ukernel__neon_c16(size_t output_pixels,size_t channels,const int8_t ** restrict input,size_t input_offset,const int16_t * restrict weights,int8_t * restrict output,size_t output_increment)18 void xnn_s8_ibilinear_ukernel__neon_c16(
19 size_t output_pixels,
20 size_t channels,
21 const int8_t**restrict input,
22 size_t input_offset,
23 const int16_t*restrict weights,
24 int8_t*restrict output,
25 size_t output_increment) XNN_OOB_READS
26 {
27 assert(output_pixels != 0);
28 assert(channels != 0);
29
30 do {
31 const int8_t* i0 = (const int8_t*) ((uintptr_t) input[0] + input_offset);
32 const int8_t* i1 = (const int8_t*) ((uintptr_t) input[1] + input_offset);
33 const int8_t* i2 = (const int8_t*) ((uintptr_t) input[2] + input_offset);
34 const int8_t* i3 = (const int8_t*) ((uintptr_t) input[3] + input_offset);
35 input += 4;
36
37 #if XNN_ARCH_ARM64
38 const int16x8_t valphah = vld1q_dup_s16(weights); weights += 1;
39 #else
40 const int16x4_t valphah = vld1_dup_s16(weights); weights += 1;
41 #endif
42 const int32x4_t valphav = vmovl_s16(vld1_dup_s16(weights)); weights += 1;
43
44 size_t c = channels;
45 for (; c >= 16 * sizeof(int8_t); c -= 16 * sizeof(int8_t)) {
46 const int8x8_t vtl01234567 = vld1_s8(i0); i0 += 8;
47 const int8x8_t vtr01234567 = vld1_s8(i1); i1 += 8;
48 const int8x8_t vbl01234567 = vld1_s8(i2); i2 += 8;
49 const int8x8_t vbr01234567 = vld1_s8(i3); i3 += 8;
50 const int8x8_t vtl89ABCDEF = vld1_s8(i0); i0 += 8;
51 const int8x8_t vtr89ABCDEF = vld1_s8(i1); i1 += 8;
52 const int8x8_t vbl89ABCDEF = vld1_s8(i2); i2 += 8;
53 const int8x8_t vbr89ABCDEF = vld1_s8(i3); i3 += 8;
54
55 const int16x8_t vtd01234567 = vsubl_s8(vtr01234567, vtl01234567);
56 const int16x8_t vbd01234567 = vsubl_s8(vbr01234567, vbl01234567);
57 const int16x8_t vdl01234567 = vsubl_s8(vbl01234567, vtl01234567);
58 const int16x8_t vxtl01234567 = vmovl_s8(vtl01234567);
59 const int16x8_t vtd89ABCDEF = vsubl_s8(vtr89ABCDEF, vtl89ABCDEF);
60 const int16x8_t vbd89ABCDEF = vsubl_s8(vbr89ABCDEF, vbl89ABCDEF);
61 const int16x8_t vdl89ABCDEF = vsubl_s8(vbl89ABCDEF, vtl89ABCDEF);
62 const int16x8_t vxtl89ABCDEF = vmovl_s8(vtl89ABCDEF);
63
64 const int16x8_t vdd01234567 = vsubq_s16(vbd01234567, vtd01234567);
65 const int16x8_t vdd89ABCDEF = vsubq_s16(vbd89ABCDEF, vtd89ABCDEF);
66
67 #if XNN_ARCH_ARM64
68 const int32x4_t vt0123 = vmlal_s16(vshll_n_s16(vget_low_s16(vxtl01234567), 11), vget_low_s16(vtd01234567), vget_low_s16(valphah));
69 const int32x4_t vt4567 = vmlal_high_s16(vshll_n_s16(vget_high_s16(vxtl01234567), 11), vtd01234567, valphah);
70 const int32x4_t vt89AB = vmlal_s16(vshll_n_s16(vget_low_s16(vxtl89ABCDEF), 11), vget_low_s16(vtd89ABCDEF), vget_low_s16(valphah));
71 const int32x4_t vtCDEF = vmlal_high_s16(vshll_n_s16(vget_high_s16(vxtl89ABCDEF), 11), vtd89ABCDEF, valphah);
72
73 const int32x4_t vd0123 = vmlal_s16(vshll_n_s16(vget_low_s16(vdl01234567), 11), vget_low_s16(vdd01234567), vget_low_s16(valphah));
74 const int32x4_t vd4567 = vmlal_high_s16(vshll_n_s16(vget_high_s16(vdl01234567), 11), vdd01234567, valphah);
75 const int32x4_t vd89AB = vmlal_s16(vshll_n_s16(vget_low_s16(vdl89ABCDEF), 11), vget_low_s16(vdd89ABCDEF), vget_low_s16(valphah));
76 const int32x4_t vdCDEF = vmlal_high_s16(vshll_n_s16(vget_high_s16(vdl89ABCDEF), 11), vdd89ABCDEF, valphah);
77 #else // !XNN_ARCH_ARM64
78 const int32x4_t vt0123 = vmlal_s16(vshll_n_s16(vget_low_s16(vxtl01234567), 11), vget_low_s16(vtd01234567), valphah);
79 const int32x4_t vt4567 = vmlal_s16(vshll_n_s16(vget_high_s16(vxtl01234567), 11), vget_high_s16(vtd01234567), valphah);
80 const int32x4_t vt89AB = vmlal_s16(vshll_n_s16(vget_low_s16(vxtl89ABCDEF), 11), vget_low_s16(vtd89ABCDEF), valphah);
81 const int32x4_t vtCDEF = vmlal_s16(vshll_n_s16(vget_high_s16(vxtl89ABCDEF), 11), vget_high_s16(vtd89ABCDEF), valphah);
82
83 const int32x4_t vd0123 = vmlal_s16(vshll_n_s16(vget_low_s16(vdl01234567), 11), vget_low_s16(vdd01234567), valphah);
84 const int32x4_t vd4567 = vmlal_s16(vshll_n_s16(vget_high_s16(vdl01234567), 11), vget_high_s16(vdd01234567), valphah);
85 const int32x4_t vd89AB = vmlal_s16(vshll_n_s16(vget_low_s16(vdl89ABCDEF), 11), vget_low_s16(vdd89ABCDEF), valphah);
86 const int32x4_t vdCDEF = vmlal_s16(vshll_n_s16(vget_high_s16(vdl89ABCDEF), 11), vget_high_s16(vdd89ABCDEF), valphah);
87 #endif // !XNN_ARCH_ARM64
88
89 const int32x4_t vacc0123 = vmlaq_s32(vshlq_n_s32(vt0123, 11), vd0123, valphav);
90 const int32x4_t vacc4567 = vmlaq_s32(vshlq_n_s32(vt4567, 11), vd4567, valphav);
91 const int32x4_t vacc89AB = vmlaq_s32(vshlq_n_s32(vt89AB, 11), vd89AB, valphav);
92 const int32x4_t vaccCDEF = vmlaq_s32(vshlq_n_s32(vtCDEF, 11), vdCDEF, valphav);
93
94 #if XNN_ARCH_ARM64
95 const int16x8_t vacc01234567 = vuzp2q_s16(vreinterpretq_s16_s32(vacc0123), vreinterpretq_s16_s32(vacc4567));
96 const int16x8_t vacc89ABCDEF = vuzp2q_s16(vreinterpretq_s16_s32(vacc89AB), vreinterpretq_s16_s32(vaccCDEF));
97 #else // !XNN_ARCH_ARM64
98 const int16x8_t vacc01234567 = vcombine_s16(vshrn_n_s32(vacc0123, 16), vshrn_n_s32(vacc4567, 16));
99 const int16x8_t vacc89ABCDEF = vcombine_s16(vshrn_n_s32(vacc89AB, 16), vshrn_n_s32(vaccCDEF, 16));
100 #endif // !XNN_ARCH_ARM64
101
102 const int8x8_t vo01234567 = vrshrn_n_s16(vacc01234567, 6);
103 const int8x8_t vo89ABCDEF = vrshrn_n_s16(vacc89ABCDEF, 6);
104
105 vst1_s8(output, vo01234567); output += 8;
106 vst1_s8(output, vo89ABCDEF); output += 8;
107 }
108 for (; c >= 8 * sizeof(int8_t); c -= 8 * sizeof(int8_t)) {
109 const int8x8_t vtl01234567 = vld1_s8(i0); i0 += 8;
110 const int8x8_t vtr01234567 = vld1_s8(i1); i1 += 8;
111 const int8x8_t vbl01234567 = vld1_s8(i2); i2 += 8;
112 const int8x8_t vbr01234567 = vld1_s8(i3); i3 += 8;
113
114 const int16x8_t vtd01234567 = vsubl_s8(vtr01234567, vtl01234567);
115 const int16x8_t vbd01234567 = vsubl_s8(vbr01234567, vbl01234567);
116 const int16x8_t vdl01234567 = vsubl_s8(vbl01234567, vtl01234567);
117 const int16x8_t vxtl01234567 = vmovl_s8(vtl01234567);
118
119 const int16x8_t vdd01234567 = vsubq_s16(vbd01234567, vtd01234567);
120
121 #if XNN_ARCH_ARM64
122 const int32x4_t vt0123 = vmlal_s16(vshll_n_s16(vget_low_s16(vxtl01234567), 11), vget_low_s16(vtd01234567), vget_low_s16(valphah));
123 const int32x4_t vt4567 = vmlal_high_s16(vshll_n_s16(vget_high_s16(vxtl01234567), 11), vtd01234567, valphah);
124
125 const int32x4_t vd0123 = vmlal_s16(vshll_n_s16(vget_low_s16(vdl01234567), 11), vget_low_s16(vdd01234567), vget_low_s16(valphah));
126 const int32x4_t vd4567 = vmlal_high_s16(vshll_n_s16(vget_high_s16(vdl01234567), 11), vdd01234567, valphah);
127 #else // !XNN_ARCH_ARM64
128 const int32x4_t vt0123 = vmlal_s16(vshll_n_s16(vget_low_s16(vxtl01234567), 11), vget_low_s16(vtd01234567), valphah);
129 const int32x4_t vt4567 = vmlal_s16(vshll_n_s16(vget_high_s16(vxtl01234567), 11), vget_high_s16(vtd01234567), valphah);
130
131 const int32x4_t vd0123 = vmlal_s16(vshll_n_s16(vget_low_s16(vdl01234567), 11), vget_low_s16(vdd01234567), valphah);
132 const int32x4_t vd4567 = vmlal_s16(vshll_n_s16(vget_high_s16(vdl01234567), 11), vget_high_s16(vdd01234567), valphah);
133 #endif // !XNN_ARCH_ARM64
134
135 const int32x4_t vacc0123 = vmlaq_s32(vshlq_n_s32(vt0123, 11), vd0123, valphav);
136 const int32x4_t vacc4567 = vmlaq_s32(vshlq_n_s32(vt4567, 11), vd4567, valphav);
137
138 #if XNN_ARCH_ARM64
139 const int16x8_t vacc01234567 = vuzp2q_s16(vreinterpretq_s16_s32(vacc0123), vreinterpretq_s16_s32(vacc4567));
140 #else // !XNN_ARCH_ARM64
141 const int16x8_t vacc01234567 = vcombine_s16(vshrn_n_s32(vacc0123, 16), vshrn_n_s32(vacc4567, 16));
142 #endif // !XNN_ARCH_ARM64
143
144 const int8x8_t vo01234567 = vrshrn_n_s16(vacc01234567, 6);
145
146 vst1_s8(output, vo01234567); output += 8;
147 }
148 if XNN_UNLIKELY(c != 0) {
149 const int8x8_t vtl01234567 = vld1_s8(i0);
150 const int8x8_t vtr01234567 = vld1_s8(i1);
151 const int8x8_t vbl01234567 = vld1_s8(i2);
152 const int8x8_t vbr01234567 = vld1_s8(i3);
153
154 const int16x8_t vtd01234567 = vsubl_s8(vtr01234567, vtl01234567);
155 const int16x8_t vbd01234567 = vsubl_s8(vbr01234567, vbl01234567);
156 const int16x8_t vdl01234567 = vsubl_s8(vbl01234567, vtl01234567);
157 const int16x8_t vxtl01234567 = vmovl_s8(vtl01234567);
158
159 const int16x8_t vdd01234567 = vsubq_s16(vbd01234567, vtd01234567);
160
161 #if XNN_ARCH_ARM64
162 const int32x4_t vt0123 = vmlal_s16(vshll_n_s16(vget_low_s16(vxtl01234567), 11), vget_low_s16(vtd01234567), vget_low_s16(valphah));
163 const int32x4_t vt4567 = vmlal_high_s16(vshll_n_s16(vget_high_s16(vxtl01234567), 11), vtd01234567, valphah);
164
165 const int32x4_t vd0123 = vmlal_s16(vshll_n_s16(vget_low_s16(vdl01234567), 11), vget_low_s16(vdd01234567), vget_low_s16(valphah));
166 const int32x4_t vd4567 = vmlal_high_s16(vshll_n_s16(vget_high_s16(vdl01234567), 11), vdd01234567, valphah);
167 #else // !XNN_ARCH_ARM64
168 const int32x4_t vt0123 = vmlal_s16(vshll_n_s16(vget_low_s16(vxtl01234567), 11), vget_low_s16(vtd01234567), valphah);
169 const int32x4_t vt4567 = vmlal_s16(vshll_n_s16(vget_high_s16(vxtl01234567), 11), vget_high_s16(vtd01234567), valphah);
170
171 const int32x4_t vd0123 = vmlal_s16(vshll_n_s16(vget_low_s16(vdl01234567), 11), vget_low_s16(vdd01234567), valphah);
172 const int32x4_t vd4567 = vmlal_s16(vshll_n_s16(vget_high_s16(vdl01234567), 11), vget_high_s16(vdd01234567), valphah);
173 #endif // !XNN_ARCH_ARM64
174
175 const int32x4_t vacc0123 = vmlaq_s32(vshlq_n_s32(vt0123, 11), vd0123, valphav);
176 const int32x4_t vacc4567 = vmlaq_s32(vshlq_n_s32(vt4567, 11), vd4567, valphav);
177
178 #if XNN_ARCH_ARM64
179 const int16x8_t vacc01234567 = vuzp2q_s16(vreinterpretq_s16_s32(vacc0123), vreinterpretq_s16_s32(vacc4567));
180 #else // !XNN_ARCH_ARM64
181 const int16x8_t vacc01234567 = vcombine_s16(vshrn_n_s32(vacc0123, 16), vshrn_n_s32(vacc4567, 16));
182 #endif // !XNN_ARCH_ARM64
183
184 int8x8_t vo01234567 = vrshrn_n_s16(vacc01234567, 6);
185
186 if (c & (4 * sizeof(int8_t))) {
187 vst1_lane_u32((void*) output, vreinterpret_u32_s8(vo01234567), 0); output += 4;
188 vo01234567 = vext_s8(vo01234567, vo01234567, 4);
189 }
190 if (c & (2 * sizeof(int8_t))) {
191 vst1_lane_u16((void*) output, vreinterpret_u16_s8(vo01234567), 0); output += 2;
192 vo01234567 = vext_s8(vo01234567, vo01234567, 2);
193 }
194 if (c & (1 * sizeof(int8_t))) {
195 vst1_lane_s8(output, vo01234567, 0); output += 1;
196 }
197 }
198
199 output = (int8_t*) ((uintptr_t) output + output_increment);
200 } while (--output_pixels != 0);
201 }
202