1// Copyright 2022 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert NR % 4 == 0 7$assert EXTOPT in ["SHLAND", "ZIP", "MOVL"] 8$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 9 10#include <assert.h> 11 12#include <arm_neon.h> 13 14#include <xnnpack/gemm.h> 15 16 17void xnn_bf16_gemm_minmax_ukernel_${MR}x${NR}c8__neonfma_${EXTOPT.lower()}( 18 size_t mr, 19 size_t nc, 20 size_t kc, 21 const void* restrict a, 22 size_t a_stride, 23 const void* restrict w_ptr, 24 void* restrict c, 25 size_t cm_stride, 26 size_t cn_stride, 27 const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) 28{ 29 assert(mr != 0); 30 assert(mr <= ${MR}); 31 assert(nc != 0); 32 assert(kc != 0); 33 assert(kc % sizeof(uint16_t) == 0); 34 assert(a != NULL); 35 assert(w_ptr != NULL); 36 assert(c != NULL); 37 38 const uint16_t* a0 = (const uint16_t*) a; 39 uint16_t* c0 = (uint16_t*) c; 40 $for M in range(1, MR): 41 const uint16_t* a${M} = (const uint16_t*) ((uintptr_t) a${M-1} + a_stride); 42 uint16_t* c${M} = (uint16_t*) ((uintptr_t) c${M-1} + cm_stride); 43 $if M % 2 == 0: 44 if XNN_UNPREDICTABLE(mr <= ${M}) { 45 a${M} = a${M-1}; 46 c${M} = c${M-1}; 47 } 48 $elif M + 1 == MR: 49 if XNN_UNPREDICTABLE(mr != ${M+1}) { 50 a${M} = a${M-1}; 51 c${M} = c${M-1}; 52 } 53 $else: 54 if XNN_UNPREDICTABLE(mr < ${M+1}) { 55 a${M} = a${M-1}; 56 c${M} = c${M-1}; 57 } 58 59 const uint16_t* w = (const uint16_t*) w_ptr; 60 $if EXTOPT == "SHLAND": 61 const uint16x8_t vmask = vreinterpretq_u16_u32(vmovq_n_u32(UINT32_C(0xFFFF0000))); 62 $elif EXTOPT == "ZIP": 63 const uint16x8_t vzero = vmovq_n_u16(0); 64 do { 65 $for N in range(NR): 66 float32x4_t vacc0x${ABC[N]} = vreinterpretq_f32_u32(vshll_n_u16(vld1_lane_u16(w, vdup_n_u16(0), 0), 16)); w += 1; 67 $for M in range(1, MR): 68 $for N in range(NR): 69 float32x4_t vacc${M}x${ABC[N]} = vacc0x${ABC[N]}; 70 71 size_t k = kc; 72 for (; k >= 8 * sizeof(uint16_t); k -= 8 * sizeof(uint16_t)) { 73 $for M in range(MR): 74 const uint16x8_t va${M} = vld1q_u16(a${M}); a${M} += 8; 75 76 $for N in range(NR): 77 const uint16x8_t vb${ABC[N]} = vld1q_u16(w); w += 8; 78 79 $for M in range(MR): 80 $if EXTOPT == "SHLAND": 81 const float32x4_t va${M}e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va${M}), 16)); 82 $elif EXTOPT == "ZIP": 83 const float32x4_t va${M}e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va${M})); 84 85 $for N in range(NR): 86 $if EXTOPT == "SHLAND": 87 const float32x4_t vb${ABC[N]}e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb${ABC[N]}), 16)); 88 $elif EXTOPT == "ZIP": 89 const float32x4_t vb${ABC[N]}e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb${ABC[N]})); 90 91 $for N in range(NR): 92 $for M in range(MR): 93 vacc${M}x${ABC[N]} = vfmaq_f32(vacc${M}x${ABC[N]}, va${M}e, vb${ABC[N]}e); 94 95 $for M in range(MR): 96 $if EXTOPT == "SHLAND": 97 const float32x4_t va${M}o = vreinterpretq_f32_u16(vandq_u16(va${M}, vmask)); 98 $elif EXTOPT == "ZIP": 99 const float32x4_t va${M}o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va${M})); 100 101 $for N in range(NR): 102 $if EXTOPT == "SHLAND": 103 const float32x4_t vb${ABC[N]}o = vreinterpretq_f32_u16(vandq_u16(vb${ABC[N]}, vmask)); 104 $elif EXTOPT == "ZIP": 105 const float32x4_t vb${ABC[N]}o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb${ABC[N]})); 106 107 $for N in range(NR): 108 $for M in range(MR): 109 vacc${M}x${ABC[N]} = vfmaq_f32(vacc${M}x${ABC[N]}, va${M}o, vb${ABC[N]}o); 110 } 111 if XNN_UNLIKELY(k != 0) { 112 $for M in range(MR): 113 const uint16x8_t va${M} = vld1q_u16(a${M}); a${M} = (const uint16_t*) ((uintptr_t) a${M} + k); 114 115 $for N in range(NR): 116 const uint16x8_t vb${ABC[N]} = vld1q_u16(w); w += 8; 117 118 $for N in range(NR): 119 const uint16x8_t vm${ABC[N]} = vceqq_u16(vb${ABC[N]}, vmovq_n_u16(0)); 120 121 $for N in range(NR): 122 $if EXTOPT == "SHLAND": 123 const float32x4_t vb${ABC[N]}e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(vb${ABC[N]}), 16)); 124 $elif EXTOPT == "ZIP": 125 const float32x4_t vb${ABC[N]}e = vreinterpretq_f32_u16(vzip1q_u16(vzero, vb${ABC[N]})); 126 127 $for N in range(NR): 128 $for M in range(MR): 129 const uint16x8_t va${M}x${ABC[N]} = vbicq_u16(va${M}, vm${ABC[N]}); 130 131 $for N in range(NR): 132 $for M in range(MR): 133 $if EXTOPT == "SHLAND": 134 const float32x4_t va${M}x${ABC[N]}e = vreinterpretq_f32_u32(vshlq_n_u32(vreinterpretq_u32_u16(va${M}x${ABC[N]}), 16)); 135 $elif EXTOPT == "ZIP": 136 const float32x4_t va${M}x${ABC[N]}e = vreinterpretq_f32_u16(vzip1q_u16(vzero, va${M}x${ABC[N]})); 137 138 $for N in range(NR): 139 $for M in range(MR): 140 vacc${M}x${ABC[N]} = vfmaq_f32(vacc${M}x${ABC[N]}, va${M}x${ABC[N]}e, vb${ABC[N]}e); 141 142 $for N in range(NR): 143 $if EXTOPT == "SHLAND": 144 const float32x4_t vb${ABC[N]}o = vreinterpretq_f32_u16(vandq_u16(vb${ABC[N]}, vmask)); 145 $elif EXTOPT == "ZIP": 146 const float32x4_t vb${ABC[N]}o = vreinterpretq_f32_u16(vzip2q_u16(vzero, vb${ABC[N]})); 147 148 $for N in range(NR): 149 $for M in range(MR): 150 $if EXTOPT == "SHLAND": 151 const float32x4_t va${M}x${ABC[N]}o = vreinterpretq_f32_u16(vandq_u16(va${M}x${ABC[N]}, vmask)); 152 $elif EXTOPT == "ZIP": 153 const float32x4_t va${M}x${ABC[N]}o = vreinterpretq_f32_u16(vzip2q_u16(vzero, va${M}x${ABC[N]})); 154 155 $for N in range(NR): 156 $for M in range(MR): 157 vacc${M}x${ABC[N]} = vfmaq_f32(vacc${M}x${ABC[N]}, va${M}x${ABC[N]}o, vb${ABC[N]}o); 158 } 159 160#if XNN_ARCH_ARM64 161 $for N in range(0, NR, 2): 162 $for M in range(MR): 163 const float32x4_t vacc${M}x${ABC[N:N+2]} = vpaddq_f32(vacc${M}x${ABC[N]}, vacc${M}x${ABC[N+1]}); 164 165 $for N in range(0, NR, 4): 166 $for M in range(MR): 167 float32x4_t vacc${M}x${ABC[N:N+4]} = vpaddq_f32(vacc${M}x${ABC[N:N+2]}, vacc${M}x${ABC[N+2:N+4]}); 168#else 169 $for N in range(NR): 170 $for M in range(MR): 171 const float32x2_t vsum${M}x${ABC[N]} = vadd_f32(vget_low_f32(vacc${M}x${ABC[N]}), vget_high_f32(vacc${M}x${ABC[N]})); 172 173 $for N in range(0, NR, 4): 174 $for M in range(MR): 175 float32x4_t vacc${M}x${ABC[N:N+4]} = vcombine_f32(vpadd_f32(vsum${M}x${ABC[N]}, vsum${M}x${ABC[N+1]}), vpadd_f32(vsum${M}x${ABC[N+2]}, vsum${M}x${ABC[N+3]})); 176#endif 177 178 const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); 179 $for N in range(0, NR, 4): 180 $for M in range(MR): 181 vacc${M}x${ABC[N:N+4]} = vminq_f32(vacc${M}x${ABC[N:N+4]}, vmax); 182 183 const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); 184 $for N in range(0, NR, 4): 185 $for M in range(MR): 186 vacc${M}x${ABC[N:N+4]} = vmaxq_f32(vacc${M}x${ABC[N:N+4]}, vmin); 187 188 $for N in range(0, NR, 4): 189 $for M in range(MR): 190 uint16x4_t vout${M}x${ABC[N:N+4]} = vshrn_n_u32(vreinterpretq_u32_f32(vacc${M}x${ABC[N:N+4]}), 16); 191 192 if XNN_LIKELY(nc >= ${NR}) { 193 $for M in range(MR): 194 vst1_u16(c${M}, vout${M}x${ABC[0:4]}); 195 $for N in range(4, NR, 4): 196 vst1_u16(c${M} + ${N}, vout${M}x${ABC[N:N+4]}); 197 c${M} = (uint16_t*) ((uintptr_t) c${M} + cn_stride); 198 199 $for M in range(MR): 200 a${M} = (const uint16_t*) ((uintptr_t) a${M} - kc); 201 202 nc -= ${NR}; 203 } else { 204 $for LOG2N in reversed(range(NR.bit_length())): 205 $if NR != 1 << LOG2N: 206 if (nc & ${1 << LOG2N}) { 207 $if LOG2N >= 2: 208 $for N in range(0, 1 << LOG2N, 4): 209 $for M in range(MR): 210 vst1_u16(c${M}, vout${M}x${ABC[N:N+4]}); c${M} += 4; 211 212 $for M in range(MR): 213 $for N in range(0, 1 << (LOG2N - 1), 4): 214 vout${M}x${ABC[N:N+4]} = vout${M}x${ABC[N + (1 << LOG2N):N + (1 << LOG2N)+4]}; 215 $elif LOG2N == 1: 216 $for M in range(MR): 217 vst1_lane_u32((void*) c${M}, vreinterpret_u32_u16(vout${M}x${ABC[0:4]}), 0); c${M} += 2; 218 219 $for M in range(MR): 220 vout${M}x${ABC[0:4]} = vext_u16(vout${M}x${ABC[0:4]}, vout${M}x${ABC[0:4]}, 2); 221 $elif LOG2N == 0: 222 $for M in range(MR): 223 vst1_lane_u16(c${M}, vout${M}x${ABC[0:4]}, 0); 224 } 225 226 nc = 0; 227 } 228 } while (nc != 0); 229} 230