xref: /aosp_15_r20/external/XNNPACK/src/qs8-gavgpool/multipass-scalar.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2021 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert DATATYPE in ["QS8", "QU8"]
7$assert CHANNEL_TILE >= 1
8$assert CHANNEL_TILE <= 16
9$assert ROW_TILE >= 3
10$assert ROW_SUBTILE >= 3
11$assert ROW_SUBTILE <= ROW_TILE
12$assert REQUANTIZATION == "FP32"
13#include <assert.h>
14$if VARIANT == "LRINTF":
15  #include <math.h>
16
17#include <xnnpack/gavgpool.h>
18#include <xnnpack/math.h>
19
20
21$PARAMS_STRUCT = "fp32_scalar_" + VARIANT.lower()
22$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t"
23$MIN_F32 = "__builtin_wasm_min_f32" if WASM else "math_min_f32"
24$MAX_F32 = "__builtin_wasm_max_f32" if WASM else "math_max_f32"
25void xnn_${DATATYPE.lower()}_gavgpool_minmax_fp32_ukernel_${ROW_TILE}p${ROW_SUBTILE}x__scalar_${VARIANT.lower()}_c${CHANNEL_TILE}(
26    size_t rows,
27    size_t channels,
28    const ${XINT8_T}* input,
29    size_t input_stride,
30    const ${XINT8_T}* zero,
31    int32_t* buffer,
32    ${XINT8_T}* output,
33    const union xnn_${DATATYPE.lower()}_avgpool_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
34{
35  assert(rows > ${ROW_TILE});
36  assert(channels != 0);
37
38  const ${XINT8_T}* i0 = input;
39  $for M in range(1, ROW_TILE):
40    const ${XINT8_T}* i${M} = (const ${XINT8_T}*) ((uintptr_t) i${M-1} + input_stride);
41  const size_t input_increment = ${ROW_TILE} * input_stride - round_up_po2(channels, ${CHANNEL_TILE}) * sizeof(${XINT8_T});
42
43  const int32_t vinit_bias = params->${PARAMS_STRUCT}.init_bias;
44  int32_t* b = buffer;
45  $if CHANNEL_TILE == 1:
46    size_t c = channels;
47    do {
48      int32_t vacc = vinit_bias;
49      $for M in range(2):
50        const int32_t vi${M} = (int32_t) *i${M}++;
51
52      $for M in range(2, ROW_TILE):
53        vacc += vi${M-2};
54        const int32_t vi${M} = (int32_t) *i${M}++;
55
56      $for M in range(ROW_TILE - 2, ROW_TILE):
57        vacc += vi${M};
58
59      *b++ = vacc;
60    } while (--c != 0);
61  $else:
62    for (ptrdiff_t c = (ptrdiff_t) channels; c > 0; c -= ${CHANNEL_TILE}) {
63      $for C in range(CHANNEL_TILE):
64        const int32_t vi0x${C} = (int32_t) i0[${C}];
65      i0 += ${CHANNEL_TILE};
66
67      $for C in range(CHANNEL_TILE):
68        int32_t vacc${C} = vi0x${C} + vinit_bias;
69        const int32_t vi1x${C} = (int32_t) i1[${C}];
70      i1 += ${CHANNEL_TILE};
71
72      $for M in range(2, ROW_TILE):
73        $for C in range(CHANNEL_TILE):
74          vacc${C} += vi${M-1}x${C};
75          const int32_t vi${M}x${C} = (int32_t) i${M}[${C}];
76        i${M} += ${CHANNEL_TILE};
77
78      $for C in range(CHANNEL_TILE):
79        vacc${C} += vi${ROW_TILE-1}x${C};
80
81      $for C in range(CHANNEL_TILE):
82        b[${C}] = vacc${C};
83      b += ${CHANNEL_TILE};
84    }
85
86  for (rows -= ${ROW_TILE}; rows > ${ROW_SUBTILE}; rows -= ${ROW_SUBTILE}) {
87    $for M in range(ROW_SUBTILE):
88      i${M} = (const ${XINT8_T}*) ((uintptr_t) i${M + ROW_TILE - ROW_SUBTILE} + input_increment);
89
90    int32_t* b = buffer;
91    $if CHANNEL_TILE == 1:
92      size_t c = channels;
93      do {
94        int32_t vacc = *b;
95        $for M in range(2):
96          const int32_t vi${M} = (int32_t) *i${M}++;
97
98        $for M in range(2, ROW_SUBTILE):
99          vacc += vi${M-2};
100          const int32_t vi${M} = (int32_t) *i${M}++;
101
102        $for M in range(ROW_SUBTILE - 2, ROW_SUBTILE):
103          vacc += vi${M};
104
105        *b++ = vacc;
106      } while (--c != 0);
107    $else:
108      for (ptrdiff_t c = (ptrdiff_t) channels; c > 0; c -= ${CHANNEL_TILE}) {
109        $for C in range(CHANNEL_TILE):
110          int32_t vacc${C} = b[${C}];
111          const int32_t vi0x${C} = (int32_t) i0[${C}];
112        i0 += ${CHANNEL_TILE};
113
114        $for M in range(1, ROW_SUBTILE):
115          $for C in range(CHANNEL_TILE):
116            vacc${C} += vi${M-1}x${C};
117            const int32_t vi${M}x${C} = (int32_t) i${M}[${C}];
118          i${M} += ${CHANNEL_TILE};
119
120        $for C in range(CHANNEL_TILE):
121          vacc${C} += vi${ROW_SUBTILE-1}x${C};
122
123        $for C in range(CHANNEL_TILE):
124          b[${C}] = vacc${C};
125        b += ${CHANNEL_TILE};
126      }
127  }
128
129  i0 = (const ${XINT8_T}*) ((uintptr_t) i${ROW_TILE - ROW_SUBTILE} + input_increment);
130  $for M in range(1, ROW_SUBTILE):
131    i${M} = (const ${XINT8_T}*) ((uintptr_t) i${M + ROW_TILE - ROW_SUBTILE} + input_increment);
132    $if M % 2 == 1:
133      if XNN_UNPREDICTABLE(rows < ${M+1}) {
134        i${M} = zero;
135      }
136    $else:
137      if XNN_UNPREDICTABLE(rows <= ${M}) {
138        i${M} = zero;
139      }
140
141  const float vscale = params->${PARAMS_STRUCT}.scale;
142  $if VARIANT == "FMAGIC":
143    const float voutput_min_less_zero_point = params->fp32_scalar_fmagic.output_min_less_zero_point;
144    const float voutput_max_less_zero_point = params->fp32_scalar_fmagic.output_max_less_zero_point;
145    const float vmagic_bias = params->fp32_scalar_fmagic.magic_bias;
146    const int32_t vmagic_bias_less_output_zero_point = params->fp32_scalar_fmagic.magic_bias_less_output_zero_point;
147  $elif VARIANT == "IMAGIC":
148    const float vmagic_bias = params->fp32_scalar_imagic.magic_bias;
149    const int32_t vmagic_min = params->fp32_scalar_imagic.magic_min;
150    const int32_t vmagic_max = params->fp32_scalar_imagic.magic_max;
151    const int32_t vmagic_bias_less_zero_point = params->fp32_scalar_imagic.magic_bias_less_zero_point;
152  $elif VARIANT == "LRINTF":
153    const float voutput_min_less_zero_point = params->fp32_scalar_lrintf.output_min_less_zero_point;
154    const float voutput_max_less_zero_point = params->fp32_scalar_lrintf.output_max_less_zero_point;
155    const int32_t voutput_zero_point = params->fp32_scalar_lrintf.output_zero_point;
156  $if CHANNEL_TILE == 1:
157    do {
158      int32_t vacc = *buffer++;
159      $for M in range(2):
160        const int32_t vi${M} = (int32_t) *i${M}++;
161
162      $for M in range(2, ROW_SUBTILE):
163        vacc += vi${M-2};
164        const int32_t vi${M} = (int32_t) *i${M}++;
165
166      $for M in range(ROW_SUBTILE - 2, ROW_SUBTILE):
167        vacc += vi${M};
168
169      float vfpacc = (float) vacc * vscale;
170      $if VARIANT == "FMAGIC":
171        vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point);
172        vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point);
173        vfpacc += vmagic_bias;
174        int32_t vout = (int32_t) float_as_uint32(vfpacc) - vmagic_bias_less_output_zero_point;
175      $elif VARIANT == "IMAGIC":
176        vfpacc += vmagic_bias;
177        int32_t vout = (int32_t) float_as_uint32(vfpacc);
178        vout = math_max_s32(vout, vmagic_min);
179        vout = math_min_s32(vout, vmagic_max);
180        vout -= vmagic_bias_less_zero_point;
181      $elif VARIANT == "LRINTF":
182        vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point);
183        vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point);
184        const int32_t vrndacc = (int32_t) lrintf(vfpacc);
185        int32_t vout = vrndacc + voutput_zero_point;
186
187      *output++ = (${XINT8_T}) vout;
188    } while (--channels != 0);
189  $else:
190    for (; channels >= ${CHANNEL_TILE}; channels -= ${CHANNEL_TILE}) {
191      $for C in range(CHANNEL_TILE):
192        int32_t vacc${C} = buffer[${C}];
193        const int32_t vi0x${C} = (int32_t) i0[${C}];
194      buffer += ${CHANNEL_TILE};
195      i0 += ${CHANNEL_TILE};
196
197      $for M in range(1, ROW_SUBTILE):
198        $for C in range(CHANNEL_TILE):
199          vacc${C} += vi${M-1}x${C};
200          const int32_t vi${M}x${C} = (int32_t) i${M}[${C}];
201        i${M} += ${CHANNEL_TILE};
202
203      $for C in range(CHANNEL_TILE):
204        vacc${C} += vi${ROW_SUBTILE-1}x${C};
205
206      $for C in range(CHANNEL_TILE):
207        float vfpacc${C} = (float) vacc${C} * vscale;
208
209      $if VARIANT == "FMAGIC":
210        $for C in range(CHANNEL_TILE):
211          vfpacc${C} = ${MAX_F32}(vfpacc${C}, voutput_min_less_zero_point);
212
213        $for C in range(CHANNEL_TILE):
214          vfpacc${C} = ${MIN_F32}(vfpacc${C}, voutput_max_less_zero_point);
215
216        $for C in range(CHANNEL_TILE):
217          vfpacc${C} += vmagic_bias;
218
219        $for C in range(CHANNEL_TILE):
220          int32_t vout${C} = (int32_t) float_as_uint32(vfpacc${C}) - vmagic_bias_less_output_zero_point;
221      $elif VARIANT == "IMAGIC":
222        $for C in range(CHANNEL_TILE):
223          vfpacc${C} += vmagic_bias;
224
225        $for C in range(CHANNEL_TILE):
226          int32_t vout${C} = (int32_t) float_as_uint32(vfpacc${C});
227
228        $for C in range(CHANNEL_TILE):
229          vout${C} = math_max_s32(vout${C}, vmagic_min);
230
231        $for C in range(CHANNEL_TILE):
232          vout${C} = math_min_s32(vout${C}, vmagic_max);
233
234        $for C in range(CHANNEL_TILE):
235          vout${C} -= vmagic_bias_less_zero_point;
236      $elif VARIANT == "LRINTF":
237        $for C in range(CHANNEL_TILE):
238          vfpacc${C} = ${MAX_F32}(vfpacc${C}, voutput_min_less_zero_point);
239
240        $for C in range(CHANNEL_TILE):
241          vfpacc${C} = ${MIN_F32}(vfpacc${C}, voutput_max_less_zero_point);
242
243        $for C in range(CHANNEL_TILE):
244          const int32_t vrndacc${C} = (int32_t) lrintf(vfpacc${C});
245
246        $for C in range(CHANNEL_TILE):
247          int32_t vout${C} = vrndacc${C} + voutput_zero_point;
248
249      $for C in range(CHANNEL_TILE):
250        output[${C}] = (${XINT8_T}) vout${C};
251      output += ${CHANNEL_TILE};
252    }
253    if XNN_UNLIKELY(channels != 0) {
254      $if CHANNEL_TILE == 2:
255        int32_t vacc = *buffer;
256        $for M in range(2):
257          const int32_t vi${M} = (int32_t) *i${M};
258
259        $for M in range(2, ROW_SUBTILE):
260          vacc += vi${M-2};
261          const int32_t vi${M} = (int32_t) *i${M};
262
263        $for M in range(ROW_SUBTILE - 2, ROW_SUBTILE):
264          vacc += vi${M};
265
266        float vfpacc = (float) vacc * vscale;
267        $if VARIANT == "FMAGIC":
268          vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point);
269          vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point);
270          vfpacc += vmagic_bias;
271          int32_t vout = (int32_t) float_as_uint32(vfpacc) - vmagic_bias_less_output_zero_point;
272        $elif VARIANT == "IMAGIC":
273          vfpacc += vmagic_bias;
274          int32_t vout = (int32_t) float_as_uint32(vfpacc);
275          vout = math_max_s32(vout, vmagic_min);
276          vout = math_min_s32(vout, vmagic_max);
277          vout -= vmagic_bias_less_zero_point;
278        $elif VARIANT == "LRINTF":
279          vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point);
280          vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point);
281          const int32_t vrndacc = (int32_t) lrintf(vfpacc);
282          int32_t vout = vrndacc + voutput_zero_point;
283
284        *output = (${XINT8_T}) vout;
285      $else:
286        do {
287          int32_t vacc = *buffer++;
288          $for M in range(2):
289            const int32_t vi${M} = (int32_t) *i${M}++;
290
291          $for M in range(2, ROW_SUBTILE):
292            vacc += vi${M-2};
293            const int32_t vi${M} = (int32_t) *i${M}++;
294
295          $for M in range(ROW_SUBTILE - 2, ROW_SUBTILE):
296            vacc += vi${M};
297
298          float vfpacc = (float) vacc * vscale;
299          $if VARIANT == "FMAGIC":
300            vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point);
301            vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point);
302            vfpacc += vmagic_bias;
303            int32_t vout = (int32_t) float_as_uint32(vfpacc) - vmagic_bias_less_output_zero_point;
304          $elif VARIANT == "IMAGIC":
305            vfpacc += vmagic_bias;
306            int32_t vout = (int32_t) float_as_uint32(vfpacc);
307            vout = math_max_s32(vout, vmagic_min);
308            vout = math_min_s32(vout, vmagic_max);
309            vout -= vmagic_bias_less_zero_point;
310          $elif VARIANT == "LRINTF":
311            vfpacc = ${MAX_F32}(vfpacc, voutput_min_less_zero_point);
312            vfpacc = ${MIN_F32}(vfpacc, voutput_max_less_zero_point);
313            const int32_t vrndacc = (int32_t) lrintf(vfpacc);
314            int32_t vout = vrndacc + voutput_zero_point;
315
316          *output++ = (${XINT8_T}) vout;
317        } while (--channels != 0);
318    }
319}
320