1*4bdc9457SAndroid Build Coastguard Worker// Copyright 2020 Google LLC 2*4bdc9457SAndroid Build Coastguard Worker// 3*4bdc9457SAndroid Build Coastguard Worker// This source code is licensed under the BSD-style license found in the 4*4bdc9457SAndroid Build Coastguard Worker// LICENSE file in the root directory of this source tree. 5*4bdc9457SAndroid Build Coastguard Worker 6*4bdc9457SAndroid Build Coastguard Worker$assert BATCH_TILE % 16 == 0 7*4bdc9457SAndroid Build Coastguard Worker$assert BATCH_TILE >= 16 8*4bdc9457SAndroid Build Coastguard Worker$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 9*4bdc9457SAndroid Build Coastguard Worker$assert OP in ["ABS", "NEG", "SQR"] 10*4bdc9457SAndroid Build Coastguard Worker#include <assert.h> 11*4bdc9457SAndroid Build Coastguard Worker 12*4bdc9457SAndroid Build Coastguard Worker#include <immintrin.h> 13*4bdc9457SAndroid Build Coastguard Worker 14*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/common.h> 15*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/intrinsics-polyfill.h> 16*4bdc9457SAndroid Build Coastguard Worker#include <xnnpack/vunary.h> 17*4bdc9457SAndroid Build Coastguard Worker 18*4bdc9457SAndroid Build Coastguard Worker 19*4bdc9457SAndroid Build Coastguard Worker$__M512 = { 20*4bdc9457SAndroid Build Coastguard Worker$ "ABS": "__m512i", 21*4bdc9457SAndroid Build Coastguard Worker$ "NEG": "__m512i", 22*4bdc9457SAndroid Build Coastguard Worker$ "SQR": "__m512", 23*4bdc9457SAndroid Build Coastguard Worker$}[OP] 24*4bdc9457SAndroid Build Coastguard Worker$_MM512_LOADU = { 25*4bdc9457SAndroid Build Coastguard Worker$ "ABS": "_mm512_loadu_si512", 26*4bdc9457SAndroid Build Coastguard Worker$ "NEG": "_mm512_loadu_si512", 27*4bdc9457SAndroid Build Coastguard Worker$ "SQR": "_mm512_loadu_ps", 28*4bdc9457SAndroid Build Coastguard Worker$}[OP] 29*4bdc9457SAndroid Build Coastguard Worker$_MM512_MASK_LOADU = { 30*4bdc9457SAndroid Build Coastguard Worker$ "ABS": "_mm512_maskz_loadu_epi32", 31*4bdc9457SAndroid Build Coastguard Worker$ "NEG": "_mm512_maskz_loadu_epi32", 32*4bdc9457SAndroid Build Coastguard Worker$ "SQR": "_mm512_maskz_loadu_ps", 33*4bdc9457SAndroid Build Coastguard Worker$}[OP] 34*4bdc9457SAndroid Build Coastguard Worker$_MM512_STOREU = { 35*4bdc9457SAndroid Build Coastguard Worker$ "ABS": "_mm512_storeu_si512", 36*4bdc9457SAndroid Build Coastguard Worker$ "NEG": "_mm512_storeu_si512", 37*4bdc9457SAndroid Build Coastguard Worker$ "SQR": "_mm512_storeu_ps", 38*4bdc9457SAndroid Build Coastguard Worker$}[OP] 39*4bdc9457SAndroid Build Coastguard Worker$_MM512_MASK_STOREU = { 40*4bdc9457SAndroid Build Coastguard Worker$ "ABS": "_mm512_mask_storeu_epi32", 41*4bdc9457SAndroid Build Coastguard Worker$ "NEG": "_mm512_mask_storeu_epi32", 42*4bdc9457SAndroid Build Coastguard Worker$ "SQR": "_mm512_mask_storeu_ps", 43*4bdc9457SAndroid Build Coastguard Worker$}[OP] 44*4bdc9457SAndroid Build Coastguard Worker$_MM512_OP = { 45*4bdc9457SAndroid Build Coastguard Worker$ "ABS": lambda x: "_mm512_and_epi32(%s, vnonsign_mask)" % x, 46*4bdc9457SAndroid Build Coastguard Worker$ "NEG": lambda x: "_mm512_xor_epi32(%s, vsign_mask)" % x, 47*4bdc9457SAndroid Build Coastguard Worker$ "SQR": lambda x: "_mm512_mul_ps(%s, %s)" % (x, x), 48*4bdc9457SAndroid Build Coastguard Worker$}[OP] 49*4bdc9457SAndroid Build Coastguard Worker$PARAMS = { 50*4bdc9457SAndroid Build Coastguard Worker$ "ABS": "xnn_f32_abs_params", 51*4bdc9457SAndroid Build Coastguard Worker$ "NEG": "xnn_f32_neg_params", 52*4bdc9457SAndroid Build Coastguard Worker$ "SQR": "xnn_f32_default_params", 53*4bdc9457SAndroid Build Coastguard Worker$}[OP] 54*4bdc9457SAndroid Build Coastguard Workervoid xnn_f32_v${OP.lower()}_ukernel__avx512f_x${BATCH_TILE}( 55*4bdc9457SAndroid Build Coastguard Worker size_t n, 56*4bdc9457SAndroid Build Coastguard Worker const float* x, 57*4bdc9457SAndroid Build Coastguard Worker float* y, 58*4bdc9457SAndroid Build Coastguard Worker const union ${PARAMS} params[restrict XNN_MIN_ELEMENTS(1)]) 59*4bdc9457SAndroid Build Coastguard Worker{ 60*4bdc9457SAndroid Build Coastguard Worker assert(n != 0); 61*4bdc9457SAndroid Build Coastguard Worker assert(n % sizeof(float) == 0); 62*4bdc9457SAndroid Build Coastguard Worker assert(x != NULL); 63*4bdc9457SAndroid Build Coastguard Worker assert(y != NULL); 64*4bdc9457SAndroid Build Coastguard Worker 65*4bdc9457SAndroid Build Coastguard Worker $if OP == "ABS": 66*4bdc9457SAndroid Build Coastguard Worker const __m512i vnonsign_mask = _mm512_set1_epi32((int) params->avx512.nonsign_mask); 67*4bdc9457SAndroid Build Coastguard Worker $elif OP == "NEG": 68*4bdc9457SAndroid Build Coastguard Worker const __m512i vsign_mask = _mm512_set1_epi32((int) params->avx512.sign_mask); 69*4bdc9457SAndroid Build Coastguard Worker for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) { 70*4bdc9457SAndroid Build Coastguard Worker const ${__M512} vx${ABC[0:16]} = ${_MM512_LOADU}(x); 71*4bdc9457SAndroid Build Coastguard Worker $for N in range(16, BATCH_TILE, 16): 72*4bdc9457SAndroid Build Coastguard Worker const ${__M512} vx${ABC[N:N+16]} = ${_MM512_LOADU}(x + ${N}); 73*4bdc9457SAndroid Build Coastguard Worker x += ${BATCH_TILE}; 74*4bdc9457SAndroid Build Coastguard Worker 75*4bdc9457SAndroid Build Coastguard Worker $for N in range(0, BATCH_TILE, 16): 76*4bdc9457SAndroid Build Coastguard Worker const ${__M512} vy${ABC[N:N+16]} = ${_MM512_OP("vx" + ABC[N:N+16])}; 77*4bdc9457SAndroid Build Coastguard Worker 78*4bdc9457SAndroid Build Coastguard Worker ${_MM512_STOREU}(y, vy${ABC[0:16]}); 79*4bdc9457SAndroid Build Coastguard Worker $for N in range(16, BATCH_TILE, 16): 80*4bdc9457SAndroid Build Coastguard Worker ${_MM512_STOREU}(y + ${N}, vy${ABC[N:N+16]}); 81*4bdc9457SAndroid Build Coastguard Worker y += ${BATCH_TILE}; 82*4bdc9457SAndroid Build Coastguard Worker } 83*4bdc9457SAndroid Build Coastguard Worker $if BATCH_TILE > 16: 84*4bdc9457SAndroid Build Coastguard Worker for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) { 85*4bdc9457SAndroid Build Coastguard Worker const ${__M512} vx = ${_MM512_LOADU}(x); 86*4bdc9457SAndroid Build Coastguard Worker x += 16; 87*4bdc9457SAndroid Build Coastguard Worker 88*4bdc9457SAndroid Build Coastguard Worker const ${__M512} vy = ${_MM512_OP("vx")}; 89*4bdc9457SAndroid Build Coastguard Worker 90*4bdc9457SAndroid Build Coastguard Worker ${_MM512_STOREU}(y, vy); 91*4bdc9457SAndroid Build Coastguard Worker y += 16; 92*4bdc9457SAndroid Build Coastguard Worker } 93*4bdc9457SAndroid Build Coastguard Worker if XNN_UNLIKELY(n != 0) { 94*4bdc9457SAndroid Build Coastguard Worker assert(n >= 1 * sizeof(float)); 95*4bdc9457SAndroid Build Coastguard Worker assert(n <= 15 * sizeof(float)); 96*4bdc9457SAndroid Build Coastguard Worker // Prepare mask for valid 32-bit elements (depends on n). 97*4bdc9457SAndroid Build Coastguard Worker n >>= 2 /* log2(sizeof(float)) */; 98*4bdc9457SAndroid Build Coastguard Worker const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1))); 99*4bdc9457SAndroid Build Coastguard Worker 100*4bdc9457SAndroid Build Coastguard Worker const ${__M512} vx = ${_MM512_MASK_LOADU}(vmask, x); 101*4bdc9457SAndroid Build Coastguard Worker const ${__M512} vy = ${_MM512_OP("vx")}; 102*4bdc9457SAndroid Build Coastguard Worker ${_MM512_MASK_STOREU}(y, vmask, vy); 103*4bdc9457SAndroid Build Coastguard Worker } 104*4bdc9457SAndroid Build Coastguard Worker} 105