1// Copyright 2019 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert CHANNEL_TILE % 8 == 0 7$assert KERNEL_TILE >= 2 8$assert ACCUMULATORS >= 1 9$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 10#include <assert.h> 11 12#include <immintrin.h> 13 14#include <xnnpack/dwconv.h> 15 16 17$ISA = {0: "avx", 3: "fma3"}[FMA] 18void xnn_f32_dwconv_minmax_ukernel_up${CHANNEL_TILE}x${KERNEL_TILE}__${ISA}${"" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS}( 19 size_t channels, 20 size_t output_width, 21 const float** input, 22 const float* weights, 23 float* output, 24 size_t input_stride, 25 size_t output_increment, 26 size_t input_offset, 27 const float* zero, 28 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS 29{ 30 assert(channels != 0); 31 assert(output_width != 0); 32 33 const __m256 vmax = _mm256_load_ps(params->avx.max); 34 const __m256 vmin = _mm256_load_ps(params->avx.min); 35 do { 36 $for K in range(KERNEL_TILE): 37 const float* i${K} = input[${K}]; 38 assert(i${K} != NULL); 39 if XNN_UNPREDICTABLE(i${K} != zero) { 40 i${K} = (const float*) ((uintptr_t) i${K} + input_offset); 41 } 42 input = (const float**) ((uintptr_t) input + input_stride); 43 44 size_t c = channels; 45 const float* w = weights; 46 for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) { 47 __m256 vacc${ABC[0:8]}p0 = _mm256_load_ps(w); 48 $for C in range(8, CHANNEL_TILE, 8): 49 __m256 vacc${ABC[C:C+8]}p0 = _mm256_load_ps(w + ${C}); 50 51 $for K in range(KERNEL_TILE): 52 53 const __m256 vi${K}x${ABC[0:8]} = _mm256_loadu_ps(i${K}); 54 $for C in range(8, CHANNEL_TILE, 8): 55 const __m256 vi${K}x${ABC[C:C+8]} = _mm256_loadu_ps(i${K} + ${C}); 56 i${K} += ${CHANNEL_TILE}; 57 58 $for C in range(0, CHANNEL_TILE, 8): 59 const __m256 vk${K}x${ABC[C:C+8]} = _mm256_load_ps(w + ${(K + 1) * CHANNEL_TILE + C}); 60 $for C in range(0, CHANNEL_TILE, 8): 61 $if 1 <= K < ACCUMULATORS: 62 __m256 vacc${ABC[C:C+8]}p${K} = _mm256_mul_ps(vi${K}x${ABC[C:C+8]}, vk${K}x${ABC[C:C+8]}); 63 $elif FMA == 3: 64 vacc${ABC[C:C+8]}p${K % ACCUMULATORS} = _mm256_fmadd_ps(vi${K}x${ABC[C:C+8]}, vk${K}x${ABC[C:C+8]}, vacc${ABC[C:C+8]}p${K % ACCUMULATORS}); 65 $else: 66 vacc${ABC[C:C+8]}p${K % ACCUMULATORS} = _mm256_add_ps(vacc${ABC[C:C+8]}p${K % ACCUMULATORS}, _mm256_mul_ps(vi${K}x${ABC[C:C+8]}, vk${K}x${ABC[C:C+8]})); 67 68 w += ${(KERNEL_TILE + 1) * CHANNEL_TILE}; 69 70 $if ACCUMULATORS > 1: 71 // Add up all accumulators to vacc${ABC[0:CHANNEL_TILE]}p0 72 $ACC_SLICE = 1 73 $while ACC_SLICE < ACCUMULATORS: 74 $for A in range(0, ACCUMULATORS, ACC_SLICE * 2): 75 $if A + ACC_SLICE < ACCUMULATORS: 76 $for C in range(0, CHANNEL_TILE, 8): 77 vacc${ABC[C:C+8]}p${A} = _mm256_add_ps(vacc${ABC[C:C+8]}p${A}, vacc${ABC[C:C+8]}p${A + ACC_SLICE}); 78 $ACC_SLICE *= 2 79 80 $for C in range(0, CHANNEL_TILE, 8): 81 __m256 vacc${ABC[C:C+8]} = _mm256_max_ps(vacc${ABC[C:C+8]}p0, vmin); 82 $for C in range(0, CHANNEL_TILE, 8): 83 vacc${ABC[C:C+8]} = _mm256_min_ps(vacc${ABC[C:C+8]}, vmax); 84 85 _mm256_storeu_ps(output, vacc${ABC[0:8]}); 86 $for C in range(8, CHANNEL_TILE, 8): 87 _mm256_storeu_ps(output + ${C}, vacc${ABC[C:C+8]}); 88 output += ${CHANNEL_TILE}; 89 } 90 $if CHANNEL_TILE > 8: 91 for (; c >= 8; c -= 8) { 92 __m256 vacc01234567p0 = _mm256_load_ps(w); 93 $for K in range(KERNEL_TILE): 94 95 const __m256 vi${K}x01234567 = _mm256_loadu_ps(i${K}); 96 i${K} += 8; 97 98 const __m256 vk${K}x01234567 = _mm256_load_ps(w + ${(K + 1) * CHANNEL_TILE}); 99 $if 1 <= K < ACCUMULATORS: 100 __m256 vacc01234567p${K} = _mm256_mul_ps(vi${K}x01234567, vk${K}x01234567); 101 $elif FMA == 3: 102 vacc01234567p${K % ACCUMULATORS} = _mm256_fmadd_ps(vi${K}x01234567, vk${K}x01234567, vacc01234567p${K % ACCUMULATORS}); 103 $else: 104 vacc01234567p${K % ACCUMULATORS} = _mm256_add_ps(vacc01234567p${K % ACCUMULATORS}, _mm256_mul_ps(vi${K}x01234567, vk${K}x01234567)); 105 106 w += 8; 107 108 $if ACCUMULATORS > 1: 109 // Add up all accumulators to vacc${ABC[0:8]}p0 110 $ACC_SLICE = 1 111 $while ACC_SLICE < ACCUMULATORS: 112 $for A in range(0, ACCUMULATORS, ACC_SLICE * 2): 113 $if A + ACC_SLICE < ACCUMULATORS: 114 vacc01234567p${A} = _mm256_add_ps(vacc01234567p${A}, vacc01234567p${A + ACC_SLICE}); 115 $ACC_SLICE *= 2 116 117 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin); 118 vacc01234567 = _mm256_min_ps(vacc01234567, vmax); 119 120 _mm256_storeu_ps(output, vacc01234567); 121 output += 8; 122 } 123 if XNN_UNLIKELY(c != 0) { 124 assert(c >= 1); 125 assert(c <= 7); 126 const __m256i vmask = _mm256_loadu_si256((const __m256i*) ¶ms->avx.mask_table[7 - c]); 127 128 __m256 vacc01234567p0 = _mm256_load_ps(w); 129 $for K in range(KERNEL_TILE): 130 131 const __m256 vi${K}x01234567 = _mm256_maskload_ps(i${K}, vmask); 132 const __m256 vk${K}x01234567 = _mm256_load_ps(w + ${(K + 1) * CHANNEL_TILE}); 133 $if 1 <= K < ACCUMULATORS: 134 __m256 vacc01234567p${K} = _mm256_mul_ps(vi${K}x01234567, vk${K}x01234567); 135 $elif FMA == 3: 136 vacc01234567p${K % ACCUMULATORS} = _mm256_fmadd_ps(vi${K}x01234567, vk${K}x01234567, vacc01234567p${K % ACCUMULATORS}); 137 $else: 138 vacc01234567p${K % ACCUMULATORS} = _mm256_add_ps(vacc01234567p${K % ACCUMULATORS}, _mm256_mul_ps(vi${K}x01234567, vk${K}x01234567)); 139 140 $if ACCUMULATORS > 1: 141 // Add up all accumulators to vacc${ABC[0:8]}p0 142 $ACC_SLICE = 1 143 $while ACC_SLICE < ACCUMULATORS: 144 $for A in range(0, ACCUMULATORS, ACC_SLICE * 2): 145 $if A + ACC_SLICE < ACCUMULATORS: 146 vacc01234567p${A} = _mm256_add_ps(vacc01234567p${A}, vacc01234567p${A + ACC_SLICE}); 147 $ACC_SLICE *= 2 148 149 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin); 150 vacc01234567 = _mm256_min_ps(vacc01234567, vmax); 151 152 __m128 vacc0123 = _mm256_castps256_ps128(vacc01234567); 153 if (c & 4) { 154 _mm_storeu_ps(output, vacc0123); 155 vacc0123 = _mm256_extractf128_ps(vacc01234567, 1); 156 output += 4; 157 } 158 if (c & 2) { 159 _mm_storel_pi((__m64*) output, vacc0123); 160 vacc0123 = _mm_movehl_ps(vacc0123, vacc0123); 161 output += 2; 162 } 163 if (c & 1) { 164 _mm_store_ss(output, vacc0123); 165 output += 1; 166 } 167 } 168 169 output = (float*) ((uintptr_t) output + output_increment); 170 } while (--output_width != 0); 171} 172