1 // Auto-generated file. Do not edit!
2 // Template: src/bf16-gemm/c2-neonbf16-bfdot-lane-ld128.c.in
3 // Generator: tools/xngen
4 //
5 // Copyright 2022 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9
10
11 #include <assert.h>
12
13 #include <arm_neon.h>
14
15 #include <xnnpack/gemm.h>
16
17
xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128(size_t mr,size_t nc,size_t kc,const void * restrict a,size_t a_stride,const void * restrict w_ptr,void * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])18 void xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128(
19 size_t mr,
20 size_t nc,
21 size_t kc,
22 const void* restrict a,
23 size_t a_stride,
24 const void* restrict w_ptr,
25 void* restrict c,
26 size_t cm_stride,
27 size_t cn_stride,
28 const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
29 {
30 assert(mr != 0);
31 assert(mr <= 1);
32 assert(nc != 0);
33 assert(kc != 0);
34 assert(kc % sizeof(bfloat16_t) == 0);
35 assert(a != NULL);
36 assert(w_ptr != NULL);
37 assert(c != NULL);
38
39 const bfloat16_t* a0 = (const bfloat16_t*) a;
40 bfloat16_t* c0 = (bfloat16_t*) c;
41
42 const bfloat16_t* w = (const bfloat16_t*) w_ptr;
43 do {
44 float32x4_t vacc0x0123 = vcvt_f32_bf16(vld1_bf16(w)); w += 4;
45 float32x4_t vacc0x4567 = vcvt_f32_bf16(vld1_bf16(w)); w += 4;
46
47 size_t k = kc;
48 for (; k >= 8 * sizeof(bfloat16_t); k -= 8 * sizeof(bfloat16_t)) {
49 const bfloat16x8_t va0 = vld1q_bf16(a0); a0 += 8;
50
51 const bfloat16x8_t vb0123c01 = vld1q_bf16(w); w += 8;
52 const bfloat16x8_t vb4567c01 = vld1q_bf16(w); w += 8;
53
54 vacc0x0123 = vbfdotq_laneq_f32(vacc0x0123, vb0123c01, va0, 0);
55 vacc0x4567 = vbfdotq_laneq_f32(vacc0x4567, vb4567c01, va0, 0);
56 const bfloat16x8_t vb0123c23 = vld1q_bf16(w); w += 8;
57 const bfloat16x8_t vb4567c23 = vld1q_bf16(w); w += 8;
58
59 vacc0x0123 = vbfdotq_laneq_f32(vacc0x0123, vb0123c23, va0, 1);
60 vacc0x4567 = vbfdotq_laneq_f32(vacc0x4567, vb4567c23, va0, 1);
61 const bfloat16x8_t vb0123c45 = vld1q_bf16(w); w += 8;
62 const bfloat16x8_t vb4567c45 = vld1q_bf16(w); w += 8;
63
64 vacc0x0123 = vbfdotq_laneq_f32(vacc0x0123, vb0123c45, va0, 2);
65 vacc0x4567 = vbfdotq_laneq_f32(vacc0x4567, vb4567c45, va0, 2);
66 const bfloat16x8_t vb0123c67 = vld1q_bf16(w); w += 8;
67 const bfloat16x8_t vb4567c67 = vld1q_bf16(w); w += 8;
68
69 vacc0x0123 = vbfdotq_laneq_f32(vacc0x0123, vb0123c67, va0, 3);
70 vacc0x4567 = vbfdotq_laneq_f32(vacc0x4567, vb4567c67, va0, 3);
71 }
72 if XNN_UNLIKELY(k != 0) {
73 const bfloat16x8_t va0 = vld1q_bf16(a0); a0 = (const bfloat16_t*) ((uintptr_t) a0 + k);
74
75 const bfloat16x8_t vb0123c01 = vld1q_bf16(w); w += 8;
76 const bfloat16x8_t vb4567c01 = vld1q_bf16(w); w += 8;
77
78 const uint32x4_t va0c01 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va0)), 0);
79
80 const uint32x4_t vm0123c01 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb0123c01), vmovq_n_u16(0)));
81 const uint32x4_t vm4567c01 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb4567c01), vmovq_n_u16(0)));
82
83 const uint32x4_t va0x0123c01 = vbicq_u32(va0c01, vm0123c01);
84 vacc0x0123 = vbfdotq_f32(vacc0x0123, vb0123c01, vreinterpretq_bf16_u32(va0x0123c01));
85 const uint32x4_t va0x4567c01 = vbicq_u32(va0c01, vm4567c01);
86 vacc0x4567 = vbfdotq_f32(vacc0x4567, vb4567c01, vreinterpretq_bf16_u32(va0x4567c01));
87
88 if (k > 2 * sizeof(bfloat16_t)) {
89 const bfloat16x8_t vb0123c23 = vld1q_bf16(w); w += 8;
90 const bfloat16x8_t vb4567c23 = vld1q_bf16(w); w += 8;
91
92 const uint32x4_t va0c23 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_low_bf16(va0)), 1);
93
94 const uint32x4_t vm0123c23 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb0123c23), vmovq_n_u16(0)));
95 const uint32x4_t vm4567c23 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb4567c23), vmovq_n_u16(0)));
96
97 const uint32x4_t va0x0123c23 = vbicq_u32(va0c23, vm0123c23);
98 vacc0x0123 = vbfdotq_f32(vacc0x0123, vb0123c23, vreinterpretq_bf16_u32(va0x0123c23));
99 const uint32x4_t va0x4567c23 = vbicq_u32(va0c23, vm4567c23);
100 vacc0x4567 = vbfdotq_f32(vacc0x4567, vb4567c23, vreinterpretq_bf16_u32(va0x4567c23));
101
102 if (k > 4 * sizeof(bfloat16_t)) {
103 const bfloat16x8_t vb0123c45 = vld1q_bf16(w); w += 8;
104 const bfloat16x8_t vb4567c45 = vld1q_bf16(w); w += 8;
105
106 const uint32x4_t va0c45 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va0)), 0);
107
108 const uint32x4_t vm0123c45 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb0123c45), vmovq_n_u16(0)));
109 const uint32x4_t vm4567c45 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb4567c45), vmovq_n_u16(0)));
110
111 const uint32x4_t va0x0123c45 = vbicq_u32(va0c45, vm0123c45);
112 vacc0x0123 = vbfdotq_f32(vacc0x0123, vb0123c45, vreinterpretq_bf16_u32(va0x0123c45));
113 const uint32x4_t va0x4567c45 = vbicq_u32(va0c45, vm4567c45);
114 vacc0x4567 = vbfdotq_f32(vacc0x4567, vb4567c45, vreinterpretq_bf16_u32(va0x4567c45));
115
116 if (k > 6 * sizeof(bfloat16_t)) {
117 const bfloat16x8_t vb0123c67 = vld1q_bf16(w); w += 8;
118 const bfloat16x8_t vb4567c67 = vld1q_bf16(w); w += 8;
119
120 const uint32x4_t va0c67 = vdupq_lane_u32(vreinterpret_u32_bf16(vget_high_bf16(va0)), 1);
121
122 const uint32x4_t vm0123c67 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb0123c67), vmovq_n_u16(0)));
123 const uint32x4_t vm4567c67 = vreinterpretq_u32_u16(vceqq_u16(vreinterpretq_u16_bf16(vb4567c67), vmovq_n_u16(0)));
124
125 const uint32x4_t va0x0123c67 = vbicq_u32(va0c67, vm0123c67);
126 vacc0x0123 = vbfdotq_f32(vacc0x0123, vb0123c67, vreinterpretq_bf16_u32(va0x0123c67));
127 const uint32x4_t va0x4567c67 = vbicq_u32(va0c67, vm4567c67);
128 vacc0x4567 = vbfdotq_f32(vacc0x4567, vb4567c67, vreinterpretq_bf16_u32(va0x4567c67));
129 }
130 }
131 }
132 }
133
134 const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
135 vacc0x0123 = vminq_f32(vacc0x0123, vmax);
136 vacc0x4567 = vminq_f32(vacc0x4567, vmax);
137
138 const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
139 vacc0x0123 = vmaxq_f32(vacc0x0123, vmin);
140 vacc0x4567 = vmaxq_f32(vacc0x4567, vmin);
141
142 bfloat16x4_t vout0x0123 = vcvt_bf16_f32(vacc0x0123);
143 bfloat16x4_t vout0x4567 = vcvt_bf16_f32(vacc0x4567);
144
145 if XNN_LIKELY(nc >= 8) {
146 vst1_bf16(c0, vout0x0123);
147 vst1_bf16(c0 + 4, vout0x4567);
148 c0 = (bfloat16_t*) ((uintptr_t) c0 + cn_stride);
149
150 a0 = (const bfloat16_t*) ((uintptr_t) a0 - kc);
151
152 nc -= 8;
153 } else {
154 if (nc & 4) {
155 vst1_bf16(c0, vout0x0123); c0 += 4;
156
157 vout0x0123 = vout0x4567;
158 }
159 if (nc & 2) {
160 vst1_lane_u32((void*) c0, vreinterpret_u32_bf16(vout0x0123), 0); c0 += 2;
161
162 vout0x0123 = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout0x0123), vreinterpret_u16_bf16(vout0x0123), 2));
163 }
164 if (nc & 1) {
165 vst1_lane_bf16(c0, vout0x0123, 0);
166 }
167
168 nc = 0;
169 }
170 } while (nc != 0);
171 }
172