1*4bdc9457SAndroid Build Coastguard Worker// Copyright 2020 Google LLC 2*4bdc9457SAndroid Build Coastguard Worker// 3*4bdc9457SAndroid Build Coastguard Worker// This source code is licensed under the BSD-style license found in the 4*4bdc9457SAndroid Build Coastguard Worker// LICENSE file in the root directory of this source tree. 5*4bdc9457SAndroid Build Coastguard Worker 6*4bdc9457SAndroid Build Coastguard Worker$assert NR % 8 == 0 7*4bdc9457SAndroid Build Coastguard Worker$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 8*4bdc9457SAndroid Build Coastguard Worker 9*4bdc9457SAndroid Build Coastguard Worker#include <assert.h> 10*4bdc9457SAndroid Build Coastguard Worker 11*4bdc9457SAndroid Build Coastguard Worker#include <arm_neon.h> 12*4bdc9457SAndroid Build Coastguard Worker 13*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/common.h> 14*4bdc9457SAndroid Build Coastguard Worker 15*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/gemm.h> 16*4bdc9457SAndroid Build Coastguard Worker 17*4bdc9457SAndroid Build Coastguard Worker 18*4bdc9457SAndroid Build Coastguard Workervoid xnn_f16_gemm${"inc" if INC else ""}_minmax_ukernel_${MR}x${NR}__neonfp16arith_ld64( 19*4bdc9457SAndroid Build Coastguard Worker size_t mr, 20*4bdc9457SAndroid Build Coastguard Worker size_t nc, 21*4bdc9457SAndroid Build Coastguard Worker size_t kc, 22*4bdc9457SAndroid Build Coastguard Worker const void* restrict a, 23*4bdc9457SAndroid Build Coastguard Worker size_t a_stride, 24*4bdc9457SAndroid Build Coastguard Worker const void* restrict w, 25*4bdc9457SAndroid Build Coastguard Worker void* restrict c, 26*4bdc9457SAndroid Build Coastguard Worker size_t cm_stride, 27*4bdc9457SAndroid Build Coastguard Worker size_t cn_stride, 28*4bdc9457SAndroid Build Coastguard Worker $if INC: 29*4bdc9457SAndroid Build Coastguard Worker const void*restrict acc, 30*4bdc9457SAndroid Build Coastguard Worker const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) 31*4bdc9457SAndroid Build Coastguard Worker{ 32*4bdc9457SAndroid Build Coastguard Worker assert(mr != 0); 33*4bdc9457SAndroid Build Coastguard Worker assert(mr <= ${MR}); 34*4bdc9457SAndroid Build Coastguard Worker assert(nc != 0); 35*4bdc9457SAndroid Build Coastguard Worker assert(kc != 0); 36*4bdc9457SAndroid Build Coastguard Worker assert(kc % sizeof(__fp16) == 0); 37*4bdc9457SAndroid Build Coastguard Worker assert(a != NULL); 38*4bdc9457SAndroid Build Coastguard Worker assert(w != NULL); 39*4bdc9457SAndroid Build Coastguard Worker assert(c != NULL); 40*4bdc9457SAndroid Build Coastguard Worker $if INC: 41*4bdc9457SAndroid Build Coastguard Worker assert(acc != NULL); 42*4bdc9457SAndroid Build Coastguard Worker 43*4bdc9457SAndroid Build Coastguard Worker const __fp16* a0 = (const __fp16*) a; 44*4bdc9457SAndroid Build Coastguard Worker __fp16* c0 = (__fp16*) c; 45*4bdc9457SAndroid Build Coastguard Worker $for M in range(1, MR): 46*4bdc9457SAndroid Build Coastguard Worker const __fp16* a${M} = (const __fp16*) ((uintptr_t) a${M-1} + a_stride); 47*4bdc9457SAndroid Build Coastguard Worker __fp16* c${M} = (__fp16*) ((uintptr_t) c${M-1} + cm_stride); 48*4bdc9457SAndroid Build Coastguard Worker $if M % 2 == 0: 49*4bdc9457SAndroid Build Coastguard Worker if XNN_UNPREDICTABLE(mr <= ${M}) { 50*4bdc9457SAndroid Build Coastguard Worker a${M} = a${M-1}; 51*4bdc9457SAndroid Build Coastguard Worker c${M} = c${M-1}; 52*4bdc9457SAndroid Build Coastguard Worker } 53*4bdc9457SAndroid Build Coastguard Worker $elif M + 1 == MR: 54*4bdc9457SAndroid Build Coastguard Worker if XNN_UNPREDICTABLE(mr != ${M+1}) { 55*4bdc9457SAndroid Build Coastguard Worker a${M} = a${M-1}; 56*4bdc9457SAndroid Build Coastguard Worker c${M} = c${M-1}; 57*4bdc9457SAndroid Build Coastguard Worker } 58*4bdc9457SAndroid Build Coastguard Worker $else: 59*4bdc9457SAndroid Build Coastguard Worker if XNN_UNPREDICTABLE(mr < ${M+1}) { 60*4bdc9457SAndroid Build Coastguard Worker a${M} = a${M-1}; 61*4bdc9457SAndroid Build Coastguard Worker c${M} = c${M-1}; 62*4bdc9457SAndroid Build Coastguard Worker } 63*4bdc9457SAndroid Build Coastguard Worker 64*4bdc9457SAndroid Build Coastguard Worker do { 65*4bdc9457SAndroid Build Coastguard Worker $if INC: 66*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 67*4bdc9457SAndroid Build Coastguard Worker $for N in range(0, NR, 8): 68*4bdc9457SAndroid Build Coastguard Worker float16x8_t vacc${M}x${ABC[N:N+8]} = vld1q_f16(acc); acc = (const void*) ((uintptr_t) acc + sizeof(float16x8_t)); 69*4bdc9457SAndroid Build Coastguard Worker $else: 70*4bdc9457SAndroid Build Coastguard Worker $for N in range(0, NR, 8): 71*4bdc9457SAndroid Build Coastguard Worker float16x8_t vacc0x${ABC[N:N+8]} = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); 72*4bdc9457SAndroid Build Coastguard Worker $for M in range(1, MR): 73*4bdc9457SAndroid Build Coastguard Worker $for N in range(0, NR, 8): 74*4bdc9457SAndroid Build Coastguard Worker float16x8_t vacc${M}x${ABC[N:N+8]} = vacc0x${ABC[N:N+8]}; 75*4bdc9457SAndroid Build Coastguard Worker 76*4bdc9457SAndroid Build Coastguard Worker size_t k = kc; 77*4bdc9457SAndroid Build Coastguard Worker while (k >= 4 * sizeof(__fp16)) { 78*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 79*4bdc9457SAndroid Build Coastguard Worker const float16x4_t va${M} = vld1_f16(a${M}); a${M} += 4; 80*4bdc9457SAndroid Build Coastguard Worker 81*4bdc9457SAndroid Build Coastguard Worker $for L in range(4): 82*4bdc9457SAndroid Build Coastguard Worker $for N in range(0, NR, 8): 83*4bdc9457SAndroid Build Coastguard Worker const float16x8_t vb${ABC[N:N+8]}c${L} = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); 84*4bdc9457SAndroid Build Coastguard Worker 85*4bdc9457SAndroid Build Coastguard Worker #if XNN_ARCH_ARM64 86*4bdc9457SAndroid Build Coastguard Worker $for N in range(0, NR, 8): 87*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 88*4bdc9457SAndroid Build Coastguard Worker vacc${M}x${ABC[N:N+8]} = vfmaq_lane_f16(vacc${M}x${ABC[N:N+8]}, vb${ABC[N:N+8]}c${L}, va${M}, ${L}); 89*4bdc9457SAndroid Build Coastguard Worker #else 90*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 91*4bdc9457SAndroid Build Coastguard Worker const float16x8_t va${M}c${L} = vdupq_lane_f16(va${M}, ${L}); 92*4bdc9457SAndroid Build Coastguard Worker 93*4bdc9457SAndroid Build Coastguard Worker $for N in range(0, NR, 8): 94*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 95*4bdc9457SAndroid Build Coastguard Worker vacc${M}x${ABC[N:N+8]} = vfmaq_f16(vacc${M}x${ABC[N:N+8]}, va${M}c${L}, vb${ABC[N:N+8]}c${L}); 96*4bdc9457SAndroid Build Coastguard Worker #endif 97*4bdc9457SAndroid Build Coastguard Worker 98*4bdc9457SAndroid Build Coastguard Worker k -= 4 * sizeof(__fp16); 99*4bdc9457SAndroid Build Coastguard Worker } 100*4bdc9457SAndroid Build Coastguard Worker if XNN_UNLIKELY(k != 0) { 101*4bdc9457SAndroid Build Coastguard Worker do { 102*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 103*4bdc9457SAndroid Build Coastguard Worker const float16x8_t va${M} = vld1q_dup_f16(a${M}); a${M} += 1; 104*4bdc9457SAndroid Build Coastguard Worker 105*4bdc9457SAndroid Build Coastguard Worker $for N in range(0, NR, 8): 106*4bdc9457SAndroid Build Coastguard Worker const float16x8_t vb${ABC[N:N+8]} = vld1q_f16(w); w = (const void*) ((uintptr_t) w + sizeof(float16x8_t)); 107*4bdc9457SAndroid Build Coastguard Worker 108*4bdc9457SAndroid Build Coastguard Worker $for N in range(0, NR, 8): 109*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 110*4bdc9457SAndroid Build Coastguard Worker vacc${M}x${ABC[N:N+8]} = vfmaq_f16(vacc${M}x${ABC[N:N+8]}, va${M}, vb${ABC[N:N+8]}); 111*4bdc9457SAndroid Build Coastguard Worker 112*4bdc9457SAndroid Build Coastguard Worker k -= sizeof(__fp16); 113*4bdc9457SAndroid Build Coastguard Worker } while (k != 0); 114*4bdc9457SAndroid Build Coastguard Worker } 115*4bdc9457SAndroid Build Coastguard Worker 116*4bdc9457SAndroid Build Coastguard Worker 117*4bdc9457SAndroid Build Coastguard Worker const float16x8_t vmax = vreinterpretq_f16_u16(vld1q_dup_u16(¶ms->neon.max)); 118*4bdc9457SAndroid Build Coastguard Worker $for N in range(0, NR, 8): 119*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 120*4bdc9457SAndroid Build Coastguard Worker vacc${M}x${ABC[N:N+8]} = vminq_f16(vacc${M}x${ABC[N:N+8]}, vmax); 121*4bdc9457SAndroid Build Coastguard Worker 122*4bdc9457SAndroid Build Coastguard Worker const float16x8_t vmin = vreinterpretq_f16_u16(vld1q_dup_u16(¶ms->neon.min)); 123*4bdc9457SAndroid Build Coastguard Worker $for N in range(0, NR, 8): 124*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 125*4bdc9457SAndroid Build Coastguard Worker vacc${M}x${ABC[N:N+8]} = vmaxq_f16(vacc${M}x${ABC[N:N+8]}, vmin); 126*4bdc9457SAndroid Build Coastguard Worker 127*4bdc9457SAndroid Build Coastguard Worker if XNN_LIKELY(nc >= ${NR}) { 128*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 129*4bdc9457SAndroid Build Coastguard Worker vst1q_f16(c${M}, vacc${M}x${ABC[0:8]}); 130*4bdc9457SAndroid Build Coastguard Worker $for N in range(8, NR, 8): 131*4bdc9457SAndroid Build Coastguard Worker vst1q_f16(c${M} + ${N}, vacc${M}x${ABC[N:N+8]}); 132*4bdc9457SAndroid Build Coastguard Worker c${M} = (__fp16*) ((uintptr_t) c${M} + cn_stride); 133*4bdc9457SAndroid Build Coastguard Worker 134*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 135*4bdc9457SAndroid Build Coastguard Worker a${M} = (const __fp16*) ((uintptr_t) a${M} - kc); 136*4bdc9457SAndroid Build Coastguard Worker 137*4bdc9457SAndroid Build Coastguard Worker nc -= ${NR}; 138*4bdc9457SAndroid Build Coastguard Worker } else { 139*4bdc9457SAndroid Build Coastguard Worker $for LOG2N in reversed(range(NR.bit_length())): 140*4bdc9457SAndroid Build Coastguard Worker $if NR != 1 << LOG2N: 141*4bdc9457SAndroid Build Coastguard Worker if (nc & ${1 << LOG2N}) { 142*4bdc9457SAndroid Build Coastguard Worker $if LOG2N >= 3: 143*4bdc9457SAndroid Build Coastguard Worker $for N in range(0, 1 << LOG2N, 8): 144*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 145*4bdc9457SAndroid Build Coastguard Worker vst1q_f16(c${M}, vacc${M}x${ABC[N:N+8]}); c${M} += 8; 146*4bdc9457SAndroid Build Coastguard Worker 147*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 148*4bdc9457SAndroid Build Coastguard Worker $for N in range(0, 1 << (LOG2N - 1), 8): 149*4bdc9457SAndroid Build Coastguard Worker vacc${M}x${ABC[N:N+8]} = vacc${M}x${ABC[N + (1 << LOG2N):N + (1 << LOG2N)+8]}; 150*4bdc9457SAndroid Build Coastguard Worker $elif LOG2N == 2: 151*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 152*4bdc9457SAndroid Build Coastguard Worker vst1_f16(c${M}, vacc${M}x${ABC[0:4]}); c${M} += 4; 153*4bdc9457SAndroid Build Coastguard Worker 154*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 155*4bdc9457SAndroid Build Coastguard Worker vacc${M}x${ABC[0:4]} = vget_high_f16(vacc${M}x${ABC[0:8]}); 156*4bdc9457SAndroid Build Coastguard Worker $elif LOG2N == 1: 157*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 158*4bdc9457SAndroid Build Coastguard Worker vst1_lane_u32((void*) c${M}, vreinterpret_u32_f16(vacc${M}x${ABC[0:4]}), 0); c${M} += 2; 159*4bdc9457SAndroid Build Coastguard Worker 160*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 161*4bdc9457SAndroid Build Coastguard Worker vacc${M}x${ABC[0:4]} = vext_f16(vacc${M}x${ABC[0:4]}, vacc${M}x${ABC[0:4]}, 2); 162*4bdc9457SAndroid Build Coastguard Worker $elif LOG2N == 0: 163*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 164*4bdc9457SAndroid Build Coastguard Worker vst1_lane_f16(c${M}, vacc${M}x${ABC[0:4]}, 0); 165*4bdc9457SAndroid Build Coastguard Worker } 166*4bdc9457SAndroid Build Coastguard Worker $if LOG2N == 3: 167*4bdc9457SAndroid Build Coastguard Worker $for M in range(MR): 168*4bdc9457SAndroid Build Coastguard Worker float16x4_t vacc${M}x${ABC[0:4]} = vget_low_f16(vacc${M}x${ABC[0:8]}); 169*4bdc9457SAndroid Build Coastguard Worker 170*4bdc9457SAndroid Build Coastguard Worker nc = 0; 171*4bdc9457SAndroid Build Coastguard Worker } 172*4bdc9457SAndroid Build Coastguard Worker } while (nc != 0); 173*4bdc9457SAndroid Build Coastguard Worker} 174