1// Copyright 2022 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert DATATYPE in ["QS8", "QU8"] 7$assert BATCH_TILE >= 1 8$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 9#include <assert.h> 10 11#include <xnnpack/math.h> 12#include <xnnpack/vlrelu.h> 13 14 15$XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[DATATYPE] 16$OUTPUT_MIN = {"QS8": -128, "QU8": 0}[DATATYPE] 17$OUTPUT_MAX = {"QS8": 127, "QU8": 255}[DATATYPE] 18void xnn_${DATATYPE.lower()}_vlrelu_ukernel__scalar_andxor_x${BATCH_TILE}( 19 size_t n, 20 const ${XINT8_T}* x, 21 ${XINT8_T}* y, 22 const union xnn_${DATATYPE.lower()}_lrelu_params params[restrict XNN_MIN_ELEMENTS(1)]) 23{ 24 const int32_t vinput_zero_point = params->scalar_andxor.input_zero_point; 25 const int32_t vmultiplier_diff = params->scalar_andxor.multiplier_diff; 26 const int32_t vmultiplier_base = params->scalar_andxor.multiplier_base; 27 const int32_t vbias = params->scalar_andxor.bias; 28 $if BATCH_TILE == 1: 29 do { 30 int32_t vacc = (int32_t) *x++ - vinput_zero_point; 31 const int32_t vmultiplier = vmultiplier_base ^ (vmultiplier_diff & math_asr_s32(vacc, 31)); 32 vacc = vbias + vacc * vmultiplier; 33 34 int32_t vout = math_asr_s32(vacc, 8); 35 vout = math_max_s32(vout, ${OUTPUT_MIN}); 36 vout = math_min_s32(vout, ${OUTPUT_MAX}); 37 *y++ = (${XINT8_T}) vout; 38 39 n -= sizeof(${XINT8_T}); 40 } while (n != 0); 41 $else: 42 for (; n >= ${BATCH_TILE} * sizeof(${XINT8_T}); n -= ${BATCH_TILE} * sizeof(${XINT8_T})) { 43 $for N in range(BATCH_TILE): 44 int32_t vacc${ABC[N]} = (int32_t) x[${N}]; 45 x += ${BATCH_TILE}; 46 47 $for N in range(BATCH_TILE): 48 vacc${ABC[N]} -= vinput_zero_point; 49 50 $for N in range(BATCH_TILE): 51 int32_t vmultiplier${ABC[N]} = math_asr_s32(vacc${ABC[N]}, 31); 52 53 $for N in range(BATCH_TILE): 54 vmultiplier${ABC[N]} &= vmultiplier_diff; 55 56 $for N in range(BATCH_TILE): 57 vmultiplier${ABC[N]} ^= vmultiplier_base; 58 59 $for N in range(BATCH_TILE): 60 vacc${ABC[N]} = vbias + vacc${ABC[N]} * vmultiplier${ABC[N]}; 61 62 $for N in range(BATCH_TILE): 63 int32_t vout${ABC[N]} = math_asr_s32(vacc${ABC[N]}, 8); 64 65 $for N in range(BATCH_TILE): 66 vout${ABC[N]} = math_max_s32(vout${ABC[N]}, ${OUTPUT_MIN}); 67 68 $for N in range(BATCH_TILE): 69 vout${ABC[N]} = math_min_s32(vout${ABC[N]}, ${OUTPUT_MAX}); 70 71 $for N in range(BATCH_TILE): 72 y[${N}] = (${XINT8_T}) vout${ABC[N]}; 73 y += ${BATCH_TILE}; 74 } 75 if XNN_UNLIKELY(n != 0) { 76 $if BATCH_TILE == 2: 77 int32_t vacc = (int32_t) *x++ - vinput_zero_point; 78 const int32_t vmultiplier = vmultiplier_base ^ (vmultiplier_diff & math_asr_s32(vacc, 31)); 79 vacc = vbias + vacc * vmultiplier; 80 81 int32_t vout = math_asr_s32(vacc, 8); 82 vout = math_max_s32(vout, ${OUTPUT_MIN}); 83 vout = math_min_s32(vout, ${OUTPUT_MAX}); 84 *y = (${XINT8_T}) vout; 85 $else: 86 do { 87 int32_t vacc = (int32_t) *x++ - vinput_zero_point; 88 const int32_t vmultiplier = vmultiplier_base ^ (vmultiplier_diff & math_asr_s32(vacc, 31)); 89 vacc = vbias + vacc * vmultiplier; 90 91 int32_t vout = math_asr_s32(vacc, 8); 92 vout = math_max_s32(vout, ${OUTPUT_MIN}); 93 vout = math_min_s32(vout, ${OUTPUT_MAX}); 94 *y++ = (${XINT8_T}) vout; 95 96 n -= sizeof(${XINT8_T}); 97 } while (n != 0); 98 } 99} 100