1 // Auto-generated file. Do not edit!
2 // Template: src/f16-spmm/neonfp16arith.c.in
3 // Generator: tools/xngen
4 //
5 // Copyright 2019 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9
10 #include <assert.h>
11
12 #include <arm_neon.h>
13
14 #include <xnnpack/spmm.h>
15
16
xnn_f16_spmm_minmax_ukernel_8x1__neonfp16arith_x2(size_t mc,size_t nc,const void * restrict input,const void * restrict weights,const int32_t * restrict widx_dmap,const uint32_t * restrict nidx_nnzmap,void * restrict output,size_t output_stride,const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])17 void xnn_f16_spmm_minmax_ukernel_8x1__neonfp16arith_x2(
18 size_t mc,
19 size_t nc,
20 const void*restrict input,
21 const void*restrict weights,
22 const int32_t*restrict widx_dmap,
23 const uint32_t*restrict nidx_nnzmap,
24 void*restrict output,
25 size_t output_stride,
26 const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
27 {
28 assert(mc != 0);
29 assert(mc % sizeof(__fp16) == 0);
30 assert(nc != 0);
31
32 const __fp16*restrict i = (const __fp16*) input;
33 __fp16*restrict o = (__fp16*) output;
34
35 const float16x8_t vmax = vreinterpretq_f16_u16(vld1q_dup_u16(¶ms->neon.max));
36 const float16x8_t vmin = vreinterpretq_f16_u16(vld1q_dup_u16(¶ms->neon.min));
37
38 size_t output_decrement = output_stride * nc - 8 * sizeof(__fp16);
39 while XNN_LIKELY(mc >= 8 * sizeof(__fp16)) {
40 const __fp16*restrict w = (const __fp16*) weights;
41 const int32_t* dmap = widx_dmap;
42 const uint32_t* nnzmap = nidx_nnzmap;
43 size_t n = nc;
44 do {
45 uint32_t nnz = *nnzmap++;
46 float16x8_t vacc01234567x0 = vld1q_dup_f16(w); w += 1;
47 float16x8_t vacc01234567x1 = vmovq_n_f16(0.0f);
48 for (; nnz >= 2; nnz -= 2) {
49 const intptr_t diff0 = dmap[0];
50 const intptr_t diff1 = dmap[1];
51 dmap += 2;
52 const float16x8_t va01234567x0 = vld1q_f16(i);
53 i = (const __fp16*restrict) ((uintptr_t) i + (uintptr_t) diff0);
54 const float16x8_t vb0 = vld1q_dup_f16(w); w += 1;
55 vacc01234567x0 = vfmaq_f16(vacc01234567x0, va01234567x0, vb0);
56 const float16x8_t va01234567x1 = vld1q_f16(i);
57 i = (const __fp16*restrict) ((uintptr_t) i + (uintptr_t) diff1);
58 const float16x8_t vb1 = vld1q_dup_f16(w); w += 1;
59 vacc01234567x1 = vfmaq_f16(vacc01234567x1, va01234567x1, vb1);
60 }
61 float16x8_t vacc01234567 = vacc01234567x0;
62 vacc01234567 = vaddq_f16(vacc01234567, vacc01234567x1);
63 if XNN_LIKELY(nnz != 0) {
64 do {
65 const intptr_t diff = *dmap++;
66 const float16x8_t va01234567 = vld1q_f16(i);
67 i = (const __fp16*restrict) ((uintptr_t) i + (uintptr_t) diff);
68 const float16x8_t vb = vld1q_dup_f16(w); w += 1;
69 vacc01234567 = vfmaq_f16(vacc01234567, va01234567, vb);
70 } while (--nnz != 0);
71 }
72 float16x8_t vout01234567 = vminq_f16(vacc01234567, vmax);
73 vout01234567 = vmaxq_f16(vout01234567, vmin);
74 vst1q_f16(o, vout01234567);
75 o = (__fp16*restrict) ((uintptr_t) o + output_stride);
76 } while (--n != 0);
77 o = (__fp16*restrict) ((uintptr_t) o - output_decrement);
78 i += 8;
79 mc -= 8 * sizeof(__fp16);
80 }
81 if XNN_UNLIKELY(mc != 0) {
82 output_decrement += 4 * sizeof(__fp16);
83 if (mc & (4 * sizeof(__fp16))) {
84 const __fp16*restrict w = (const __fp16*) weights;
85 const int32_t* dmap = widx_dmap;
86 const uint32_t* nnzmap = nidx_nnzmap;
87 size_t n = nc;
88 do {
89 uint32_t nnz = *nnzmap++;
90 float16x4_t vacc0123 = vld1_dup_f16(w); w += 1;
91 if XNN_LIKELY(nnz != 0) {
92 do {
93 const intptr_t diff = *dmap++;
94 const float16x4_t va0123 = vld1_f16(i);
95 i = (const __fp16*restrict) ((uintptr_t) i + (uintptr_t) diff);
96 const float16x4_t vb = vld1_dup_f16(w); w += 1;
97 vacc0123 = vfma_f16(vacc0123, va0123, vb);
98 } while (--nnz != 0);
99 }
100 float16x4_t vout0123 = vmin_f16(vacc0123, vget_low_f16(vmax));
101 vout0123 = vmax_f16(vout0123, vget_low_f16(vmin));
102 vst1_f16(o, vout0123);
103 o = (__fp16*restrict) ((uintptr_t) o + output_stride);
104 } while (--n != 0);
105 o = (__fp16*restrict) ((uintptr_t) o - output_decrement);
106 i += 4;
107 }
108 output_decrement += 2 * sizeof(__fp16);
109 if (mc & (2 * sizeof(__fp16))) {
110 const __fp16*restrict w = (const __fp16*) weights;
111 const int32_t* dmap = widx_dmap;
112 const uint32_t* nnzmap = nidx_nnzmap;
113 size_t n = nc;
114 do {
115 uint32_t nnz = *nnzmap++;
116 float16x4_t vacc01 = vld1_dup_f16(w); w += 1;
117 if XNN_LIKELY(nnz != 0) {
118 do {
119 const intptr_t diff = *dmap++;
120 const float16x4_t va01 = vreinterpret_f16_f32(vld1_dup_f32((const void*) i));
121 i = (const __fp16*restrict) ((uintptr_t) i + (uintptr_t) diff);
122 const float16x4_t vb = vld1_dup_f16(w); w += 1;
123 vacc01 = vfma_f16(vacc01, va01, vb);
124 } while (--nnz != 0);
125 }
126 float16x4_t vout01 = vmin_f16(vacc01, vget_low_f16(vmax));
127 vout01 = vmax_f16(vout01, vget_low_f16(vmin));
128 vst1_lane_f32((void*) o, vreinterpret_f32_f16(vout01), 0);
129 o = (__fp16*restrict) ((uintptr_t) o + output_stride);
130 } while (--n != 0);
131 o = (__fp16*restrict) ((uintptr_t) o - output_decrement);
132 i += 2;
133 }
134 output_decrement += 1 * sizeof(__fp16);
135 if (mc & (1 * sizeof(__fp16))) {
136 const __fp16*restrict w = (const __fp16*) weights;
137 const int32_t* dmap = widx_dmap;
138 const uint32_t* nnzmap = nidx_nnzmap;
139 size_t n = nc;
140 do {
141 uint32_t nnz = *nnzmap++;
142 float16x4_t vacc0 = vld1_dup_f16(w); w += 1;
143 if XNN_LIKELY(nnz != 0) {
144 do {
145 const intptr_t diff = *dmap++;
146 const float16x4_t va0 = vld1_dup_f16(i);
147 i = (const __fp16*restrict) ((uintptr_t) i + (uintptr_t) diff);
148 const float16x4_t vb = vld1_dup_f16(w); w += 1;
149 vacc0 = vfma_f16(vacc0, va0, vb);
150 } while (--nnz != 0);
151 }
152 float16x4_t vout0 = vmin_f16(vacc0, vget_low_f16(vmax));
153 vout0 = vmax_f16(vout0, vget_low_f16(vmin));
154 vst1_lane_f16(o, vout0, 0);
155 o = (__fp16*restrict) ((uintptr_t) o + output_stride);
156 } while (--n != 0);
157 o = (__fp16*restrict) ((uintptr_t) o - output_decrement);
158 i += 1;
159 }
160 }
161 }
162