1// Copyright 2022 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert NR % 4 == 0 7$assert BFOPT in ["BFDOT", "BFMLAL"] 8$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 9 10#include <assert.h> 11 12#include <arm_neon.h> 13 14#include <xnnpack/gemm.h> 15 16 17void xnn_bf16_gemm_minmax_ukernel_${MR}x${NR}c8__neonbf16_${BFOPT.lower()}( 18 size_t mr, 19 size_t nc, 20 size_t kc, 21 const void* restrict a, 22 size_t a_stride, 23 const void* restrict w_ptr, 24 void* restrict c, 25 size_t cm_stride, 26 size_t cn_stride, 27 const union xnn_bf16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) 28{ 29 assert(mr != 0); 30 assert(mr <= ${MR}); 31 assert(nc != 0); 32 assert(kc != 0); 33 assert(kc % sizeof(bfloat16_t) == 0); 34 assert(a != NULL); 35 assert(w_ptr != NULL); 36 assert(c != NULL); 37 38 const bfloat16_t* a0 = (const bfloat16_t*) a; 39 bfloat16_t* c0 = (bfloat16_t*) c; 40 $for M in range(1, MR): 41 const bfloat16_t* a${M} = (const bfloat16_t*) ((uintptr_t) a${M-1} + a_stride); 42 bfloat16_t* c${M} = (bfloat16_t*) ((uintptr_t) c${M-1} + cm_stride); 43 $if M % 2 == 0: 44 if XNN_UNPREDICTABLE(mr <= ${M}) { 45 a${M} = a${M-1}; 46 c${M} = c${M-1}; 47 } 48 $elif M + 1 == MR: 49 if XNN_UNPREDICTABLE(mr != ${M+1}) { 50 a${M} = a${M-1}; 51 c${M} = c${M-1}; 52 } 53 $else: 54 if XNN_UNPREDICTABLE(mr < ${M+1}) { 55 a${M} = a${M-1}; 56 c${M} = c${M-1}; 57 } 58 59 const bfloat16_t* w = (const bfloat16_t*) w_ptr; 60 do { 61 $for N in range(NR): 62 float32x4_t vacc0x${ABC[N]} = vcvt_f32_bf16(vld1_lane_bf16(w, vreinterpret_bf16_u16(vdup_n_u16(0)), 0)); w += 1; 63 $for M in range(1, MR): 64 $for N in range(NR): 65 float32x4_t vacc${M}x${ABC[N]} = vacc0x${ABC[N]}; 66 67 size_t k = kc; 68 for (; k >= 8 * sizeof(bfloat16_t); k -= 8 * sizeof(bfloat16_t)) { 69 $for M in range(MR): 70 const bfloat16x8_t va${M} = vld1q_bf16(a${M}); a${M} += 8; 71 72 $for N in range(NR): 73 const bfloat16x8_t vb${ABC[N]} = vld1q_bf16(w); w += 8; 74 75 $if BFOPT == "BFDOT": 76 $for N in range(NR): 77 $for M in range(MR): 78 vacc${M}x${ABC[N]} = vbfdotq_f32(vacc${M}x${ABC[N]}, va${M}, vb${ABC[N]}); 79 $elif BFOPT == "BFMLAL": 80 $for N in range(NR): 81 $for M in range(MR): 82 vacc${M}x${ABC[N]} = vbfmlalbq_f32(vacc${M}x${ABC[N]}, va${M}, vb${ABC[N]}); 83 84 $for N in range(NR): 85 $for M in range(MR): 86 vacc${M}x${ABC[N]} = vbfmlaltq_f32(vacc${M}x${ABC[N]}, va${M}, vb${ABC[N]}); 87 } 88 if XNN_UNLIKELY(k != 0) { 89 $for M in range(MR): 90 const bfloat16x8_t va${M} = vld1q_bf16(a${M}); a${M} = (const bfloat16_t*) ((uintptr_t) a${M} + k); 91 92 $for N in range(NR): 93 const bfloat16x8_t vb${ABC[N]} = vld1q_bf16(w); w += 8; 94 95 $for N in range(NR): 96 const uint16x8_t vm${ABC[N]} = vceqq_u16(vreinterpretq_u16_bf16(vb${ABC[N]}), vmovq_n_u16(0)); 97 98 $for N in range(NR): 99 $for M in range(MR): 100 const bfloat16x8_t va${M}x${ABC[N]} = vreinterpretq_bf16_u16(vbicq_u16(vreinterpretq_u16_bf16(va${M}), vm${ABC[N]})); 101 $if BFOPT == "BFDOT": 102 vacc${M}x${ABC[N]} = vbfdotq_f32(vacc${M}x${ABC[N]}, va${M}x${ABC[N]}, vb${ABC[N]}); 103 $elif BFOPT == "BFMLAL": 104 vacc${M}x${ABC[N]} = vbfmlalbq_f32(vacc${M}x${ABC[N]}, va${M}x${ABC[N]}, vb${ABC[N]}); 105 vacc${M}x${ABC[N]} = vbfmlaltq_f32(vacc${M}x${ABC[N]}, va${M}x${ABC[N]}, vb${ABC[N]}); 106 } 107 108#if XNN_ARCH_ARM64 109 $for N in range(0, NR, 2): 110 $for M in range(MR): 111 const float32x4_t vacc${M}x${ABC[N:N+2]} = vpaddq_f32(vacc${M}x${ABC[N]}, vacc${M}x${ABC[N+1]}); 112 113 $for N in range(0, NR, 4): 114 $for M in range(MR): 115 float32x4_t vacc${M}x${ABC[N:N+4]} = vpaddq_f32(vacc${M}x${ABC[N:N+2]}, vacc${M}x${ABC[N+2:N+4]}); 116#else 117 $for N in range(NR): 118 $for M in range(MR): 119 const float32x2_t vsum${M}x${ABC[N]} = vadd_f32(vget_low_f32(vacc${M}x${ABC[N]}), vget_high_f32(vacc${M}x${ABC[N]})); 120 121 $for N in range(0, NR, 4): 122 $for M in range(MR): 123 float32x4_t vacc${M}x${ABC[N:N+4]} = vcombine_f32(vpadd_f32(vsum${M}x${ABC[N]}, vsum${M}x${ABC[N+1]}), vpadd_f32(vsum${M}x${ABC[N+2]}, vsum${M}x${ABC[N+3]})); 124#endif 125 126 const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); 127 $for N in range(0, NR, 4): 128 $for M in range(MR): 129 vacc${M}x${ABC[N:N+4]} = vminq_f32(vacc${M}x${ABC[N:N+4]}, vmax); 130 131 const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); 132 $for N in range(0, NR, 4): 133 $for M in range(MR): 134 vacc${M}x${ABC[N:N+4]} = vmaxq_f32(vacc${M}x${ABC[N:N+4]}, vmin); 135 136 $for N in range(0, NR, 4): 137 $for M in range(MR): 138 bfloat16x4_t vout${M}x${ABC[N:N+4]} = vcvt_bf16_f32(vacc${M}x${ABC[N:N+4]}); 139 140 if XNN_LIKELY(nc >= ${NR}) { 141 $for M in range(MR): 142 vst1_bf16(c${M}, vout${M}x${ABC[0:4]}); 143 $for N in range(4, NR, 4): 144 vst1_bf16(c${M} + ${N}, vout${M}x${ABC[N:N+4]}); 145 c${M} = (bfloat16_t*) ((uintptr_t) c${M} + cn_stride); 146 147 $for M in range(MR): 148 a${M} = (const bfloat16_t*) ((uintptr_t) a${M} - kc); 149 150 nc -= ${NR}; 151 } else { 152 $for LOG2N in reversed(range(NR.bit_length())): 153 $if NR != 1 << LOG2N: 154 if (nc & ${1 << LOG2N}) { 155 $if LOG2N >= 2: 156 $for N in range(0, 1 << LOG2N, 4): 157 $for M in range(MR): 158 vst1_bf16(c${M}, vout${M}x${ABC[N:N+4]}); c${M} += 4; 159 160 $for M in range(MR): 161 $for N in range(0, 1 << (LOG2N - 1), 4): 162 vout${M}x${ABC[N:N+4]} = vout${M}x${ABC[N + (1 << LOG2N):N + (1 << LOG2N)+4]}; 163 $elif LOG2N == 1: 164 $for M in range(MR): 165 vst1_lane_u32((void*) c${M}, vreinterpret_u32_bf16(vout${M}x${ABC[0:4]}), 0); c${M} += 2; 166 167 $for M in range(MR): 168 vout${M}x${ABC[0:4]} = vreinterpret_bf16_u16(vext_u16(vreinterpret_u16_bf16(vout${M}x${ABC[0:4]}), vreinterpret_u16_bf16(vout${M}x${ABC[0:4]}), 2)); 169 $elif LOG2N == 0: 170 $for M in range(MR): 171 vst1_lane_bf16(c${M}, vout${M}x${ABC[0:4]}, 0); 172 } 173 174 nc = 0; 175 } 176 } while (nc != 0); 177} 178