xref: /aosp_15_r20/external/XNNPACK/src/f16-dwconv/up-fma3.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker// Copyright 2019 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker//
3*4bdc9457SAndroid Build Coastguard Worker// This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker// LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker
6*4bdc9457SAndroid Build Coastguard Worker$assert CHANNEL_TILE % 8 == 0
7*4bdc9457SAndroid Build Coastguard Worker$assert KERNEL_TILE >= 2
8*4bdc9457SAndroid Build Coastguard Worker$assert ACCUMULATORS >= 1
9*4bdc9457SAndroid Build Coastguard Worker$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
10*4bdc9457SAndroid Build Coastguard Worker#include <assert.h>
11*4bdc9457SAndroid Build Coastguard Worker
12*4bdc9457SAndroid Build Coastguard Worker#include <immintrin.h>
13*4bdc9457SAndroid Build Coastguard Worker
14*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/dwconv.h>
15*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/intrinsics-polyfill.h>
16*4bdc9457SAndroid Build Coastguard Worker
17*4bdc9457SAndroid Build Coastguard Worker
18*4bdc9457SAndroid Build Coastguard Workervoid xnn_f16_dwconv_minmax_ukernel_up${CHANNEL_TILE}x${KERNEL_TILE}__fma3${"" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS}(
19*4bdc9457SAndroid Build Coastguard Worker    size_t channels,
20*4bdc9457SAndroid Build Coastguard Worker    size_t output_width,
21*4bdc9457SAndroid Build Coastguard Worker    const void** input,
22*4bdc9457SAndroid Build Coastguard Worker    const void* weights,
23*4bdc9457SAndroid Build Coastguard Worker    void* output,
24*4bdc9457SAndroid Build Coastguard Worker    size_t input_stride,
25*4bdc9457SAndroid Build Coastguard Worker    size_t output_increment,
26*4bdc9457SAndroid Build Coastguard Worker    size_t input_offset,
27*4bdc9457SAndroid Build Coastguard Worker    const void* zero,
28*4bdc9457SAndroid Build Coastguard Worker    const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
29*4bdc9457SAndroid Build Coastguard Worker{
30*4bdc9457SAndroid Build Coastguard Worker  assert(channels != 0);
31*4bdc9457SAndroid Build Coastguard Worker  assert(output_width != 0);
32*4bdc9457SAndroid Build Coastguard Worker
33*4bdc9457SAndroid Build Coastguard Worker  const __m256 vmax = _mm256_load_ps(params->avx.max);
34*4bdc9457SAndroid Build Coastguard Worker  const __m256 vmin = _mm256_load_ps(params->avx.min);
35*4bdc9457SAndroid Build Coastguard Worker
36*4bdc9457SAndroid Build Coastguard Worker  uint16_t* o = (uint16_t*) output;
37*4bdc9457SAndroid Build Coastguard Worker  do {
38*4bdc9457SAndroid Build Coastguard Worker    $for K in range(KERNEL_TILE):
39*4bdc9457SAndroid Build Coastguard Worker      const uint16_t* i${K} = input[${K}];
40*4bdc9457SAndroid Build Coastguard Worker      assert(i${K} != NULL);
41*4bdc9457SAndroid Build Coastguard Worker      if XNN_UNPREDICTABLE(i${K} != zero) {
42*4bdc9457SAndroid Build Coastguard Worker        i${K} = (const uint16_t*) ((uintptr_t) i${K} + input_offset);
43*4bdc9457SAndroid Build Coastguard Worker      }
44*4bdc9457SAndroid Build Coastguard Worker    input = (const void**) ((uintptr_t) input + input_stride);
45*4bdc9457SAndroid Build Coastguard Worker
46*4bdc9457SAndroid Build Coastguard Worker    size_t c = channels;
47*4bdc9457SAndroid Build Coastguard Worker    const uint16_t* w = weights;
48*4bdc9457SAndroid Build Coastguard Worker    for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) {
49*4bdc9457SAndroid Build Coastguard Worker      __m256 vacc${ABC[0:8]}p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
50*4bdc9457SAndroid Build Coastguard Worker      $for C in range(8, CHANNEL_TILE, 8):
51*4bdc9457SAndroid Build Coastguard Worker        __m256 vacc${ABC[C:C+8]}p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + ${C})));
52*4bdc9457SAndroid Build Coastguard Worker
53*4bdc9457SAndroid Build Coastguard Worker      $for K in range(KERNEL_TILE):
54*4bdc9457SAndroid Build Coastguard Worker
55*4bdc9457SAndroid Build Coastguard Worker        const __m256 vi${K}x${ABC[0:8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i${K}));
56*4bdc9457SAndroid Build Coastguard Worker        $for C in range(8, CHANNEL_TILE, 8):
57*4bdc9457SAndroid Build Coastguard Worker          const __m256 vi${K}x${ABC[C:C+8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i${K} + ${C})));
58*4bdc9457SAndroid Build Coastguard Worker        i${K} += ${CHANNEL_TILE};
59*4bdc9457SAndroid Build Coastguard Worker
60*4bdc9457SAndroid Build Coastguard Worker        $for C in range(0, CHANNEL_TILE, 8):
61*4bdc9457SAndroid Build Coastguard Worker          const __m256 vk${K}x${ABC[C:C+8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + ${(K + 1) * CHANNEL_TILE + C})));
62*4bdc9457SAndroid Build Coastguard Worker        $for C in range(0, CHANNEL_TILE, 8):
63*4bdc9457SAndroid Build Coastguard Worker          $if 1 <= K < ACCUMULATORS:
64*4bdc9457SAndroid Build Coastguard Worker            __m256 vacc${ABC[C:C+8]}p${K} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(vi${K}x${ABC[C:C+8]}, vk${K}x${ABC[C:C+8]}), _MM_FROUND_NO_EXC));
65*4bdc9457SAndroid Build Coastguard Worker          $else:
66*4bdc9457SAndroid Build Coastguard Worker            vacc${ABC[C:C+8]}p${K % ACCUMULATORS} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi${K}x${ABC[C:C+8]}, vk${K}x${ABC[C:C+8]}, vacc${ABC[C:C+8]}p${K % ACCUMULATORS}), _MM_FROUND_NO_EXC));
67*4bdc9457SAndroid Build Coastguard Worker
68*4bdc9457SAndroid Build Coastguard Worker      w += ${(KERNEL_TILE + 1) * CHANNEL_TILE};
69*4bdc9457SAndroid Build Coastguard Worker
70*4bdc9457SAndroid Build Coastguard Worker      $if ACCUMULATORS > 1:
71*4bdc9457SAndroid Build Coastguard Worker        // Add up all accumulators to vacc${ABC[0:CHANNEL_TILE]}p0
72*4bdc9457SAndroid Build Coastguard Worker        $ACC_SLICE = 1
73*4bdc9457SAndroid Build Coastguard Worker        $while ACC_SLICE < ACCUMULATORS:
74*4bdc9457SAndroid Build Coastguard Worker          $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
75*4bdc9457SAndroid Build Coastguard Worker            $if A + ACC_SLICE < ACCUMULATORS:
76*4bdc9457SAndroid Build Coastguard Worker              $for C in range(0, CHANNEL_TILE, 8):
77*4bdc9457SAndroid Build Coastguard Worker                vacc${ABC[C:C+8]}p${A} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vacc${ABC[C:C+8]}p${A}, vacc${ABC[C:C+8]}p${A + ACC_SLICE}), _MM_FROUND_NO_EXC));
78*4bdc9457SAndroid Build Coastguard Worker          $ACC_SLICE *= 2
79*4bdc9457SAndroid Build Coastguard Worker
80*4bdc9457SAndroid Build Coastguard Worker      $for C in range(0, CHANNEL_TILE, 8):
81*4bdc9457SAndroid Build Coastguard Worker        __m256 vacc${ABC[C:C+8]} = _mm256_max_ps(vacc${ABC[C:C+8]}p0, vmin);
82*4bdc9457SAndroid Build Coastguard Worker      $for C in range(0, CHANNEL_TILE, 8):
83*4bdc9457SAndroid Build Coastguard Worker        vacc${ABC[C:C+8]} = _mm256_min_ps(vacc${ABC[C:C+8]}, vmax);
84*4bdc9457SAndroid Build Coastguard Worker
85*4bdc9457SAndroid Build Coastguard Worker      _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vacc${ABC[0:8]}, _MM_FROUND_NO_EXC));
86*4bdc9457SAndroid Build Coastguard Worker      $for C in range(8, CHANNEL_TILE, 8):
87*4bdc9457SAndroid Build Coastguard Worker        _mm_storeu_si128((__m128i*) (o + ${C}), _mm256_cvtps_ph(vacc${ABC[C:C+8]}, _MM_FROUND_NO_EXC));
88*4bdc9457SAndroid Build Coastguard Worker      o += ${CHANNEL_TILE};
89*4bdc9457SAndroid Build Coastguard Worker    }
90*4bdc9457SAndroid Build Coastguard Worker    $if CHANNEL_TILE > 8:
91*4bdc9457SAndroid Build Coastguard Worker      for (; c >= 8; c -= 8) {
92*4bdc9457SAndroid Build Coastguard Worker        __m256 vacc01234567p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
93*4bdc9457SAndroid Build Coastguard Worker        $for K in range(KERNEL_TILE):
94*4bdc9457SAndroid Build Coastguard Worker
95*4bdc9457SAndroid Build Coastguard Worker          const __m256 vi${K}x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i${K}));
96*4bdc9457SAndroid Build Coastguard Worker          i${K} += 8;
97*4bdc9457SAndroid Build Coastguard Worker
98*4bdc9457SAndroid Build Coastguard Worker          const __m256 vk${K}x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + ${(K + 1) * CHANNEL_TILE})));
99*4bdc9457SAndroid Build Coastguard Worker          $if 1 <= K < ACCUMULATORS:
100*4bdc9457SAndroid Build Coastguard Worker            __m256 vacc01234567p${K} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(vi${K}x01234567, vk${K}x01234567), _MM_FROUND_NO_EXC));
101*4bdc9457SAndroid Build Coastguard Worker          $else:
102*4bdc9457SAndroid Build Coastguard Worker            vacc01234567p${K % ACCUMULATORS} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi${K}x01234567, vk${K}x01234567, vacc01234567p${K % ACCUMULATORS}), _MM_FROUND_NO_EXC));
103*4bdc9457SAndroid Build Coastguard Worker
104*4bdc9457SAndroid Build Coastguard Worker        w += 8;
105*4bdc9457SAndroid Build Coastguard Worker
106*4bdc9457SAndroid Build Coastguard Worker        $if ACCUMULATORS > 1:
107*4bdc9457SAndroid Build Coastguard Worker          // Add up all accumulators to vacc${ABC[0:8]}p0
108*4bdc9457SAndroid Build Coastguard Worker          $ACC_SLICE = 1
109*4bdc9457SAndroid Build Coastguard Worker          $while ACC_SLICE < ACCUMULATORS:
110*4bdc9457SAndroid Build Coastguard Worker            $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
111*4bdc9457SAndroid Build Coastguard Worker              $if A + ACC_SLICE < ACCUMULATORS:
112*4bdc9457SAndroid Build Coastguard Worker                vacc01234567p${A} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vacc01234567p${A}, vacc01234567p${A + ACC_SLICE}), _MM_FROUND_NO_EXC));
113*4bdc9457SAndroid Build Coastguard Worker            $ACC_SLICE *= 2
114*4bdc9457SAndroid Build Coastguard Worker
115*4bdc9457SAndroid Build Coastguard Worker        __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
116*4bdc9457SAndroid Build Coastguard Worker        vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
117*4bdc9457SAndroid Build Coastguard Worker
118*4bdc9457SAndroid Build Coastguard Worker        _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vacc01234567, _MM_FROUND_NO_EXC));
119*4bdc9457SAndroid Build Coastguard Worker        o += 8;
120*4bdc9457SAndroid Build Coastguard Worker      }
121*4bdc9457SAndroid Build Coastguard Worker    if XNN_UNLIKELY(c != 0) {
122*4bdc9457SAndroid Build Coastguard Worker      assert(c >= 1);
123*4bdc9457SAndroid Build Coastguard Worker      assert(c <= 7);
124*4bdc9457SAndroid Build Coastguard Worker
125*4bdc9457SAndroid Build Coastguard Worker      __m256 vacc01234567p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
126*4bdc9457SAndroid Build Coastguard Worker      $for K in range(KERNEL_TILE):
127*4bdc9457SAndroid Build Coastguard Worker
128*4bdc9457SAndroid Build Coastguard Worker        const __m256 vi${K}x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i${K}));
129*4bdc9457SAndroid Build Coastguard Worker
130*4bdc9457SAndroid Build Coastguard Worker        const __m256 vk${K}x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + ${(K + 1) * CHANNEL_TILE})));
131*4bdc9457SAndroid Build Coastguard Worker        $if 1 <= K < ACCUMULATORS:
132*4bdc9457SAndroid Build Coastguard Worker          __m256 vacc01234567p${K} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(vi${K}x01234567, vk${K}x01234567), _MM_FROUND_NO_EXC));
133*4bdc9457SAndroid Build Coastguard Worker        $else:
134*4bdc9457SAndroid Build Coastguard Worker          vacc01234567p${K % ACCUMULATORS} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi${K}x01234567, vk${K}x01234567, vacc01234567p${K % ACCUMULATORS}), _MM_FROUND_NO_EXC));
135*4bdc9457SAndroid Build Coastguard Worker
136*4bdc9457SAndroid Build Coastguard Worker      $if ACCUMULATORS > 1:
137*4bdc9457SAndroid Build Coastguard Worker        // Add up all accumulators to vacc${ABC[0:8]}p0
138*4bdc9457SAndroid Build Coastguard Worker        $ACC_SLICE = 1
139*4bdc9457SAndroid Build Coastguard Worker        $while ACC_SLICE < ACCUMULATORS:
140*4bdc9457SAndroid Build Coastguard Worker          $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
141*4bdc9457SAndroid Build Coastguard Worker            $if A + ACC_SLICE < ACCUMULATORS:
142*4bdc9457SAndroid Build Coastguard Worker              vacc01234567p${A} = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vacc01234567p${A}, vacc01234567p${A + ACC_SLICE}), _MM_FROUND_NO_EXC));
143*4bdc9457SAndroid Build Coastguard Worker          $ACC_SLICE *= 2
144*4bdc9457SAndroid Build Coastguard Worker
145*4bdc9457SAndroid Build Coastguard Worker      __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
146*4bdc9457SAndroid Build Coastguard Worker      vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
147*4bdc9457SAndroid Build Coastguard Worker
148*4bdc9457SAndroid Build Coastguard Worker      __m128i vh01234567 = _mm256_cvtps_ph(vacc01234567, _MM_FROUND_NO_EXC);
149*4bdc9457SAndroid Build Coastguard Worker      if (c & 4) {
150*4bdc9457SAndroid Build Coastguard Worker        _mm_storel_epi64((__m128i*) o, vh01234567);
151*4bdc9457SAndroid Build Coastguard Worker        vh01234567 = _mm_unpackhi_epi64(vh01234567, vh01234567);
152*4bdc9457SAndroid Build Coastguard Worker        o += 4;
153*4bdc9457SAndroid Build Coastguard Worker      }
154*4bdc9457SAndroid Build Coastguard Worker      if (c & 2) {
155*4bdc9457SAndroid Build Coastguard Worker        _mm_storeu_si32(o, vh01234567);
156*4bdc9457SAndroid Build Coastguard Worker        vh01234567 = _mm_srli_epi64(vh01234567, 32);
157*4bdc9457SAndroid Build Coastguard Worker        o += 2;
158*4bdc9457SAndroid Build Coastguard Worker      }
159*4bdc9457SAndroid Build Coastguard Worker      if (c & 1) {
160*4bdc9457SAndroid Build Coastguard Worker        *o = (uint16_t) _mm_extract_epi16(vh01234567, 0);
161*4bdc9457SAndroid Build Coastguard Worker        o += 1;
162*4bdc9457SAndroid Build Coastguard Worker      }
163*4bdc9457SAndroid Build Coastguard Worker    }
164*4bdc9457SAndroid Build Coastguard Worker
165*4bdc9457SAndroid Build Coastguard Worker    o = (uint16_t*) ((uintptr_t) o + output_increment);
166*4bdc9457SAndroid Build Coastguard Worker  } while (--output_width != 0);
167*4bdc9457SAndroid Build Coastguard Worker}
168