xref: /aosp_15_r20/external/XNNPACK/src/f16-dwconv2d-chw/3x3p1-neonfp16arith.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert ROW_TILE >= 1
7$assert ACCUMULATORS >= 1
8#include <assert.h>
9
10#include <arm_neon.h>
11
12#include <xnnpack/dwconv.h>
13#include <xnnpack/math.h>
14
15
16void xnn_f16_dwconv2d_chw_ukernel_3x3p1__neonfp16arith_${ROW_TILE}x8${"_acc%d" % ACCUMULATORS if ACCUMULATORS > 1 else ""}(
17    size_t input_height,
18    size_t input_width,
19    const void* input,
20    const void* weights,
21    const void* zero,
22    void* output,
23    uint32_t padding_top,
24    const union xnn_f16_chw_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
25{
26  assert(input_height != 0);
27  assert(input_width != 0);
28  assert(input_width % sizeof(__fp16) == 0);
29  assert(padding_top == 1);
30
31  const uint16x8_t vmask = vld1q_u16(params->neonfp16arith.maskx8);
32  const float16x8_t vmax = vld1q_dup_f16(&params->neonfp16arith.max);
33  const float16x8_t vmin = vld1q_dup_f16(&params->neonfp16arith.min);
34
35  const __fp16* w0 = (const __fp16*)weights;
36  const float16x8_t vw01234567 = vld1q_f16(w0);
37  const float16x4_t vw89 = vreinterpret_f16_u32(vld1_lane_u32((const void*)(w0 + 8), vmov_n_u32(0), 0));
38
39  const size_t input_decrement = round_up_po2(input_width, 8 * sizeof(__fp16));
40
41  const __fp16* i0 = zero;
42  const __fp16* i1 = input;
43  $for M in range(2, 2 + ROW_TILE):
44    const __fp16* i${M} = (const __fp16*) ((uintptr_t) i${M-1} + input_width);
45
46  __fp16* o0 = output;
47  $for M in range(1, ROW_TILE):
48    __fp16* o${M} = (__fp16*) ((uintptr_t) o${M-1} + input_width);
49
50  size_t output_height = input_height;
51  do {
52    $for M in range(2, 2 + ROW_TILE):
53      if XNN_UNPREDICTABLE(output_height < ${M}) {
54        i${M} = zero;
55        $if M <= ROW_TILE:
56          o${M-1} = o${M-2};
57      }
58
59    $for M in range(2 + ROW_TILE):
60      float16x8_t vi${M}x01234567 = vmovq_n_f16(0);
61
62    $for M in range(2 + ROW_TILE):
63      float16x8_t vi${M}x89ABCDEF = vld1q_f16(i${M}); i${M} += 8;
64
65    size_t w = input_width;
66    for (; w > 8 * sizeof(__fp16); w -= 8 * sizeof(__fp16)) {
67      $for M in range(ROW_TILE):
68        float16x8_t vo${M}p0 = vdupq_lane_f16(vget_low_f16(vw01234567), 0);
69
70      $for M in range(2 + ROW_TILE):
71        const float16x8_t vi${M}xGHIJKLMN = vld1q_f16(i${M}); i${M} += 8;
72
73      // Center column
74      $for M in range(ROW_TILE):
75        vo${M}p0 = vfmaq_lane_f16(vo${M}p0, vi${M}x89ABCDEF, vget_low_f16(vw01234567), 2);
76
77      $for M in range(ROW_TILE):
78        $if ACCUMULATORS >= 2:
79          float16x8_t vo${M}p1 = vmulq_lane_f16(vi${M+1}x89ABCDEF, vget_high_f16(vw01234567), 1);
80        $else:
81          vo${M}p0 = vfmaq_lane_f16(vo${M}p0, vi${M+1}x89ABCDEF, vget_high_f16(vw01234567), 1);
82
83      $for M in range(ROW_TILE):
84        $if ACCUMULATORS >= 3:
85          float16x8_t vo${M}p2 = vmulq_lane_f16(vi${M+2}x89ABCDEF, vw89, 0);
86        $else:
87          vo${M}p0 = vfmaq_lane_f16(vo${M}p0, vi${M+2}x89ABCDEF, vw89, 0);
88
89      // Left column
90      $for M in range(2 + ROW_TILE):
91        const float16x8_t vi${M}x789ABCDE = vextq_f16(vi${M}x01234567, vi${M}x89ABCDEF, 7);
92
93      $for M in range(ROW_TILE):
94        $if ACCUMULATORS >= 4:
95          float16x8_t vo${M}p3 = vmulq_lane_f16(vi${M}x789ABCDE, vget_low_f16(vw01234567), 1);
96        $else:
97          vo${M}p${3 % ACCUMULATORS} = vfmaq_lane_f16(vo${M}p${3 % ACCUMULATORS}, vi${M}x789ABCDE, vget_low_f16(vw01234567), 1);
98
99      $for M in range(ROW_TILE):
100        vo${M}p${4 % ACCUMULATORS} = vfmaq_lane_f16(vo${M}p${4 % ACCUMULATORS}, vi${M+1}x789ABCDE, vget_high_f16(vw01234567), 0);
101
102      $for M in range(ROW_TILE):
103        vo${M}p${5 % ACCUMULATORS} = vfmaq_lane_f16(vo${M}p${5 % ACCUMULATORS}, vi${M+2}x789ABCDE, vget_high_f16(vw01234567), 3);
104
105      $for M in range(2 + ROW_TILE):
106        vi${M}x01234567 = vi${M}x89ABCDEF;
107
108      // Right column
109      $for M in range(2 + ROW_TILE):
110        const float16x8_t vi${M}x9ABCDEFG = vextq_f16(vi${M}x89ABCDEF, vi${M}xGHIJKLMN, 1);
111
112      $for M in range(ROW_TILE):
113        vo${M}p${6 % ACCUMULATORS} = vfmaq_lane_f16(vo${M}p${6 % ACCUMULATORS}, vi${M}x9ABCDEFG, vget_low_f16(vw01234567), 3);
114
115      $for M in range(ROW_TILE):
116        vo${M}p${7 % ACCUMULATORS} = vfmaq_lane_f16(vo${M}p${7 % ACCUMULATORS}, vi${M+1}x9ABCDEFG, vget_high_f16(vw01234567), 2);
117
118      $for M in range(ROW_TILE):
119        vo${M}p${8 % ACCUMULATORS} = vfmaq_lane_f16(vo${M}p${8 % ACCUMULATORS}, vi${M+2}x9ABCDEFG, vw89, 1);
120
121      $for M in range(2 + ROW_TILE):
122        vi${M}x89ABCDEF = vi${M}xGHIJKLMN;
123
124      $if ACCUMULATORS > 1:
125        $ACC_SLICE = 1
126        $while ACC_SLICE < ACCUMULATORS:
127          $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
128            $if A + ACC_SLICE < ACCUMULATORS:
129              $for M in range(ROW_TILE):
130                vo${M}p${A} = vaddq_f16(vo${M}p${A}, vo${M}p${A + ACC_SLICE});
131          $ACC_SLICE *= 2
132
133      $for M in range(ROW_TILE):
134        float16x8_t vo${M} = vmaxq_f16(vo${M}p0, vmin);
135
136      $for M in range(ROW_TILE):
137        vo${M} = vminq_f16(vo${M}, vmax);
138
139      $for M in reversed(range(ROW_TILE)):
140        vst1q_f16(o${M}, vo${M}); o${M} += 8;
141    }
142    // Always process the last block of 1..8 pixels.
143    assert(w >= 1 * sizeof(__fp16));
144    assert(w <= 8 * sizeof(__fp16));
145    {
146      $for M in range(ROW_TILE):
147        float16x8_t vo${M}p0 = vdupq_lane_f16(vget_low_f16(vw01234567), 0);
148
149      $for M in range(2 + ROW_TILE):
150        vi${M}x89ABCDEF = vreinterpretq_f16_u16(vandq_u16(vmask, vreinterpretq_u16_f16(vi${M}x89ABCDEF)));
151
152      // Center column
153      $for M in range(ROW_TILE):
154        vo${M}p0 = vfmaq_lane_f16(vo${M}p0, vi${M}x89ABCDEF, vget_low_f16(vw01234567), 2);
155
156      $for M in range(ROW_TILE):
157        $if ACCUMULATORS >= 2:
158          float16x8_t vo${M}p1 = vmulq_lane_f16(vi${M+1}x89ABCDEF, vget_high_f16(vw01234567), 1);
159        $else:
160          vo${M}p0 = vfmaq_lane_f16(vo${M}p0, vi${M+1}x89ABCDEF, vget_high_f16(vw01234567), 1);
161
162      $for M in range(ROW_TILE):
163        $if ACCUMULATORS >= 3:
164          float16x8_t vo${M}p2 = vmulq_lane_f16(vi${M+2}x89ABCDEF, vw89, 0);
165        $else:
166          vo${M}p0 = vfmaq_lane_f16(vo${M}p0, vi${M+2}x89ABCDEF, vw89, 0);
167
168      // Left column
169      $for M in range(2 + ROW_TILE):
170        const float16x8_t vi${M}x789ABCDE = vextq_f16(vi${M}x01234567, vi${M}x89ABCDEF, 7);
171
172      $for M in range(ROW_TILE):
173        $if ACCUMULATORS >= 4:
174          float16x8_t vo${M}p3 = vmulq_lane_f16(vi${M}x789ABCDE, vget_low_f16(vw01234567), 1);
175        $else:
176          vo${M}p${3 % ACCUMULATORS} = vfmaq_lane_f16(vo${M}p${3 % ACCUMULATORS}, vi${M}x789ABCDE, vget_low_f16(vw01234567), 1);
177
178      $for M in range(ROW_TILE):
179        vo${M}p${4 % ACCUMULATORS} = vfmaq_lane_f16(vo${M}p${4 % ACCUMULATORS}, vi${M+1}x789ABCDE, vget_high_f16(vw01234567), 0);
180
181      $for M in range(ROW_TILE):
182        vo${M}p${5 % ACCUMULATORS} = vfmaq_lane_f16(vo${M}p${5 % ACCUMULATORS}, vi${M+2}x789ABCDE, vget_high_f16(vw01234567), 3);
183
184      // Right column
185      const float16x8_t vzero = vmovq_n_f16(0);
186      $for M in range(2 + ROW_TILE):
187        const float16x8_t vi${M}x9ABCDEFG = vextq_f16(vi${M}x89ABCDEF, vzero, 1);
188
189      $for M in range(ROW_TILE):
190        vo${M}p${6 % ACCUMULATORS} = vfmaq_lane_f16(vo${M}p${6 % ACCUMULATORS}, vi${M}x9ABCDEFG, vget_low_f16(vw01234567), 3);
191
192      $for M in range(ROW_TILE):
193        vo${M}p${7 % ACCUMULATORS} = vfmaq_lane_f16(vo${M}p${7 % ACCUMULATORS}, vi${M+1}x9ABCDEFG, vget_high_f16(vw01234567), 2);
194
195      $for M in range(ROW_TILE):
196        vo${M}p${8 % ACCUMULATORS} = vfmaq_lane_f16(vo${M}p${8 % ACCUMULATORS}, vi${M+2}x9ABCDEFG, vw89, 1);
197
198      $if ACCUMULATORS > 1:
199        $ACC_SLICE = 1
200        $while ACC_SLICE < ACCUMULATORS:
201          $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
202            $if A + ACC_SLICE < ACCUMULATORS:
203              $for M in range(ROW_TILE):
204                vo${M}p${A} = vaddq_f16(vo${M}p${A}, vo${M}p${A + ACC_SLICE});
205          $ACC_SLICE *= 2
206
207      $for M in range(ROW_TILE):
208        float16x8_t vo${M} = vmaxq_f16(vo${M}p0, vmin);
209
210      $for M in range(ROW_TILE):
211        vo${M} = vminq_f16(vo${M}, vmax);
212
213      if XNN_LIKELY(w == 8 * sizeof(__fp16)) {
214        $for M in reversed(range(ROW_TILE)):
215          vst1q_f16(o${M}, vo${M}); o${M} += 8;
216      } else {
217        $for M in reversed(range(ROW_TILE)):
218          float16x4_t vo${M}_lo = vget_low_f16(vo${M});
219
220        if (w & (4 * sizeof(__fp16))) {
221         $for M in reversed(range(ROW_TILE)):
222            vst1_f16(o${M}, vo${M}_lo); o${M} += 4;
223
224          $for M in reversed(range(ROW_TILE)):
225            vo${M}_lo = vget_high_f16(vo${M});
226        }
227        if (w & (2 * sizeof(__fp16))) {
228          $for M in reversed(range(ROW_TILE)):
229            vst1_lane_u32((void*) o${M}, vreinterpret_u32_f16(vo${M}_lo), 0); o${M} += 2;
230
231          $for M in range(ROW_TILE):
232            vo${M}_lo = vext_f16(vo${M}_lo, vo${M}_lo, 2);
233        }
234        if (w & (1 * sizeof(__fp16))) {
235          $for M in reversed(range(ROW_TILE)):
236            vst1_lane_f16(o${M}, vo${M}_lo, 0); o${M} += 1;
237        }
238      }
239    }
240
241    i0 = (const __fp16*) ((uintptr_t) i${ROW_TILE} - input_decrement);
242    i1 = (const __fp16*) ((uintptr_t) i${ROW_TILE+1} - input_decrement);
243    $for M in range(2, 2 + ROW_TILE):
244      i${M} = (const __fp16*) ((uintptr_t) i${M-1} + input_width);
245
246    $if ROW_TILE > 1:
247      o0 = o${ROW_TILE - 1};
248      $for M in range(1, ROW_TILE):
249        o${M} = (__fp16*) ((uintptr_t) o${M-1} + input_width);
250
251    $if ROW_TILE > 1:
252      output_height = doz(output_height, ${ROW_TILE});
253  } while (${"--" if ROW_TILE == 1 else ""}output_height != 0);
254}
255