xref: /aosp_15_r20/external/XNNPACK/src/qs8-vmul/scalar.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2021 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert DATATYPE in ["QS8", "QU8"]
7$assert REQUANTIZATION == "FP32"
8$assert BATCH_TILE >= 1
9#include <assert.h>
10
11#include <xnnpack/math.h>
12#include <xnnpack/vmul.h>
13
14
15$XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[DATATYPE]
16void xnn_${DATATYPE.lower()}_vmul_minmax_${REQUANTIZATION.lower()}_ukernel__scalar_x${BATCH_TILE}(
17    size_t n,
18    const ${XINT8_T}* input_a,
19    const ${XINT8_T}* input_b,
20    ${XINT8_T}* output,
21    const union xnn_${DATATYPE.lower()}_mul_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
22{
23  const int32_t va_zero_point = params->fp32_scalar.a_zero_point;
24  const int32_t vb_zero_point = params->fp32_scalar.b_zero_point;
25  const float vscale = params->fp32_scalar.scale;
26  const float voutput_min_less_zero_point = params->fp32_scalar.output_min_less_zero_point;
27  const float voutput_max_less_zero_point = params->fp32_scalar.output_max_less_zero_point;
28  const float vmagic_bias = params->fp32_scalar.magic_bias;
29  const int32_t vmagic_bias_less_output_zero_point = params->fp32_scalar.magic_bias_less_output_zero_point;
30
31  $if BATCH_TILE == 1:
32    do {
33      const int32_t va = (int32_t) *input_a++ - va_zero_point;
34      const int32_t vb = (int32_t) *input_b++ - vb_zero_point;
35      const int32_t vacc = va * vb;
36
37      float vfpacc = (float) vacc * vscale;
38      vfpacc = math_max_f32(vfpacc, voutput_min_less_zero_point);
39      vfpacc = math_min_f32(vfpacc, voutput_max_less_zero_point);
40      vfpacc += vmagic_bias;
41      const int32_t vout = (int32_t) float_as_uint32(vfpacc) - vmagic_bias_less_output_zero_point;
42      *output++ = (${XINT8_T}) vout;
43
44      n -= sizeof(${XINT8_T});
45    } while (n != 0);
46  $else:
47    for (; n >= ${BATCH_TILE} * sizeof(${XINT8_T}); n -= ${BATCH_TILE} * sizeof(${XINT8_T})) {
48      $for N in range(BATCH_TILE):
49        const int32_t va${N} = input_a[${N}] - va_zero_point;
50      input_a += ${BATCH_TILE};
51
52      $for N in range(BATCH_TILE):
53        const int32_t vb${N} = input_b[${N}] - vb_zero_point;
54      input_b += ${BATCH_TILE};
55
56      $for N in range(BATCH_TILE):
57        const int32_t vacc${N} = va${N} * vb${N};
58
59      $for N in range(BATCH_TILE):
60        float vfpacc${N} = (float) vacc${N} * vscale;
61
62      $for N in range(BATCH_TILE):
63        vfpacc${N} = math_max_f32(vfpacc${N}, voutput_min_less_zero_point);
64
65      $for N in range(BATCH_TILE):
66        vfpacc${N} = math_min_f32(vfpacc${N}, voutput_max_less_zero_point);
67
68      $for N in range(BATCH_TILE):
69        vfpacc${N} += vmagic_bias;
70
71      $for N in range(BATCH_TILE):
72        const int32_t vout${N} = (int32_t) float_as_uint32(vfpacc${N}) - vmagic_bias_less_output_zero_point;
73
74      $for N in range(BATCH_TILE):
75        output[${N}] = (${XINT8_T}) vout${N};
76      output += ${BATCH_TILE};
77    }
78    if XNN_UNLIKELY(n != 0) {
79      $if BATCH_TILE == 2:
80        const int32_t va = (int32_t) *input_a - va_zero_point;
81        const int32_t vb = (int32_t) *input_b - vb_zero_point;
82        const int32_t vacc = va * vb;
83
84        float vfpacc = (float) vacc * vscale;
85        vfpacc = math_max_f32(vfpacc, voutput_min_less_zero_point);
86        vfpacc = math_min_f32(vfpacc, voutput_max_less_zero_point);
87        vfpacc += vmagic_bias;
88        const int32_t vout = (int32_t) float_as_uint32(vfpacc) - vmagic_bias_less_output_zero_point;
89        *output = (${XINT8_T}) vout;
90      $else:
91        do {
92          const int32_t va = (int32_t) *input_a++ - va_zero_point;
93          const int32_t vb = (int32_t) *input_b++ - vb_zero_point;
94          const int32_t vacc = va * vb;
95
96          float vfpacc = (float) vacc * vscale;
97          vfpacc = math_max_f32(vfpacc, voutput_min_less_zero_point);
98          vfpacc = math_min_f32(vfpacc, voutput_max_less_zero_point);
99          vfpacc += vmagic_bias;
100          const int32_t vout = (int32_t) float_as_uint32(vfpacc) - vmagic_bias_less_output_zero_point;
101          *output++ = (${XINT8_T}) vout;
102
103          n -= sizeof(${XINT8_T});
104        } while (n != 0);
105    }
106}
107