1// Copyright 2022 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert CHANNEL_TILE % 8 == 0 7$assert CHANNEL_TILE >= 8 8$assert ROW_TILE >= 3 9$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 10#include <assert.h> 11 12#include <arm_neon.h> 13 14#include <xnnpack/gavgpool.h> 15 16 17void xnn_f16_gavgpool_minmax_ukernel_${ROW_TILE}x__neonfp16arith_c${CHANNEL_TILE}( 18 size_t rows, 19 size_t channels, 20 const void* input, 21 size_t input_stride, 22 const void* zero, 23 void* output, 24 const union xnn_f16_scaleminmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS 25{ 26 assert(rows != 0); 27 assert(rows <= ${ROW_TILE}); 28 assert(channels != 0); 29 30 const __fp16* i0 = input; 31 $for M in range(1, ROW_TILE): 32 const __fp16* i${M} = (const __fp16*) ((uintptr_t) i${M-1} + input_stride); 33 $if M % 2 == 1: 34 if XNN_UNPREDICTABLE(rows < ${M+1}) { 35 i${M} = (const __fp16*) zero; 36 } 37 $else: 38 if XNN_UNPREDICTABLE(rows <= ${M}) { 39 i${M} = (const __fp16*) zero; 40 } 41 42 const float16x8_t vscale = vreinterpretq_f16_u16(vld1q_dup_u16(¶ms->neon.scale)); 43 const float16x8_t vmin = vreinterpretq_f16_u16(vld1q_dup_u16(¶ms->neon.min)); 44 const float16x8_t vmax = vreinterpretq_f16_u16(vld1q_dup_u16(¶ms->neon.max)); 45 for (; channels >= ${CHANNEL_TILE}; channels -= ${CHANNEL_TILE}) { 46 $for M in range(2): 47 $for C in range(0, CHANNEL_TILE, 8): 48 const float16x8_t vi${M}x${ABC[C:C+8]} = vld1q_f16(i${M}); i${M} += 8; 49 50 $for C in range(0, CHANNEL_TILE, 8): 51 const float16x8_t vi2x${ABC[C:C+8]} = vld1q_f16(i2); i2 += 8; 52 float16x8_t vacc${ABC[C:C+8]} = vaddq_f16(vi0x${ABC[C:C+8]}, vi1x${ABC[C:C+8]}); 53 54 $for M in range(2, ROW_TILE): 55 $for C in range(0, CHANNEL_TILE, 8): 56 $if M + 1 != ROW_TILE: 57 const float16x8_t vi${M+1}x${ABC[C:C+8]} = vld1q_f16(i${M+1}); i${M+1} += 8; 58 vacc${ABC[C:C+8]} = vaddq_f16(vacc${ABC[C:C+8]}, vi${M}x${ABC[C:C+8]}); 59 60 $for C in range(0, CHANNEL_TILE, 8): 61 vacc${ABC[C:C+8]} = vmulq_f16(vacc${ABC[C:C+8]}, vscale); 62 63 $for C in range(0, CHANNEL_TILE, 8): 64 vacc${ABC[C:C+8]} = vmaxq_f16(vacc${ABC[C:C+8]}, vmin); 65 66 $for C in range(0, CHANNEL_TILE, 8): 67 vacc${ABC[C:C+8]} = vminq_f16(vacc${ABC[C:C+8]}, vmax); 68 69 $for C in range(0, CHANNEL_TILE, 8): 70 vst1q_f16(output, vacc${ABC[C:C+8]}); output = (__fp16*) output + 8; 71 } 72 if XNN_UNLIKELY(channels != 0) { 73 ${"do " if CHANNEL_TILE > 8 else ""}{ 74 $for M in range(2): 75 const float16x8_t vi${M}x${ABC[0:8]} = vld1q_f16(i${M}); i${M} += 8; 76 77 const float16x8_t vi2x${ABC[0:8]} = vld1q_f16(i2); i2 += 8; 78 float16x8_t vacc${ABC[0:8]} = vaddq_f16(vi0x${ABC[0:8]}, vi1x${ABC[0:8]}); 79 80 $for M in range(2, ROW_TILE): 81 $if M + 1 != ROW_TILE: 82 const float16x8_t vi${M+1}x${ABC[0:8]} = vld1q_f16(i${M+1}); i${M+1} += 8; 83 vacc${ABC[0:8]} = vaddq_f16(vacc${ABC[0:8]}, vi${M}x${ABC[0:8]}); 84 85 vacc${ABC[0:8]} = vmulq_f16(vacc${ABC[0:8]}, vscale); 86 vacc${ABC[0:8]} = vmaxq_f16(vacc${ABC[0:8]}, vmin); 87 vacc${ABC[0:8]} = vminq_f16(vacc${ABC[0:8]}, vmax); 88 89 $if CHANNEL_TILE > 8: 90 if XNN_LIKELY(channels >= 8) { 91 vst1q_f16(output, vacc${ABC[0:8]}); output = (__fp16*) output + 8; 92 channels -= 8; 93 } else { 94 float16x4_t vacc${ABC[0:4]} = vget_low_f16(vacc${ABC[0:8]}); 95 if (channels & 4) { 96 vst1_f16(output, vacc${ABC[0:4]}); output = (__fp16*) output + 4; 97 vacc${ABC[0:4]} = vget_high_f16(vacc${ABC[0:8]}); 98 } 99 if (channels & 2) { 100 vst1_lane_u32(output, vreinterpret_u32_f16(vacc${ABC[0:4]}), 0); output = (__fp16*) output + 2; 101 vacc${ABC[0:4]} = vext_f16(vacc${ABC[0:4]}, vacc${ABC[0:4]}, 2); 102 } 103 if (channels & 1) { 104 vst1_lane_f16(output, vacc${ABC[0:4]}, 0); output = (__fp16*) output + 1; 105 } 106 channels = 0; 107 } 108 $else: 109 float16x4_t vacc${ABC[0:4]} = vget_low_f16(vacc${ABC[0:8]}); 110 if (channels & 4) { 111 vst1_f16(output, vacc${ABC[0:4]}); output = (__fp16*) output + 4; 112 vacc${ABC[0:4]} = vget_high_f16(vacc${ABC[0:8]}); 113 } 114 if (channels & 2) { 115 vst1_lane_u32(output, vreinterpret_u32_f16(vacc${ABC[0:4]}), 0); output = (__fp16*) output + 2; 116 vacc${ABC[0:4]} = vext_f16(vacc${ABC[0:4]}, vacc${ABC[0:4]}, 2); 117 } 118 if (channels & 1) { 119 vst1_lane_f16(output, vacc${ABC[0:4]}, 0); output = (__fp16*) output + 1; 120 } 121 }${" while (channels != 0);" if CHANNEL_TILE > 8 else ""} 122 } 123} 124