1// Copyright 2022 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert PIXEL_TILE >= 1 7$assert PIXEL_TILE % 4 == 0 8$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 9#include <assert.h> 10 11#include <arm_neon.h> 12 13#include <xnnpack/ibilinear.h> 14 15 16void xnn_f16_ibilinear_chw_ukernel__neonfp16arith_p${PIXEL_TILE}( 17 size_t output_pixels, 18 size_t channels, 19 const void**restrict input, 20 size_t input_offset, 21 const void*restrict weights, 22 void*restrict output, 23 size_t input_increment) XNN_OOB_READS 24{ 25 assert(output_pixels != 0); 26 assert(channels != 0); 27 assert(input_increment % sizeof(__fp16) == 0); 28 29 __fp16* o = (__fp16*) output; 30 do { 31 const __fp16** i = (const __fp16**)input; 32 const __fp16* w = weights; 33 size_t p = output_pixels; 34 35 $if PIXEL_TILE > 4: 36 for (; p >= ${PIXEL_TILE}; p -= ${PIXEL_TILE}) { 37 $for P in range(PIXEL_TILE): 38 const __fp16* itl${ABC[P]} = (const __fp16*) ((uintptr_t) i[${2 * P}] + input_offset); 39 const __fp16* ibl${ABC[P]} = (const __fp16*) ((uintptr_t) i[${2 * P + 1}] + input_offset); 40 i += 2 * ${PIXEL_TILE}; 41 42 $for P in range(0, PIXEL_TILE, 4): 43 const float16x4x2_t vw${ABC[P:P+4]} = vld2_f16(w + ${2 * P}); 44 w += 2 * ${PIXEL_TILE}; 45 46 $for P in range(0, PIXEL_TILE, 4): 47 float16x8_t vtltr${ABC[P:P+4]} = vmovq_n_f16(0); // vmov for uninitialized var warning 48 float16x8_t vblbr${ABC[P:P+4]} = vmovq_n_f16(0); 49 $for L in range(0, 4): 50 vtltr${ABC[P:P+4]} = vreinterpretq_f16_u32(vld1q_lane_u32((const void*) itl${ABC[P+L]}, vreinterpretq_u32_f16(vtltr${ABC[P:P+4]}), ${L})); 51 vblbr${ABC[P:P+4]} = vreinterpretq_f16_u32(vld1q_lane_u32((const void*) ibl${ABC[P+L]}, vreinterpretq_u32_f16(vblbr${ABC[P:P+4]}), ${L})); 52 53 $for P in range(0, PIXEL_TILE, 8): 54 const float16x8_t valphah${ABC[P:P+8]} = vcombine_f16(vw${ABC[P:P+4]}.val[0], vw${ABC[P+4:P+8]}.val[0]); 55 const float16x8_t valphav${ABC[P:P+8]} = vcombine_f16(vw${ABC[P:P+4]}.val[1], vw${ABC[P+4:P+8]}.val[1]); 56 57 $for P in range(0, PIXEL_TILE, 4): 58 const float16x8_t vldrd${ABC[P:P+4]} = vsubq_f16(vblbr${ABC[P:P+4]}, vtltr${ABC[P:P+4]}); 59 60 $for P in range(0, PIXEL_TILE, 8): 61 const float16x8x2_t vld_t${ABC[P:P+8]} = vuzpq_f16(vldrd${ABC[P:P+4]}, vldrd${ABC[P+4:P+8]}); 62 const float16x8_t vld${ABC[P:P+8]} = vld_t${ABC[P:P+8]}.val[0]; 63 const float16x8_t vrd${ABC[P:P+8]} = vld_t${ABC[P:P+8]}.val[1]; 64 65 $for P in range(0, PIXEL_TILE, 8): 66 const float16x8x2_t vtl_t${ABC[P:P+8]} = vuzpq_f16(vtltr${ABC[P:P+4]}, vtltr${ABC[P+4:P+8]}); 67 const float16x8_t vtl${ABC[P:P+8]} = vtl_t${ABC[P:P+8]}.val[0]; 68 const float16x8_t vtr${ABC[P:P+8]} = vtl_t${ABC[P:P+8]}.val[1]; 69 70 $for P in range(0, PIXEL_TILE, 8): 71 const float16x8_t vl${ABC[P:P+8]} = vfmaq_f16(vtl${ABC[P:P+8]}, vld${ABC[P:P+8]}, valphav${ABC[P:P+8]}); 72 const float16x8_t vr${ABC[P:P+8]} = vfmaq_f16(vtr${ABC[P:P+8]}, vrd${ABC[P:P+8]}, valphav${ABC[P:P+8]}); 73 74 $for P in range(0, PIXEL_TILE, 8): 75 const float16x8_t vd${ABC[P:P+8]} = vsubq_f16(vr${ABC[P:P+8]}, vl${ABC[P:P+8]}); 76 $for P in range(0, PIXEL_TILE, 8): 77 const float16x8_t vo${ABC[P:P+8]} = vfmaq_f16(vl${ABC[P:P+8]}, vd${ABC[P:P+8]}, valphah${ABC[P:P+8]}); 78 79 $for P in range(0, PIXEL_TILE, 8): 80 vst1q_f16(o + ${P}, vo${ABC[P:P+8]}); 81 o += ${PIXEL_TILE}; 82 } 83 84 for (; p >= 4; p -= 4) { 85 $for P in range(4): 86 const __fp16* itl${ABC[P]} = (const __fp16*) ((uintptr_t) i[${2 * P}] + input_offset); 87 const __fp16* ibl${ABC[P]} = (const __fp16*) ((uintptr_t) i[${2 * P + 1}] + input_offset); 88 i += 8; 89 90 const float16x4x2_t vw = vld2_f16(w); 91 w += 8; 92 93 float16x8_t vtltr = vmovq_n_f16(0); // vmov for uninitialized var warning 94 float16x8_t vblbr = vmovq_n_f16(0); 95 $for P in range(0, 4): 96 vtltr = vreinterpretq_f16_u32(vld1q_lane_u32((const void*) itl${ABC[P]}, vreinterpretq_u32_f16(vtltr), ${P})); 97 vblbr = vreinterpretq_f16_u32(vld1q_lane_u32((const void*) ibl${ABC[P]}, vreinterpretq_u32_f16(vblbr), ${P})); 98 99 const float16x4_t valphah = vw.val[0]; 100 const float16x4_t valphav = vw.val[1]; 101 102 const float16x8_t vldrd = vsubq_f16(vblbr, vtltr); 103 104 const float16x4x2_t vld_t = vuzp_f16(vget_low_f16(vldrd), vget_high_f16(vldrd)); 105 const float16x4_t vld = vld_t.val[0]; 106 const float16x4_t vrd = vld_t.val[1]; 107 108 const float16x4x2_t vtl_t = vuzp_f16(vget_low_f16(vtltr), vget_high_f16(vtltr)); 109 const float16x4_t vtl = vtl_t.val[0]; 110 const float16x4_t vtr = vtl_t.val[1]; 111 112 const float16x4_t vl = vfma_f16(vtl, vld, valphav); 113 const float16x4_t vr = vfma_f16(vtr, vrd, valphav); 114 115 const float16x4_t vd = vsub_f16(vr, vl); 116 const float16x4_t vo = vfma_f16(vl, vd, valphah); 117 118 vst1_f16(o, vo); 119 o += 4; 120 } 121 122 if XNN_UNLIKELY(p != 0) { 123 if (p & 2) { 124 $for P in range(2): 125 const __fp16* itl${ABC[P]} = (const __fp16*) ((uintptr_t) i[${2 * P}] + input_offset); 126 const __fp16* ibl${ABC[P]} = (const __fp16*) ((uintptr_t) i[${2 * P + 1}] + input_offset); 127 i += 4; 128 129 const float16x4_t vw = vld1_f16(w); 130 w += 4; 131 132 const float16x4x2_t vwhv = vuzp_f16(vw, vw); 133 const float16x4_t valphah = vwhv.val[0]; 134 const float16x4_t valphav = vwhv.val[1]; 135 136 float16x4_t vtltr = vmov_n_f16(0); // vmov for uninitialized var warning 137 float16x4_t vblbr = vmov_n_f16(0); 138 139 $for P in range(0, 2): 140 vtltr = vreinterpret_f16_u32(vld1_lane_u32((const void*) itl${ABC[P]}, vreinterpret_u32_f16(vtltr), ${P})); 141 vblbr = vreinterpret_f16_u32(vld1_lane_u32((const void*) ibl${ABC[P]}, vreinterpret_u32_f16(vblbr), ${P})); 142 143 const float16x4_t vldrd = vsub_f16(vblbr, vtltr); 144 145 const float16x4x2_t vld_t = vuzp_f16(vldrd, vldrd); 146 const float16x4_t vld = vld_t.val[0]; 147 const float16x4_t vrd = vld_t.val[1]; 148 149 const float16x4x2_t vtl_t = vuzp_f16(vtltr, vtltr); 150 const float16x4_t vtl = vtl_t.val[0]; 151 const float16x4_t vtr = vtl_t.val[1]; 152 153 const float16x4_t vl = vfma_f16(vtl, vld, valphav); 154 const float16x4_t vr = vfma_f16(vtr, vrd, valphav); 155 156 const float16x4_t vd = vsub_f16(vr, vl); 157 const float16x4_t vo = vfma_f16(vl, vd, valphah); 158 159 vst1_lane_u32((void*) o, vreinterpret_u32_f16(vo), 0); 160 o += 2; 161 } 162 163 if (p & 1) { 164 // We are computing the following formula: 165 // result = (1 - alpha_h) * (1 - alpha_v) * top_left + 166 // alpha_h * (1 - alpha_v) * top_right + 167 // (1 - alpha_h) * alpha_v * bottom_left + 168 // alpha_h * alpha_v * bottom_right. 169 // 170 // Rearranging gives 171 // result = left + alpha_h * (right - left), 172 // where 173 // left = top_left + alpha_v * (bottom_left - top_left), 174 // right = top_right + alpha_v * (bottom_right - top_right). 175 176 const __fp16* itl = (const __fp16*) ((uintptr_t) i[0] + input_offset); 177 const __fp16* ibl = (const __fp16*) ((uintptr_t) i[1] + input_offset); 178 i += 2; 179 180 float16x4_t vw = vmov_n_f16(0); 181 vw = vreinterpret_f16_u32(vld1_lane_u32((const void*) w, vreinterpret_u32_f16(vw), 0)); 182 w += 2; 183 184 const float16x4x2_t vwhv = vuzp_f16(vw, vw); 185 const float16x4_t valphah = vwhv.val[0]; 186 const float16x4_t valphav = vwhv.val[1]; 187 188 float16x4_t vtltr = vmov_n_f16(0); // vmov for uninitialized var warning 189 float16x4_t vblbr = vmov_n_f16(0); 190 191 vtltr = vreinterpret_f16_u32(vld1_lane_u32((const void*) itl, vreinterpret_u32_f16(vtltr), 0)); 192 vblbr = vreinterpret_f16_u32(vld1_lane_u32((const void*) ibl, vreinterpret_u32_f16(vblbr), 0)); 193 194 const float16x4_t vldrd = vsub_f16(vblbr, vtltr); 195 196 const float16x4x2_t vld_t = vuzp_f16(vldrd, vldrd); 197 const float16x4_t vld = vld_t.val[0]; 198 const float16x4_t vrd = vld_t.val[1]; 199 200 const float16x4x2_t vtl_t = vuzp_f16(vtltr, vtltr); 201 const float16x4_t vtl = vtl_t.val[0]; 202 const float16x4_t vtr = vtl_t.val[1]; 203 204 const float16x4_t vl = vfma_f16(vtl, vld, valphav); 205 const float16x4_t vr = vfma_f16(vtr, vrd, valphav); 206 207 const float16x4_t vd = vsub_f16(vr, vl); 208 const float16x4_t vo = vfma_f16(vl, vd, valphah); 209 210 vst1_lane_f16(o, vo, 0); 211 o += 1; 212 } 213 } 214 215 input_offset += input_increment; 216 } while (--channels != 0); 217} 218