xref: /aosp_15_r20/external/XNNPACK/src/f16-ibilinear-chw/neonfp16arith.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2022 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert PIXEL_TILE >= 1
7$assert PIXEL_TILE % 4 == 0
8$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
9#include <assert.h>
10
11#include <arm_neon.h>
12
13#include <xnnpack/ibilinear.h>
14
15
16void xnn_f16_ibilinear_chw_ukernel__neonfp16arith_p${PIXEL_TILE}(
17    size_t output_pixels,
18    size_t channels,
19    const void**restrict input,
20    size_t input_offset,
21    const void*restrict weights,
22    void*restrict output,
23    size_t input_increment) XNN_OOB_READS
24{
25  assert(output_pixels != 0);
26  assert(channels != 0);
27  assert(input_increment % sizeof(__fp16) == 0);
28
29  __fp16* o = (__fp16*) output;
30  do {
31    const __fp16** i = (const __fp16**)input;
32    const __fp16* w = weights;
33    size_t p = output_pixels;
34
35    $if PIXEL_TILE > 4:
36      for (; p >= ${PIXEL_TILE}; p -= ${PIXEL_TILE}) {
37        $for P in range(PIXEL_TILE):
38          const __fp16* itl${ABC[P]} = (const __fp16*) ((uintptr_t) i[${2 * P}] + input_offset);
39          const __fp16* ibl${ABC[P]} = (const __fp16*) ((uintptr_t) i[${2 * P + 1}] + input_offset);
40        i += 2 * ${PIXEL_TILE};
41
42        $for P in range(0, PIXEL_TILE, 4):
43          const float16x4x2_t vw${ABC[P:P+4]} = vld2_f16(w + ${2 * P});
44        w += 2 * ${PIXEL_TILE};
45
46        $for P in range(0, PIXEL_TILE, 4):
47          float16x8_t vtltr${ABC[P:P+4]} = vmovq_n_f16(0);  // vmov for uninitialized var warning
48          float16x8_t vblbr${ABC[P:P+4]} = vmovq_n_f16(0);
49          $for L in range(0, 4):
50            vtltr${ABC[P:P+4]} = vreinterpretq_f16_u32(vld1q_lane_u32((const void*) itl${ABC[P+L]}, vreinterpretq_u32_f16(vtltr${ABC[P:P+4]}), ${L}));
51            vblbr${ABC[P:P+4]} = vreinterpretq_f16_u32(vld1q_lane_u32((const void*) ibl${ABC[P+L]}, vreinterpretq_u32_f16(vblbr${ABC[P:P+4]}), ${L}));
52
53        $for P in range(0, PIXEL_TILE, 8):
54          const float16x8_t valphah${ABC[P:P+8]} = vcombine_f16(vw${ABC[P:P+4]}.val[0], vw${ABC[P+4:P+8]}.val[0]);
55          const float16x8_t valphav${ABC[P:P+8]} = vcombine_f16(vw${ABC[P:P+4]}.val[1], vw${ABC[P+4:P+8]}.val[1]);
56
57        $for P in range(0, PIXEL_TILE, 4):
58          const float16x8_t vldrd${ABC[P:P+4]} = vsubq_f16(vblbr${ABC[P:P+4]}, vtltr${ABC[P:P+4]});
59
60        $for P in range(0, PIXEL_TILE, 8):
61          const float16x8x2_t vld_t${ABC[P:P+8]} = vuzpq_f16(vldrd${ABC[P:P+4]}, vldrd${ABC[P+4:P+8]});
62          const float16x8_t vld${ABC[P:P+8]} = vld_t${ABC[P:P+8]}.val[0];
63          const float16x8_t vrd${ABC[P:P+8]} = vld_t${ABC[P:P+8]}.val[1];
64
65        $for P in range(0, PIXEL_TILE, 8):
66          const float16x8x2_t vtl_t${ABC[P:P+8]} = vuzpq_f16(vtltr${ABC[P:P+4]}, vtltr${ABC[P+4:P+8]});
67          const float16x8_t vtl${ABC[P:P+8]} = vtl_t${ABC[P:P+8]}.val[0];
68          const float16x8_t vtr${ABC[P:P+8]} = vtl_t${ABC[P:P+8]}.val[1];
69
70        $for P in range(0, PIXEL_TILE, 8):
71          const float16x8_t vl${ABC[P:P+8]} = vfmaq_f16(vtl${ABC[P:P+8]}, vld${ABC[P:P+8]}, valphav${ABC[P:P+8]});
72          const float16x8_t vr${ABC[P:P+8]} = vfmaq_f16(vtr${ABC[P:P+8]}, vrd${ABC[P:P+8]}, valphav${ABC[P:P+8]});
73
74        $for P in range(0, PIXEL_TILE, 8):
75          const float16x8_t vd${ABC[P:P+8]} = vsubq_f16(vr${ABC[P:P+8]}, vl${ABC[P:P+8]});
76        $for P in range(0, PIXEL_TILE, 8):
77          const float16x8_t vo${ABC[P:P+8]} = vfmaq_f16(vl${ABC[P:P+8]}, vd${ABC[P:P+8]}, valphah${ABC[P:P+8]});
78
79        $for P in range(0, PIXEL_TILE, 8):
80          vst1q_f16(o + ${P}, vo${ABC[P:P+8]});
81        o += ${PIXEL_TILE};
82      }
83
84    for (; p >= 4; p -= 4) {
85      $for P in range(4):
86        const __fp16* itl${ABC[P]} = (const __fp16*) ((uintptr_t) i[${2 * P}] + input_offset);
87        const __fp16* ibl${ABC[P]} = (const __fp16*) ((uintptr_t) i[${2 * P + 1}] + input_offset);
88      i += 8;
89
90      const float16x4x2_t vw = vld2_f16(w);
91      w += 8;
92
93      float16x8_t vtltr = vmovq_n_f16(0);  // vmov for uninitialized var warning
94      float16x8_t vblbr = vmovq_n_f16(0);
95      $for P in range(0, 4):
96        vtltr = vreinterpretq_f16_u32(vld1q_lane_u32((const void*) itl${ABC[P]}, vreinterpretq_u32_f16(vtltr), ${P}));
97        vblbr = vreinterpretq_f16_u32(vld1q_lane_u32((const void*) ibl${ABC[P]}, vreinterpretq_u32_f16(vblbr), ${P}));
98
99      const float16x4_t valphah = vw.val[0];
100      const float16x4_t valphav = vw.val[1];
101
102      const float16x8_t vldrd = vsubq_f16(vblbr, vtltr);
103
104      const float16x4x2_t vld_t = vuzp_f16(vget_low_f16(vldrd), vget_high_f16(vldrd));
105      const float16x4_t vld = vld_t.val[0];
106      const float16x4_t vrd = vld_t.val[1];
107
108      const float16x4x2_t vtl_t = vuzp_f16(vget_low_f16(vtltr), vget_high_f16(vtltr));
109      const float16x4_t vtl = vtl_t.val[0];
110      const float16x4_t vtr = vtl_t.val[1];
111
112      const float16x4_t vl = vfma_f16(vtl, vld, valphav);
113      const float16x4_t vr = vfma_f16(vtr, vrd, valphav);
114
115      const float16x4_t vd = vsub_f16(vr, vl);
116      const float16x4_t vo = vfma_f16(vl, vd, valphah);
117
118      vst1_f16(o, vo);
119      o += 4;
120    }
121
122    if XNN_UNLIKELY(p != 0) {
123      if (p & 2) {
124        $for P in range(2):
125          const __fp16* itl${ABC[P]} = (const __fp16*) ((uintptr_t) i[${2 * P}] + input_offset);
126          const __fp16* ibl${ABC[P]} = (const __fp16*) ((uintptr_t) i[${2 * P + 1}] + input_offset);
127        i += 4;
128
129        const float16x4_t vw = vld1_f16(w);
130        w += 4;
131
132        const float16x4x2_t vwhv = vuzp_f16(vw, vw);
133        const float16x4_t valphah = vwhv.val[0];
134        const float16x4_t valphav = vwhv.val[1];
135
136        float16x4_t vtltr = vmov_n_f16(0);  // vmov for uninitialized var warning
137        float16x4_t vblbr = vmov_n_f16(0);
138
139        $for P in range(0, 2):
140          vtltr = vreinterpret_f16_u32(vld1_lane_u32((const void*) itl${ABC[P]}, vreinterpret_u32_f16(vtltr), ${P}));
141          vblbr = vreinterpret_f16_u32(vld1_lane_u32((const void*) ibl${ABC[P]}, vreinterpret_u32_f16(vblbr), ${P}));
142
143        const float16x4_t vldrd = vsub_f16(vblbr, vtltr);
144
145        const float16x4x2_t vld_t = vuzp_f16(vldrd, vldrd);
146        const float16x4_t vld = vld_t.val[0];
147        const float16x4_t vrd = vld_t.val[1];
148
149        const float16x4x2_t vtl_t = vuzp_f16(vtltr, vtltr);
150        const float16x4_t vtl = vtl_t.val[0];
151        const float16x4_t vtr = vtl_t.val[1];
152
153        const float16x4_t vl = vfma_f16(vtl, vld, valphav);
154        const float16x4_t vr = vfma_f16(vtr, vrd, valphav);
155
156        const float16x4_t vd = vsub_f16(vr, vl);
157        const float16x4_t vo = vfma_f16(vl, vd, valphah);
158
159        vst1_lane_u32((void*) o, vreinterpret_u32_f16(vo), 0);
160        o += 2;
161      }
162
163      if (p & 1) {
164        // We are computing the following formula:
165        //   result = (1 - alpha_h) * (1 - alpha_v) * top_left +
166        //                 alpha_h  * (1 - alpha_v) * top_right +
167        //            (1 - alpha_h) *      alpha_v  * bottom_left +
168        //                 alpha_h  *      alpha_v  * bottom_right.
169        //
170        // Rearranging gives
171        //   result =    left + alpha_h * (right        - left),
172        // where
173        //   left =  top_left + alpha_v * (bottom_left  - top_left),
174        //  right = top_right + alpha_v * (bottom_right - top_right).
175
176        const __fp16* itl = (const __fp16*) ((uintptr_t) i[0] + input_offset);
177        const __fp16* ibl = (const __fp16*) ((uintptr_t) i[1] + input_offset);
178        i += 2;
179
180        float16x4_t vw = vmov_n_f16(0);
181        vw = vreinterpret_f16_u32(vld1_lane_u32((const void*) w, vreinterpret_u32_f16(vw), 0));
182        w += 2;
183
184        const float16x4x2_t vwhv = vuzp_f16(vw, vw);
185        const float16x4_t valphah = vwhv.val[0];
186        const float16x4_t valphav = vwhv.val[1];
187
188        float16x4_t vtltr = vmov_n_f16(0);  // vmov for uninitialized var warning
189        float16x4_t vblbr = vmov_n_f16(0);
190
191        vtltr = vreinterpret_f16_u32(vld1_lane_u32((const void*) itl, vreinterpret_u32_f16(vtltr), 0));
192        vblbr = vreinterpret_f16_u32(vld1_lane_u32((const void*) ibl, vreinterpret_u32_f16(vblbr), 0));
193
194        const float16x4_t vldrd = vsub_f16(vblbr, vtltr);
195
196        const float16x4x2_t vld_t = vuzp_f16(vldrd, vldrd);
197        const float16x4_t vld = vld_t.val[0];
198        const float16x4_t vrd = vld_t.val[1];
199
200        const float16x4x2_t vtl_t = vuzp_f16(vtltr, vtltr);
201        const float16x4_t vtl = vtl_t.val[0];
202        const float16x4_t vtr = vtl_t.val[1];
203
204        const float16x4_t vl = vfma_f16(vtl, vld, valphav);
205        const float16x4_t vr = vfma_f16(vtr, vrd, valphav);
206
207        const float16x4_t vd = vsub_f16(vr, vl);
208        const float16x4_t vo = vfma_f16(vl, vd, valphah);
209
210        vst1_lane_f16(o, vo, 0);
211        o += 1;
212      }
213    }
214
215    input_offset += input_increment;
216  } while (--channels != 0);
217}
218