xref: /aosp_15_r20/external/XNNPACK/src/qs8-vmulc/scalar.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2021 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert DATATYPE in ["QS8", "QU8"]
7$assert REQUANTIZATION == "FP32"
8$assert BATCH_TILE >= 1
9#include <assert.h>
10
11#include <xnnpack/math.h>
12#include <xnnpack/vmul.h>
13
14
15$XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[DATATYPE]
16void xnn_${DATATYPE.lower()}_vmulc_minmax_${REQUANTIZATION.lower()}_ukernel__scalar_x${BATCH_TILE}(
17    size_t n,
18    const ${XINT8_T}* input_a,
19    const ${XINT8_T}* input_b,
20    ${XINT8_T}* output,
21    const union xnn_${DATATYPE.lower()}_mul_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
22{
23  const int32_t va_zero_point = params->fp32_scalar.a_zero_point;
24  const float vscale = params->fp32_scalar.scale;
25  const float voutput_min_less_zero_point = params->fp32_scalar.output_min_less_zero_point;
26  const float voutput_max_less_zero_point = params->fp32_scalar.output_max_less_zero_point;
27  const float vmagic_bias = params->fp32_scalar.magic_bias;
28  const int32_t vmagic_bias_less_output_zero_point = params->fp32_scalar.magic_bias_less_output_zero_point;
29
30  const int32_t vb = (int32_t) *input_b - params->fp32_scalar.b_zero_point;
31  $if BATCH_TILE == 1:
32    do {
33      const int32_t va = (int32_t) *input_a++ - va_zero_point;
34      const int32_t vacc = va * vb;
35
36      float vfpacc = (float) vacc * vscale;
37      vfpacc = math_max_f32(vfpacc, voutput_min_less_zero_point);
38      vfpacc = math_min_f32(vfpacc, voutput_max_less_zero_point);
39      vfpacc += vmagic_bias;
40      const int32_t vout = (int32_t) float_as_uint32(vfpacc) - vmagic_bias_less_output_zero_point;
41      *output++ = (${XINT8_T}) vout;
42
43      n -= sizeof(${XINT8_T});
44    } while (n != 0);
45  $else:
46    for (; n >= ${BATCH_TILE} * sizeof(${XINT8_T}); n -= ${BATCH_TILE} * sizeof(${XINT8_T})) {
47      $for N in range(BATCH_TILE):
48        const int32_t va${N} = input_a[${N}] - va_zero_point;
49      input_a += ${BATCH_TILE};
50
51      $for N in range(BATCH_TILE):
52        const int32_t vacc${N} = va${N} * vb;
53
54      $for N in range(BATCH_TILE):
55        float vfpacc${N} = (float) vacc${N} * vscale;
56
57      $for N in range(BATCH_TILE):
58        vfpacc${N} = math_max_f32(vfpacc${N}, voutput_min_less_zero_point);
59
60      $for N in range(BATCH_TILE):
61        vfpacc${N} = math_min_f32(vfpacc${N}, voutput_max_less_zero_point);
62
63      $for N in range(BATCH_TILE):
64        vfpacc${N} += vmagic_bias;
65
66      $for N in range(BATCH_TILE):
67        const int32_t vout${N} = (int32_t) float_as_uint32(vfpacc${N}) - vmagic_bias_less_output_zero_point;
68
69      $for N in range(BATCH_TILE):
70        output[${N}] = (${XINT8_T}) vout${N};
71      output += ${BATCH_TILE};
72    }
73    if XNN_UNLIKELY(n != 0) {
74      $if BATCH_TILE == 2:
75        const int32_t va = (int32_t) *input_a - va_zero_point;
76        const int32_t vacc = va * vb;
77
78        float vfpacc = (float) vacc * vscale;
79        vfpacc = math_max_f32(vfpacc, voutput_min_less_zero_point);
80        vfpacc = math_min_f32(vfpacc, voutput_max_less_zero_point);
81        vfpacc += vmagic_bias;
82        const int32_t vout = (int32_t) float_as_uint32(vfpacc) - vmagic_bias_less_output_zero_point;
83        *output = (${XINT8_T}) vout;
84      $else:
85        do {
86          const int32_t va = (int32_t) *input_a++ - va_zero_point;
87          const int32_t vacc = va * vb;
88
89          float vfpacc = (float) vacc * vscale;
90          vfpacc = math_max_f32(vfpacc, voutput_min_less_zero_point);
91          vfpacc = math_min_f32(vfpacc, voutput_max_less_zero_point);
92          vfpacc += vmagic_bias;
93          const int32_t vout = (int32_t) float_as_uint32(vfpacc) - vmagic_bias_less_output_zero_point;
94          *output++ = (${XINT8_T}) vout;
95
96          n -= sizeof(${XINT8_T});
97        } while (n != 0);
98    }
99}
100