1 // Auto-generated file. Do not edit!
2 // Template: src/f16-gemm/neonfp16arith-ld64.c.in
3 // Generator: tools/xngen
4 //
5 // Copyright 2020 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9
10
11 #include <assert.h>
12
13 #include <arm_neon.h>
14
15 #include <xnnpack/common.h>
16
17 #include <xnnpack/gemm.h>
18
19
xnn_f16_gemminc_minmax_ukernel_4x16__neonfp16arith_ld64(size_t mr,size_t nc,size_t kc,const void * restrict a,size_t a_stride,const void * restrict w,void * restrict c,size_t cm_stride,size_t cn_stride,const void * restrict acc,const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])20 void xnn_f16_gemminc_minmax_ukernel_4x16__neonfp16arith_ld64(
21 size_t mr,
22 size_t nc,
23 size_t kc,
24 const void* restrict a,
25 size_t a_stride,
26 const void* restrict w,
27 void* restrict c,
28 size_t cm_stride,
29 size_t cn_stride,
30 const void*restrict acc,
31 const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
32 {
33 assert(mr != 0);
34 assert(mr <= 4);
35 assert(nc != 0);
36 assert(kc != 0);
37 assert(kc % sizeof(__fp16) == 0);
38 assert(a != NULL);
39 assert(w != NULL);
40 assert(c != NULL);
41 assert(acc != NULL);
42
43 const __fp16* a0 = (const __fp16*) a;
44 __fp16* c0 = (__fp16*) c;
45 const __fp16* a1 = (const __fp16*) ((uintptr_t) a0 + a_stride);
46 __fp16* c1 = (__fp16*) ((uintptr_t) c0 + cm_stride);
47 if XNN_UNPREDICTABLE(mr < 2) {
48 a1 = a0;
49 c1 = c0;
50 }
51 const __fp16* a2 = (const __fp16*) ((uintptr_t) a1 + a_stride);
52 __fp16* c2 = (__fp16*) ((uintptr_t) c1 + cm_stride);
53 if XNN_UNPREDICTABLE(mr <= 2) {
54 a2 = a1;
55 c2 = c1;
56 }
57 const __fp16* a3 = (const __fp16*) ((uintptr_t) a2 + a_stride);
58 __fp16* c3 = (__fp16*) ((uintptr_t) c2 + cm_stride);
59 if XNN_UNPREDICTABLE(mr != 4) {
60 a3 = a2;
61 c3 = c2;
62 }
63
64 do {
65 float16x8_t vacc0x01234567 = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t));
66 float16x8_t vacc0x89ABCDEF = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t));
67 float16x8_t vacc1x01234567 = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t));
68 float16x8_t vacc1x89ABCDEF = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t));
69 float16x8_t vacc2x01234567 = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t));
70 float16x8_t vacc2x89ABCDEF = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t));
71 float16x8_t vacc3x01234567 = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t));
72 float16x8_t vacc3x89ABCDEF = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t));
73
74 size_t k = kc;
75 while (k >= 4 * sizeof(__fp16)) {
76 const float16x4_t va0 = vld1_f16(a0); a0 += 4;
77 const float16x4_t va1 = vld1_f16(a1); a1 += 4;
78 const float16x4_t va2 = vld1_f16(a2); a2 += 4;
79 const float16x4_t va3 = vld1_f16(a3); a3 += 4;
80
81 const float16x8_t vb01234567c0 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
82 const float16x8_t vb89ABCDEFc0 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
83
84 #if XNN_ARCH_ARM64
85 vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c0, va0, 0);
86 vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c0, va1, 0);
87 vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c0, va2, 0);
88 vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c0, va3, 0);
89 vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc0, va0, 0);
90 vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc0, va1, 0);
91 vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc0, va2, 0);
92 vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc0, va3, 0);
93 #else
94 const float16x8_t va0c0 = vdupq_lane_f16(va0, 0);
95 const float16x8_t va1c0 = vdupq_lane_f16(va1, 0);
96 const float16x8_t va2c0 = vdupq_lane_f16(va2, 0);
97 const float16x8_t va3c0 = vdupq_lane_f16(va3, 0);
98
99 vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c0, vb01234567c0);
100 vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c0, vb01234567c0);
101 vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c0, vb01234567c0);
102 vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c0, vb01234567c0);
103 vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c0, vb89ABCDEFc0);
104 vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c0, vb89ABCDEFc0);
105 vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c0, vb89ABCDEFc0);
106 vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c0, vb89ABCDEFc0);
107 #endif
108 const float16x8_t vb01234567c1 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
109 const float16x8_t vb89ABCDEFc1 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
110
111 #if XNN_ARCH_ARM64
112 vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c1, va0, 1);
113 vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c1, va1, 1);
114 vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c1, va2, 1);
115 vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c1, va3, 1);
116 vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc1, va0, 1);
117 vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc1, va1, 1);
118 vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc1, va2, 1);
119 vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc1, va3, 1);
120 #else
121 const float16x8_t va0c1 = vdupq_lane_f16(va0, 1);
122 const float16x8_t va1c1 = vdupq_lane_f16(va1, 1);
123 const float16x8_t va2c1 = vdupq_lane_f16(va2, 1);
124 const float16x8_t va3c1 = vdupq_lane_f16(va3, 1);
125
126 vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c1, vb01234567c1);
127 vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c1, vb01234567c1);
128 vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c1, vb01234567c1);
129 vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c1, vb01234567c1);
130 vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c1, vb89ABCDEFc1);
131 vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c1, vb89ABCDEFc1);
132 vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c1, vb89ABCDEFc1);
133 vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c1, vb89ABCDEFc1);
134 #endif
135 const float16x8_t vb01234567c2 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
136 const float16x8_t vb89ABCDEFc2 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
137
138 #if XNN_ARCH_ARM64
139 vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c2, va0, 2);
140 vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c2, va1, 2);
141 vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c2, va2, 2);
142 vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c2, va3, 2);
143 vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc2, va0, 2);
144 vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc2, va1, 2);
145 vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc2, va2, 2);
146 vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc2, va3, 2);
147 #else
148 const float16x8_t va0c2 = vdupq_lane_f16(va0, 2);
149 const float16x8_t va1c2 = vdupq_lane_f16(va1, 2);
150 const float16x8_t va2c2 = vdupq_lane_f16(va2, 2);
151 const float16x8_t va3c2 = vdupq_lane_f16(va3, 2);
152
153 vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c2, vb01234567c2);
154 vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c2, vb01234567c2);
155 vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c2, vb01234567c2);
156 vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c2, vb01234567c2);
157 vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c2, vb89ABCDEFc2);
158 vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c2, vb89ABCDEFc2);
159 vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c2, vb89ABCDEFc2);
160 vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c2, vb89ABCDEFc2);
161 #endif
162 const float16x8_t vb01234567c3 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
163 const float16x8_t vb89ABCDEFc3 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
164
165 #if XNN_ARCH_ARM64
166 vacc0x01234567 = vfmaq_lane_f16(vacc0x01234567, vb01234567c3, va0, 3);
167 vacc1x01234567 = vfmaq_lane_f16(vacc1x01234567, vb01234567c3, va1, 3);
168 vacc2x01234567 = vfmaq_lane_f16(vacc2x01234567, vb01234567c3, va2, 3);
169 vacc3x01234567 = vfmaq_lane_f16(vacc3x01234567, vb01234567c3, va3, 3);
170 vacc0x89ABCDEF = vfmaq_lane_f16(vacc0x89ABCDEF, vb89ABCDEFc3, va0, 3);
171 vacc1x89ABCDEF = vfmaq_lane_f16(vacc1x89ABCDEF, vb89ABCDEFc3, va1, 3);
172 vacc2x89ABCDEF = vfmaq_lane_f16(vacc2x89ABCDEF, vb89ABCDEFc3, va2, 3);
173 vacc3x89ABCDEF = vfmaq_lane_f16(vacc3x89ABCDEF, vb89ABCDEFc3, va3, 3);
174 #else
175 const float16x8_t va0c3 = vdupq_lane_f16(va0, 3);
176 const float16x8_t va1c3 = vdupq_lane_f16(va1, 3);
177 const float16x8_t va2c3 = vdupq_lane_f16(va2, 3);
178 const float16x8_t va3c3 = vdupq_lane_f16(va3, 3);
179
180 vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0c3, vb01234567c3);
181 vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1c3, vb01234567c3);
182 vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2c3, vb01234567c3);
183 vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3c3, vb01234567c3);
184 vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0c3, vb89ABCDEFc3);
185 vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1c3, vb89ABCDEFc3);
186 vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2c3, vb89ABCDEFc3);
187 vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3c3, vb89ABCDEFc3);
188 #endif
189
190 k -= 4 * sizeof(__fp16);
191 }
192 if XNN_UNLIKELY(k != 0) {
193 do {
194 const float16x8_t va0 = vld1q_dup_f16(a0); a0 += 1;
195 const float16x8_t va1 = vld1q_dup_f16(a1); a1 += 1;
196 const float16x8_t va2 = vld1q_dup_f16(a2); a2 += 1;
197 const float16x8_t va3 = vld1q_dup_f16(a3); a3 += 1;
198
199 const float16x8_t vb01234567 = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
200 const float16x8_t vb89ABCDEF = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t));
201
202 vacc0x01234567 = vfmaq_f16(vacc0x01234567, va0, vb01234567);
203 vacc1x01234567 = vfmaq_f16(vacc1x01234567, va1, vb01234567);
204 vacc2x01234567 = vfmaq_f16(vacc2x01234567, va2, vb01234567);
205 vacc3x01234567 = vfmaq_f16(vacc3x01234567, va3, vb01234567);
206 vacc0x89ABCDEF = vfmaq_f16(vacc0x89ABCDEF, va0, vb89ABCDEF);
207 vacc1x89ABCDEF = vfmaq_f16(vacc1x89ABCDEF, va1, vb89ABCDEF);
208 vacc2x89ABCDEF = vfmaq_f16(vacc2x89ABCDEF, va2, vb89ABCDEF);
209 vacc3x89ABCDEF = vfmaq_f16(vacc3x89ABCDEF, va3, vb89ABCDEF);
210
211 k -= sizeof(__fp16);
212 } while (k != 0);
213 }
214
215
216 const float16x8_t vmax = vreinterpretq_f16_u16(vld1q_dup_u16(¶ms->neon.max));
217 vacc0x01234567 = vminq_f16(vacc0x01234567, vmax);
218 vacc1x01234567 = vminq_f16(vacc1x01234567, vmax);
219 vacc2x01234567 = vminq_f16(vacc2x01234567, vmax);
220 vacc3x01234567 = vminq_f16(vacc3x01234567, vmax);
221 vacc0x89ABCDEF = vminq_f16(vacc0x89ABCDEF, vmax);
222 vacc1x89ABCDEF = vminq_f16(vacc1x89ABCDEF, vmax);
223 vacc2x89ABCDEF = vminq_f16(vacc2x89ABCDEF, vmax);
224 vacc3x89ABCDEF = vminq_f16(vacc3x89ABCDEF, vmax);
225
226 const float16x8_t vmin = vreinterpretq_f16_u16(vld1q_dup_u16(¶ms->neon.min));
227 vacc0x01234567 = vmaxq_f16(vacc0x01234567, vmin);
228 vacc1x01234567 = vmaxq_f16(vacc1x01234567, vmin);
229 vacc2x01234567 = vmaxq_f16(vacc2x01234567, vmin);
230 vacc3x01234567 = vmaxq_f16(vacc3x01234567, vmin);
231 vacc0x89ABCDEF = vmaxq_f16(vacc0x89ABCDEF, vmin);
232 vacc1x89ABCDEF = vmaxq_f16(vacc1x89ABCDEF, vmin);
233 vacc2x89ABCDEF = vmaxq_f16(vacc2x89ABCDEF, vmin);
234 vacc3x89ABCDEF = vmaxq_f16(vacc3x89ABCDEF, vmin);
235
236 if XNN_LIKELY(nc >= 16) {
237 vst1q_f16(c0, vacc0x01234567);
238 vst1q_f16(c0 + 8, vacc0x89ABCDEF);
239 c0 = (__fp16*) ((uintptr_t) c0 + cn_stride);
240 vst1q_f16(c1, vacc1x01234567);
241 vst1q_f16(c1 + 8, vacc1x89ABCDEF);
242 c1 = (__fp16*) ((uintptr_t) c1 + cn_stride);
243 vst1q_f16(c2, vacc2x01234567);
244 vst1q_f16(c2 + 8, vacc2x89ABCDEF);
245 c2 = (__fp16*) ((uintptr_t) c2 + cn_stride);
246 vst1q_f16(c3, vacc3x01234567);
247 vst1q_f16(c3 + 8, vacc3x89ABCDEF);
248 c3 = (__fp16*) ((uintptr_t) c3 + cn_stride);
249
250 a0 = (const __fp16*) ((uintptr_t) a0 - kc);
251 a1 = (const __fp16*) ((uintptr_t) a1 - kc);
252 a2 = (const __fp16*) ((uintptr_t) a2 - kc);
253 a3 = (const __fp16*) ((uintptr_t) a3 - kc);
254
255 nc -= 16;
256 } else {
257 if (nc & 8) {
258 vst1q_f16(c0, vacc0x01234567); c0 += 8;
259 vst1q_f16(c1, vacc1x01234567); c1 += 8;
260 vst1q_f16(c2, vacc2x01234567); c2 += 8;
261 vst1q_f16(c3, vacc3x01234567); c3 += 8;
262
263 vacc0x01234567 = vacc0x89ABCDEF;
264 vacc1x01234567 = vacc1x89ABCDEF;
265 vacc2x01234567 = vacc2x89ABCDEF;
266 vacc3x01234567 = vacc3x89ABCDEF;
267 }
268 float16x4_t vacc0x0123 = vget_low_f16(vacc0x01234567);
269 float16x4_t vacc1x0123 = vget_low_f16(vacc1x01234567);
270 float16x4_t vacc2x0123 = vget_low_f16(vacc2x01234567);
271 float16x4_t vacc3x0123 = vget_low_f16(vacc3x01234567);
272 if (nc & 4) {
273 vst1_f16(c0, vacc0x0123); c0 += 4;
274 vst1_f16(c1, vacc1x0123); c1 += 4;
275 vst1_f16(c2, vacc2x0123); c2 += 4;
276 vst1_f16(c3, vacc3x0123); c3 += 4;
277
278 vacc0x0123 = vget_high_f16(vacc0x01234567);
279 vacc1x0123 = vget_high_f16(vacc1x01234567);
280 vacc2x0123 = vget_high_f16(vacc2x01234567);
281 vacc3x0123 = vget_high_f16(vacc3x01234567);
282 }
283 if (nc & 2) {
284 vst1_lane_u32((void*) c0, vreinterpret_u32_f16(vacc0x0123), 0); c0 += 2;
285 vst1_lane_u32((void*) c1, vreinterpret_u32_f16(vacc1x0123), 0); c1 += 2;
286 vst1_lane_u32((void*) c2, vreinterpret_u32_f16(vacc2x0123), 0); c2 += 2;
287 vst1_lane_u32((void*) c3, vreinterpret_u32_f16(vacc3x0123), 0); c3 += 2;
288
289 vacc0x0123 = vext_f16(vacc0x0123, vacc0x0123, 2);
290 vacc1x0123 = vext_f16(vacc1x0123, vacc1x0123, 2);
291 vacc2x0123 = vext_f16(vacc2x0123, vacc2x0123, 2);
292 vacc3x0123 = vext_f16(vacc3x0123, vacc3x0123, 2);
293 }
294 if (nc & 1) {
295 vst1_lane_f16(c0, vacc0x0123, 0);
296 vst1_lane_f16(c1, vacc1x0123, 0);
297 vst1_lane_f16(c2, vacc2x0123, 0);
298 vst1_lane_f16(c3, vacc3x0123, 0);
299 }
300
301 nc = 0;
302 }
303 } while (nc != 0);
304 }
305