1// Copyright 2020 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert ROW_TILE >= 1 7$assert ACCUMULATORS >= 1 8#include <assert.h> 9 10#include <arm_neon.h> 11 12#include <xnnpack/dwconv.h> 13#include <xnnpack/math.h> 14 15 16void xnn_f16_dwconv2d_chw_ukernel_3x3p1__neonfp16arith_${ROW_TILE}x8${"_acc%d" % ACCUMULATORS if ACCUMULATORS > 1 else ""}( 17 size_t input_height, 18 size_t input_width, 19 const void* input, 20 const void* weights, 21 const void* zero, 22 void* output, 23 uint32_t padding_top, 24 const union xnn_f16_chw_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS 25{ 26 assert(input_height != 0); 27 assert(input_width != 0); 28 assert(input_width % sizeof(__fp16) == 0); 29 assert(padding_top == 1); 30 31 const uint16x8_t vmask = vld1q_u16(params->neonfp16arith.maskx8); 32 const float16x8_t vmax = vld1q_dup_f16(¶ms->neonfp16arith.max); 33 const float16x8_t vmin = vld1q_dup_f16(¶ms->neonfp16arith.min); 34 35 const __fp16* w0 = (const __fp16*)weights; 36 const float16x8_t vw01234567 = vld1q_f16(w0); 37 const float16x4_t vw89 = vreinterpret_f16_u32(vld1_lane_u32((const void*)(w0 + 8), vmov_n_u32(0), 0)); 38 39 const size_t input_decrement = round_up_po2(input_width, 8 * sizeof(__fp16)); 40 41 const __fp16* i0 = zero; 42 const __fp16* i1 = input; 43 $for M in range(2, 2 + ROW_TILE): 44 const __fp16* i${M} = (const __fp16*) ((uintptr_t) i${M-1} + input_width); 45 46 __fp16* o0 = output; 47 $for M in range(1, ROW_TILE): 48 __fp16* o${M} = (__fp16*) ((uintptr_t) o${M-1} + input_width); 49 50 size_t output_height = input_height; 51 do { 52 $for M in range(2, 2 + ROW_TILE): 53 if XNN_UNPREDICTABLE(output_height < ${M}) { 54 i${M} = zero; 55 $if M <= ROW_TILE: 56 o${M-1} = o${M-2}; 57 } 58 59 $for M in range(2 + ROW_TILE): 60 float16x8_t vi${M}x01234567 = vmovq_n_f16(0); 61 62 $for M in range(2 + ROW_TILE): 63 float16x8_t vi${M}x89ABCDEF = vld1q_f16(i${M}); i${M} += 8; 64 65 size_t w = input_width; 66 for (; w > 8 * sizeof(__fp16); w -= 8 * sizeof(__fp16)) { 67 $for M in range(ROW_TILE): 68 float16x8_t vo${M}p0 = vdupq_lane_f16(vget_low_f16(vw01234567), 0); 69 70 $for M in range(2 + ROW_TILE): 71 const float16x8_t vi${M}xGHIJKLMN = vld1q_f16(i${M}); i${M} += 8; 72 73 // Center column 74 $for M in range(ROW_TILE): 75 vo${M}p0 = vfmaq_lane_f16(vo${M}p0, vi${M}x89ABCDEF, vget_low_f16(vw01234567), 2); 76 77 $for M in range(ROW_TILE): 78 $if ACCUMULATORS >= 2: 79 float16x8_t vo${M}p1 = vmulq_lane_f16(vi${M+1}x89ABCDEF, vget_high_f16(vw01234567), 1); 80 $else: 81 vo${M}p0 = vfmaq_lane_f16(vo${M}p0, vi${M+1}x89ABCDEF, vget_high_f16(vw01234567), 1); 82 83 $for M in range(ROW_TILE): 84 $if ACCUMULATORS >= 3: 85 float16x8_t vo${M}p2 = vmulq_lane_f16(vi${M+2}x89ABCDEF, vw89, 0); 86 $else: 87 vo${M}p0 = vfmaq_lane_f16(vo${M}p0, vi${M+2}x89ABCDEF, vw89, 0); 88 89 // Left column 90 $for M in range(2 + ROW_TILE): 91 const float16x8_t vi${M}x789ABCDE = vextq_f16(vi${M}x01234567, vi${M}x89ABCDEF, 7); 92 93 $for M in range(ROW_TILE): 94 $if ACCUMULATORS >= 4: 95 float16x8_t vo${M}p3 = vmulq_lane_f16(vi${M}x789ABCDE, vget_low_f16(vw01234567), 1); 96 $else: 97 vo${M}p${3 % ACCUMULATORS} = vfmaq_lane_f16(vo${M}p${3 % ACCUMULATORS}, vi${M}x789ABCDE, vget_low_f16(vw01234567), 1); 98 99 $for M in range(ROW_TILE): 100 vo${M}p${4 % ACCUMULATORS} = vfmaq_lane_f16(vo${M}p${4 % ACCUMULATORS}, vi${M+1}x789ABCDE, vget_high_f16(vw01234567), 0); 101 102 $for M in range(ROW_TILE): 103 vo${M}p${5 % ACCUMULATORS} = vfmaq_lane_f16(vo${M}p${5 % ACCUMULATORS}, vi${M+2}x789ABCDE, vget_high_f16(vw01234567), 3); 104 105 $for M in range(2 + ROW_TILE): 106 vi${M}x01234567 = vi${M}x89ABCDEF; 107 108 // Right column 109 $for M in range(2 + ROW_TILE): 110 const float16x8_t vi${M}x9ABCDEFG = vextq_f16(vi${M}x89ABCDEF, vi${M}xGHIJKLMN, 1); 111 112 $for M in range(ROW_TILE): 113 vo${M}p${6 % ACCUMULATORS} = vfmaq_lane_f16(vo${M}p${6 % ACCUMULATORS}, vi${M}x9ABCDEFG, vget_low_f16(vw01234567), 3); 114 115 $for M in range(ROW_TILE): 116 vo${M}p${7 % ACCUMULATORS} = vfmaq_lane_f16(vo${M}p${7 % ACCUMULATORS}, vi${M+1}x9ABCDEFG, vget_high_f16(vw01234567), 2); 117 118 $for M in range(ROW_TILE): 119 vo${M}p${8 % ACCUMULATORS} = vfmaq_lane_f16(vo${M}p${8 % ACCUMULATORS}, vi${M+2}x9ABCDEFG, vw89, 1); 120 121 $for M in range(2 + ROW_TILE): 122 vi${M}x89ABCDEF = vi${M}xGHIJKLMN; 123 124 $if ACCUMULATORS > 1: 125 $ACC_SLICE = 1 126 $while ACC_SLICE < ACCUMULATORS: 127 $for A in range(0, ACCUMULATORS, ACC_SLICE * 2): 128 $if A + ACC_SLICE < ACCUMULATORS: 129 $for M in range(ROW_TILE): 130 vo${M}p${A} = vaddq_f16(vo${M}p${A}, vo${M}p${A + ACC_SLICE}); 131 $ACC_SLICE *= 2 132 133 $for M in range(ROW_TILE): 134 float16x8_t vo${M} = vmaxq_f16(vo${M}p0, vmin); 135 136 $for M in range(ROW_TILE): 137 vo${M} = vminq_f16(vo${M}, vmax); 138 139 $for M in reversed(range(ROW_TILE)): 140 vst1q_f16(o${M}, vo${M}); o${M} += 8; 141 } 142 // Always process the last block of 1..8 pixels. 143 assert(w >= 1 * sizeof(__fp16)); 144 assert(w <= 8 * sizeof(__fp16)); 145 { 146 $for M in range(ROW_TILE): 147 float16x8_t vo${M}p0 = vdupq_lane_f16(vget_low_f16(vw01234567), 0); 148 149 $for M in range(2 + ROW_TILE): 150 vi${M}x89ABCDEF = vreinterpretq_f16_u16(vandq_u16(vmask, vreinterpretq_u16_f16(vi${M}x89ABCDEF))); 151 152 // Center column 153 $for M in range(ROW_TILE): 154 vo${M}p0 = vfmaq_lane_f16(vo${M}p0, vi${M}x89ABCDEF, vget_low_f16(vw01234567), 2); 155 156 $for M in range(ROW_TILE): 157 $if ACCUMULATORS >= 2: 158 float16x8_t vo${M}p1 = vmulq_lane_f16(vi${M+1}x89ABCDEF, vget_high_f16(vw01234567), 1); 159 $else: 160 vo${M}p0 = vfmaq_lane_f16(vo${M}p0, vi${M+1}x89ABCDEF, vget_high_f16(vw01234567), 1); 161 162 $for M in range(ROW_TILE): 163 $if ACCUMULATORS >= 3: 164 float16x8_t vo${M}p2 = vmulq_lane_f16(vi${M+2}x89ABCDEF, vw89, 0); 165 $else: 166 vo${M}p0 = vfmaq_lane_f16(vo${M}p0, vi${M+2}x89ABCDEF, vw89, 0); 167 168 // Left column 169 $for M in range(2 + ROW_TILE): 170 const float16x8_t vi${M}x789ABCDE = vextq_f16(vi${M}x01234567, vi${M}x89ABCDEF, 7); 171 172 $for M in range(ROW_TILE): 173 $if ACCUMULATORS >= 4: 174 float16x8_t vo${M}p3 = vmulq_lane_f16(vi${M}x789ABCDE, vget_low_f16(vw01234567), 1); 175 $else: 176 vo${M}p${3 % ACCUMULATORS} = vfmaq_lane_f16(vo${M}p${3 % ACCUMULATORS}, vi${M}x789ABCDE, vget_low_f16(vw01234567), 1); 177 178 $for M in range(ROW_TILE): 179 vo${M}p${4 % ACCUMULATORS} = vfmaq_lane_f16(vo${M}p${4 % ACCUMULATORS}, vi${M+1}x789ABCDE, vget_high_f16(vw01234567), 0); 180 181 $for M in range(ROW_TILE): 182 vo${M}p${5 % ACCUMULATORS} = vfmaq_lane_f16(vo${M}p${5 % ACCUMULATORS}, vi${M+2}x789ABCDE, vget_high_f16(vw01234567), 3); 183 184 // Right column 185 const float16x8_t vzero = vmovq_n_f16(0); 186 $for M in range(2 + ROW_TILE): 187 const float16x8_t vi${M}x9ABCDEFG = vextq_f16(vi${M}x89ABCDEF, vzero, 1); 188 189 $for M in range(ROW_TILE): 190 vo${M}p${6 % ACCUMULATORS} = vfmaq_lane_f16(vo${M}p${6 % ACCUMULATORS}, vi${M}x9ABCDEFG, vget_low_f16(vw01234567), 3); 191 192 $for M in range(ROW_TILE): 193 vo${M}p${7 % ACCUMULATORS} = vfmaq_lane_f16(vo${M}p${7 % ACCUMULATORS}, vi${M+1}x9ABCDEFG, vget_high_f16(vw01234567), 2); 194 195 $for M in range(ROW_TILE): 196 vo${M}p${8 % ACCUMULATORS} = vfmaq_lane_f16(vo${M}p${8 % ACCUMULATORS}, vi${M+2}x9ABCDEFG, vw89, 1); 197 198 $if ACCUMULATORS > 1: 199 $ACC_SLICE = 1 200 $while ACC_SLICE < ACCUMULATORS: 201 $for A in range(0, ACCUMULATORS, ACC_SLICE * 2): 202 $if A + ACC_SLICE < ACCUMULATORS: 203 $for M in range(ROW_TILE): 204 vo${M}p${A} = vaddq_f16(vo${M}p${A}, vo${M}p${A + ACC_SLICE}); 205 $ACC_SLICE *= 2 206 207 $for M in range(ROW_TILE): 208 float16x8_t vo${M} = vmaxq_f16(vo${M}p0, vmin); 209 210 $for M in range(ROW_TILE): 211 vo${M} = vminq_f16(vo${M}, vmax); 212 213 if XNN_LIKELY(w == 8 * sizeof(__fp16)) { 214 $for M in reversed(range(ROW_TILE)): 215 vst1q_f16(o${M}, vo${M}); o${M} += 8; 216 } else { 217 $for M in reversed(range(ROW_TILE)): 218 float16x4_t vo${M}_lo = vget_low_f16(vo${M}); 219 220 if (w & (4 * sizeof(__fp16))) { 221 $for M in reversed(range(ROW_TILE)): 222 vst1_f16(o${M}, vo${M}_lo); o${M} += 4; 223 224 $for M in reversed(range(ROW_TILE)): 225 vo${M}_lo = vget_high_f16(vo${M}); 226 } 227 if (w & (2 * sizeof(__fp16))) { 228 $for M in reversed(range(ROW_TILE)): 229 vst1_lane_u32((void*) o${M}, vreinterpret_u32_f16(vo${M}_lo), 0); o${M} += 2; 230 231 $for M in range(ROW_TILE): 232 vo${M}_lo = vext_f16(vo${M}_lo, vo${M}_lo, 2); 233 } 234 if (w & (1 * sizeof(__fp16))) { 235 $for M in reversed(range(ROW_TILE)): 236 vst1_lane_f16(o${M}, vo${M}_lo, 0); o${M} += 1; 237 } 238 } 239 } 240 241 i0 = (const __fp16*) ((uintptr_t) i${ROW_TILE} - input_decrement); 242 i1 = (const __fp16*) ((uintptr_t) i${ROW_TILE+1} - input_decrement); 243 $for M in range(2, 2 + ROW_TILE): 244 i${M} = (const __fp16*) ((uintptr_t) i${M-1} + input_width); 245 246 $if ROW_TILE > 1: 247 o0 = o${ROW_TILE - 1}; 248 $for M in range(1, ROW_TILE): 249 o${M} = (__fp16*) ((uintptr_t) o${M-1} + input_width); 250 251 $if ROW_TILE > 1: 252 output_height = doz(output_height, ${ROW_TILE}); 253 } while (${"--" if ROW_TILE == 1 else ""}output_height != 0); 254} 255