1// Copyright 2022 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert CHANNEL_TILE % 8 == 0 7$assert CHANNEL_TILE >= 8 8$assert PIXEL_TILE == 1 9$ABC = "456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 10#include <assert.h> 11 12#include <immintrin.h> 13 14#include <xnnpack/common.h> 15#include <xnnpack/ibilinear.h> 16#include <xnnpack/intrinsics-polyfill.h> 17 18 19void xnn_f16_ibilinear_ukernel__fma3_c${CHANNEL_TILE}${"" if PIXEL_TILE == 1 else "x%d" % PIXEL_TILE}( 20 size_t output_pixels, 21 size_t channels, 22 const void**restrict input, 23 size_t input_offset, 24 const void*restrict weights, 25 void*restrict output, 26 size_t output_increment) XNN_OOB_READS 27{ 28 assert(output_pixels != 0); 29 assert(channels != 0); 30 assert(channels % sizeof(uint16_t) == 0); 31 32 uint16_t* o = (uint16_t*) output; 33 do { 34 const uint16_t* i0 = (const uint16_t*) ((uintptr_t) input[0] + input_offset); 35 const uint16_t* i1 = (const uint16_t*) ((uintptr_t) input[1] + input_offset); 36 const uint16_t* i2 = (const uint16_t*) ((uintptr_t) input[2] + input_offset); 37 const uint16_t* i3 = (const uint16_t*) ((uintptr_t) input[3] + input_offset); 38 input += 4; 39 40 const __m256 valphahv = _mm256_cvtph_ps(_mm_castps_si128(_mm_broadcast_ss(weights))); 41 const __m256 valphah = _mm256_permute_ps(valphahv, _MM_SHUFFLE(2, 0, 2, 0)); 42 const __m256 valphav = _mm256_permute_ps(valphahv, _MM_SHUFFLE(3, 1, 3, 1)); 43 weights = (const uint16_t*) weights + 2; 44 45 size_t c = channels; 46 $if CHANNEL_TILE > 8: 47 for (; c >= ${CHANNEL_TILE} * sizeof(uint16_t); c -= ${CHANNEL_TILE} * sizeof(uint16_t)) { 48 const __m256 vtl${ABC[0:8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0)); 49 const __m256 vtr${ABC[0:8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1)); 50 const __m256 vbl${ABC[0:8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2)); 51 const __m256 vbr${ABC[0:8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3)); 52 $for C in range(8, CHANNEL_TILE, 8): 53 const __m256 vtl${ABC[C:C+8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i0 + ${C}))); 54 const __m256 vtr${ABC[C:C+8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i1 + ${C}))); 55 const __m256 vbl${ABC[C:C+8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i2 + ${C}))); 56 const __m256 vbr${ABC[C:C+8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i3 + ${C}))); 57 i0 += ${CHANNEL_TILE}; 58 i1 += ${CHANNEL_TILE}; 59 i2 += ${CHANNEL_TILE}; 60 i3 += ${CHANNEL_TILE}; 61 62 $for C in range(0, CHANNEL_TILE, 8): 63 const __m256 vtd${ABC[C:C+8]} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_sub_ps(vtr${ABC[C:C+8]}, vtl${ABC[C:C+8]}), _MM_FROUND_NO_EXC)); 64 const __m256 vbd${ABC[C:C+8]} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_sub_ps(vbr${ABC[C:C+8]}, vbl${ABC[C:C+8]}), _MM_FROUND_NO_EXC)); 65 66 $for C in range(0, CHANNEL_TILE, 8): 67 const __m256 vt${ABC[C:C+8]} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vtd${ABC[C:C+8]}, valphah, vtl${ABC[C:C+8]}), _MM_FROUND_NO_EXC)); 68 const __m256 vb${ABC[C:C+8]} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vbd${ABC[C:C+8]}, valphah, vbl${ABC[C:C+8]}), _MM_FROUND_NO_EXC)); 69 70 $for C in range(0, CHANNEL_TILE, 8): 71 const __m256 vd${ABC[C:C+8]} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_sub_ps(vb${ABC[C:C+8]}, vt${ABC[C:C+8]}), _MM_FROUND_NO_EXC)); 72 73 $for C in range(0, CHANNEL_TILE, 8): 74 const __m128i vo${ABC[C:C+8]} = _mm256_cvtps_ph(_mm256_fmadd_ps(vd${ABC[C:C+8]}, valphav, vt${ABC[C:C+8]}), _MM_FROUND_NO_EXC); 75 76 _mm_storeu_si128((__m128i*) o, vo${ABC[0:8]}); 77 $for C in range(8, CHANNEL_TILE, 8): 78 _mm_storeu_si128((__m128i*) (o + ${C}), vo${ABC[C:C+8]}); 79 o += ${CHANNEL_TILE}; 80 } 81 for (; c >= 8 * sizeof(uint16_t); c -= 8 * sizeof(uint16_t)) { 82 const __m256 vtl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0)); 83 i0 += 8; 84 const __m256 vtr = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1)); 85 i1 += 8; 86 const __m256 vbl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2)); 87 i2 += 8; 88 const __m256 vbr = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3)); 89 i3 += 8; 90 91 const __m256 vtd = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_sub_ps(vtr, vtl), _MM_FROUND_NO_EXC)); 92 const __m256 vbd = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_sub_ps(vbr, vbl), _MM_FROUND_NO_EXC)); 93 94 const __m256 vt = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vtd, valphah, vtl), _MM_FROUND_NO_EXC)); 95 const __m256 vb = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vbd, valphah, vbl), _MM_FROUND_NO_EXC)); 96 97 const __m256 vd = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_sub_ps(vb, vt), _MM_FROUND_NO_EXC)); 98 99 const __m128i vo = _mm256_cvtps_ph(_mm256_fmadd_ps(vd, valphav, vt), _MM_FROUND_NO_EXC); 100 101 _mm_storeu_si128((__m128i*) o, vo); 102 o += 8; 103 } 104 if XNN_UNLIKELY(c != 0) { 105 const __m256 vtl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0)); 106 i0 += 8; 107 const __m256 vtr = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1)); 108 i1 += 8; 109 const __m256 vbl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2)); 110 i2 += 8; 111 const __m256 vbr = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3)); 112 i3 += 8; 113 114 const __m256 vtd = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_sub_ps(vtr, vtl), _MM_FROUND_NO_EXC)); 115 const __m256 vbd = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_sub_ps(vbr, vbl), _MM_FROUND_NO_EXC)); 116 117 const __m256 vt = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vtd, valphah, vtl), _MM_FROUND_NO_EXC)); 118 const __m256 vb = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vbd, valphah, vbl), _MM_FROUND_NO_EXC)); 119 120 const __m256 vd = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_sub_ps(vb, vt), _MM_FROUND_NO_EXC)); 121 122 __m128i vo = _mm256_cvtps_ph(_mm256_fmadd_ps(vd, valphav, vt), _MM_FROUND_NO_EXC); 123 if (c & (4 * sizeof(uint16_t))) { 124 _mm_storel_epi64((__m128i*) o, vo); 125 vo = _mm_unpackhi_epi64(vo, vo); 126 o += 4; 127 } 128 if (c & (2 * sizeof(uint16_t))) { 129 _mm_storeu_si32(o, vo); 130 vo = _mm_srli_epi64(vo, 32); 131 o += 2; 132 } 133 if (c & (1 * sizeof(uint16_t))) { 134 *o = (uint16_t) _mm_extract_epi16(vo, 0); 135 o += 1; 136 } 137 } 138 139 o = (uint16_t*) ((uintptr_t) o + output_increment); 140 } while (--output_pixels != 0); 141} 142