xref: /aosp_15_r20/external/libdav1d/src/arm/64/cdef_tmpl.S (revision c09093415860a1c2373dacd84c4fde00c507cdfd)
1/*
2 * Copyright © 2018, VideoLAN and dav1d authors
3 * Copyright © 2020, Martin Storsjo
4 * All rights reserved.
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions are met:
8 *
9 * 1. Redistributions of source code must retain the above copyright notice, this
10 *    list of conditions and the following disclaimer.
11 *
12 * 2. Redistributions in binary form must reproduce the above copyright notice,
13 *    this list of conditions and the following disclaimer in the documentation
14 *    and/or other materials provided with the distribution.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
20 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28#include "src/arm/asm.S"
29#include "util.S"
30
31.macro dir_table w, stride
32const directions\w
33        .byte           -1 * \stride + 1, -2 * \stride + 2
34        .byte            0 * \stride + 1, -1 * \stride + 2
35        .byte            0 * \stride + 1,  0 * \stride + 2
36        .byte            0 * \stride + 1,  1 * \stride + 2
37        .byte            1 * \stride + 1,  2 * \stride + 2
38        .byte            1 * \stride + 0,  2 * \stride + 1
39        .byte            1 * \stride + 0,  2 * \stride + 0
40        .byte            1 * \stride + 0,  2 * \stride - 1
41// Repeated, to avoid & 7
42        .byte           -1 * \stride + 1, -2 * \stride + 2
43        .byte            0 * \stride + 1, -1 * \stride + 2
44        .byte            0 * \stride + 1,  0 * \stride + 2
45        .byte            0 * \stride + 1,  1 * \stride + 2
46        .byte            1 * \stride + 1,  2 * \stride + 2
47        .byte            1 * \stride + 0,  2 * \stride + 1
48endconst
49.endm
50
51.macro tables
52dir_table 8, 16
53dir_table 4, 8
54
55const pri_taps
56        .byte           4, 2, 3, 3
57endconst
58.endm
59
60.macro load_px d1, d2, w
61.if \w == 8
62        add             x6,  x2,  w9, sxtb #1       // x + off
63        sub             x9,  x2,  w9, sxtb #1       // x - off
64        ld1             {\d1\().8h}, [x6]           // p0
65        ld1             {\d2\().8h}, [x9]           // p1
66.else
67        add             x6,  x2,  w9, sxtb #1       // x + off
68        sub             x9,  x2,  w9, sxtb #1       // x - off
69        ld1             {\d1\().4h}, [x6]           // p0
70        add             x6,  x6,  #2*8              // += stride
71        ld1             {\d2\().4h}, [x9]           // p1
72        add             x9,  x9,  #2*8              // += stride
73        ld1             {\d1\().d}[1], [x6]         // p0
74        ld1             {\d2\().d}[1], [x9]         // p1
75.endif
76.endm
77.macro handle_pixel s1, s2, thresh_vec, shift, tap, min
78.if \min
79        umin            v2.8h,   v2.8h,  \s1\().8h
80        smax            v3.8h,   v3.8h,  \s1\().8h
81        umin            v2.8h,   v2.8h,  \s2\().8h
82        smax            v3.8h,   v3.8h,  \s2\().8h
83.endif
84        uabd            v16.8h, v0.8h,  \s1\().8h   // abs(diff)
85        uabd            v20.8h, v0.8h,  \s2\().8h   // abs(diff)
86        ushl            v17.8h, v16.8h, \shift      // abs(diff) >> shift
87        ushl            v21.8h, v20.8h, \shift      // abs(diff) >> shift
88        uqsub           v17.8h, \thresh_vec, v17.8h // clip = imax(0, threshold - (abs(diff) >> shift))
89        uqsub           v21.8h, \thresh_vec, v21.8h // clip = imax(0, threshold - (abs(diff) >> shift))
90        sub             v18.8h, \s1\().8h,  v0.8h   // diff = p0 - px
91        sub             v22.8h, \s2\().8h,  v0.8h   // diff = p1 - px
92        neg             v16.8h, v17.8h              // -clip
93        neg             v20.8h, v21.8h              // -clip
94        smin            v18.8h, v18.8h, v17.8h      // imin(diff, clip)
95        smin            v22.8h, v22.8h, v21.8h      // imin(diff, clip)
96        dup             v19.8h, \tap                // taps[k]
97        smax            v18.8h, v18.8h, v16.8h      // constrain() = imax(imin(diff, clip), -clip)
98        smax            v22.8h, v22.8h, v20.8h      // constrain() = imax(imin(diff, clip), -clip)
99        mla             v1.8h,  v18.8h, v19.8h      // sum += taps[k] * constrain()
100        mla             v1.8h,  v22.8h, v19.8h      // sum += taps[k] * constrain()
101.endm
102
103// void dav1d_cdef_filterX_Ybpc_neon(pixel *dst, ptrdiff_t dst_stride,
104//                                   const uint16_t *tmp, int pri_strength,
105//                                   int sec_strength, int dir, int damping,
106//                                   int h, size_t edges);
107.macro filter_func w, bpc, pri, sec, min, suffix
108function cdef_filter\w\suffix\()_\bpc\()bpc_neon
109.if \bpc == 8
110        ldr             w8,  [sp]                   // edges
111        cmp             w8,  #0xf
112        b.eq            cdef_filter\w\suffix\()_edged_8bpc_neon
113.endif
114.if \pri
115.if \bpc == 16
116        ldr             w9,  [sp, #8]               // bitdepth_max
117        clz             w9,  w9
118        sub             w9,  w9,  #24               // -bitdepth_min_8
119        neg             w9,  w9                     // bitdepth_min_8
120.endif
121        movrel          x8,  pri_taps
122.if \bpc == 16
123        lsr             w9,  w3,  w9                // pri_strength >> bitdepth_min_8
124        and             w9,  w9,  #1                // (pri_strength >> bitdepth_min_8) & 1
125.else
126        and             w9,  w3,  #1
127.endif
128        add             x8,  x8,  w9, uxtw #1
129.endif
130        movrel          x9,  directions\w
131        add             x5,  x9,  w5, uxtw #1
132        movi            v30.4h,   #15
133        dup             v28.4h,   w6                // damping
134
135.if \pri
136        dup             v25.8h, w3                  // threshold
137.endif
138.if \sec
139        dup             v27.8h, w4                  // threshold
140.endif
141        trn1            v24.4h, v25.4h, v27.4h
142        clz             v24.4h, v24.4h              // clz(threshold)
143        sub             v24.4h, v30.4h, v24.4h      // ulog2(threshold)
144        uqsub           v24.4h, v28.4h, v24.4h      // shift = imax(0, damping - ulog2(threshold))
145        neg             v24.4h, v24.4h              // -shift
146.if \sec
147        dup             v26.8h, v24.h[1]
148.endif
149.if \pri
150        dup             v24.8h, v24.h[0]
151.endif
152
1531:
154.if \w == 8
155        ld1             {v0.8h}, [x2]               // px
156.else
157        add             x12, x2,  #2*8
158        ld1             {v0.4h},   [x2]             // px
159        ld1             {v0.d}[1], [x12]            // px
160.endif
161
162        movi            v1.8h,  #0                  // sum
163.if \min
164        mov             v2.16b, v0.16b              // min
165        mov             v3.16b, v0.16b              // max
166.endif
167
168        // Instead of loading sec_taps 2, 1 from memory, just set it
169        // to 2 initially and decrease for the second round.
170        // This is also used as loop counter.
171        mov             w11, #2                     // sec_taps[0]
172
1732:
174.if \pri
175        ldrb            w9,  [x5]                   // off1
176
177        load_px         v4,  v5, \w
178.endif
179
180.if \sec
181        add             x5,  x5,  #4                // +2*2
182        ldrb            w9,  [x5]                   // off2
183        load_px         v6,  v7,  \w
184.endif
185
186.if \pri
187        ldrb            w10, [x8]                   // *pri_taps
188
189        handle_pixel    v4,  v5,  v25.8h, v24.8h, w10, \min
190.endif
191
192.if \sec
193        add             x5,  x5,  #8                // +2*4
194        ldrb            w9,  [x5]                   // off3
195        load_px         v4,  v5,  \w
196
197        handle_pixel    v6,  v7,  v27.8h, v26.8h, w11, \min
198
199        handle_pixel    v4,  v5,  v27.8h, v26.8h, w11, \min
200
201        sub             x5,  x5,  #11               // x5 -= 2*(2+4); x5 += 1;
202.else
203        add             x5,  x5,  #1                // x5 += 1
204.endif
205        subs            w11, w11, #1                // sec_tap-- (value)
206.if \pri
207        add             x8,  x8,  #1                // pri_taps++ (pointer)
208.endif
209        b.ne            2b
210
211        cmlt            v4.8h,  v1.8h,  #0          // -(sum < 0)
212        add             v1.8h,  v1.8h,  v4.8h       // sum - (sum < 0)
213        srshr           v1.8h,  v1.8h,  #4          // (8 + sum - (sum < 0)) >> 4
214        add             v0.8h,  v0.8h,  v1.8h       // px + (8 + sum ...) >> 4
215.if \min
216        smin            v0.8h,  v0.8h,  v3.8h
217        smax            v0.8h,  v0.8h,  v2.8h       // iclip(px + .., min, max)
218.endif
219.if \bpc == 8
220        xtn             v0.8b,  v0.8h
221.endif
222.if \w == 8
223        add             x2,  x2,  #2*16             // tmp += tmp_stride
224        subs            w7,  w7,  #1                // h--
225.if \bpc == 8
226        st1             {v0.8b}, [x0], x1
227.else
228        st1             {v0.8h}, [x0], x1
229.endif
230.else
231.if \bpc == 8
232        st1             {v0.s}[0], [x0], x1
233.else
234        st1             {v0.d}[0], [x0], x1
235.endif
236        add             x2,  x2,  #2*16             // tmp += 2*tmp_stride
237        subs            w7,  w7,  #2                // h -= 2
238.if \bpc == 8
239        st1             {v0.s}[1], [x0], x1
240.else
241        st1             {v0.d}[1], [x0], x1
242.endif
243.endif
244
245        // Reset pri_taps and directions back to the original point
246        sub             x5,  x5,  #2
247.if \pri
248        sub             x8,  x8,  #2
249.endif
250
251        b.gt            1b
252        ret
253endfunc
254.endm
255
256.macro filter w, bpc
257filter_func \w, \bpc, pri=1, sec=0, min=0, suffix=_pri
258filter_func \w, \bpc, pri=0, sec=1, min=0, suffix=_sec
259filter_func \w, \bpc, pri=1, sec=1, min=1, suffix=_pri_sec
260
261function cdef_filter\w\()_\bpc\()bpc_neon, export=1
262        cbnz            w3,  1f // pri_strength
263        b               cdef_filter\w\()_sec_\bpc\()bpc_neon     // only sec
2641:
265        cbnz            w4,  1f // sec_strength
266        b               cdef_filter\w\()_pri_\bpc\()bpc_neon     // only pri
2671:
268        b               cdef_filter\w\()_pri_sec_\bpc\()bpc_neon // both pri and sec
269endfunc
270.endm
271
272const div_table
273        .short         840, 420, 280, 210, 168, 140, 120, 105
274endconst
275
276const alt_fact
277        .short         420, 210, 140, 105, 105, 105, 105, 105, 140, 210, 420, 0
278endconst
279
280.macro cost_alt d1, d2, s1, s2, s3, s4
281        smull           v22.4s,  \s1\().4h, \s1\().4h // sum_alt[n]*sum_alt[n]
282        smull2          v23.4s,  \s1\().8h, \s1\().8h
283        smull           v24.4s,  \s2\().4h, \s2\().4h
284        smull           v25.4s,  \s3\().4h, \s3\().4h // sum_alt[n]*sum_alt[n]
285        smull2          v26.4s,  \s3\().8h, \s3\().8h
286        smull           v27.4s,  \s4\().4h, \s4\().4h
287        mul             v22.4s,  v22.4s,  v29.4s      // sum_alt[n]^2*fact
288        mla             v22.4s,  v23.4s,  v30.4s
289        mla             v22.4s,  v24.4s,  v31.4s
290        mul             v25.4s,  v25.4s,  v29.4s      // sum_alt[n]^2*fact
291        mla             v25.4s,  v26.4s,  v30.4s
292        mla             v25.4s,  v27.4s,  v31.4s
293        addv            \d1, v22.4s                   // *cost_ptr
294        addv            \d2, v25.4s                   // *cost_ptr
295.endm
296
297.macro find_best s1, s2, s3
298.ifnb \s2
299        mov             w5,  \s2\().s[0]
300.endif
301        cmp             w4,  w1                       // cost[n] > best_cost
302        csel            w0,  w3,  w0,  gt             // best_dir = n
303        csel            w1,  w4,  w1,  gt             // best_cost = cost[n]
304.ifnb \s2
305        add             w3,  w3,  #1                  // n++
306        cmp             w5,  w1                       // cost[n] > best_cost
307        mov             w4,  \s3\().s[0]
308        csel            w0,  w3,  w0,  gt             // best_dir = n
309        csel            w1,  w5,  w1,  gt             // best_cost = cost[n]
310        add             w3,  w3,  #1                  // n++
311.endif
312.endm
313
314// Steps for loading and preparing each row
315.macro dir_load_step1 s1, bpc
316.if \bpc == 8
317        ld1             {\s1\().8b}, [x0], x1
318.else
319        ld1             {\s1\().8h}, [x0], x1
320.endif
321.endm
322
323.macro dir_load_step2 s1, bpc
324.if \bpc == 8
325        usubl           \s1\().8h,  \s1\().8b, v31.8b
326.else
327        ushl            \s1\().8h,  \s1\().8h, v8.8h
328.endif
329.endm
330
331.macro dir_load_step3 s1, bpc
332// Nothing for \bpc == 8
333.if \bpc != 8
334        sub             \s1\().8h,  \s1\().8h, v31.8h
335.endif
336.endm
337
338// int dav1d_cdef_find_dir_Xbpc_neon(const pixel *img, const ptrdiff_t stride,
339//                                   unsigned *const var)
340.macro find_dir bpc
341function cdef_find_dir_\bpc\()bpc_neon, export=1
342.if \bpc == 16
343        str             d8,  [sp, #-0x10]!
344        clz             w3,  w3                       // clz(bitdepth_max)
345        sub             w3,  w3,  #24                 // -bitdepth_min_8
346        dup             v8.8h,   w3
347.endif
348        sub             sp,  sp,  #32 // cost
349        mov             w3,  #8
350.if \bpc == 8
351        movi            v31.16b, #128
352.else
353        movi            v31.8h,  #128
354.endif
355        movi            v30.16b, #0
356        movi            v1.8h,   #0 // v0-v1 sum_diag[0]
357        movi            v3.8h,   #0 // v2-v3 sum_diag[1]
358        movi            v5.8h,   #0 // v4-v5 sum_hv[0-1]
359        movi            v7.8h,   #0 // v6-v7 sum_alt[0]
360        dir_load_step1  v26, \bpc       // Setup first row early
361        movi            v17.8h,  #0 // v16-v17 sum_alt[1]
362        movi            v18.8h,  #0 // v18-v19 sum_alt[2]
363        dir_load_step2  v26, \bpc
364        movi            v19.8h,  #0
365        dir_load_step3  v26, \bpc
366        movi            v21.8h,  #0 // v20-v21 sum_alt[3]
367
368.irpc i, 01234567
369        addv            h25,     v26.8h               // [y]
370        rev64           v27.8h,  v26.8h
371        addp            v28.8h,  v26.8h,  v30.8h      // [(x >> 1)]
372        add             v5.8h,   v5.8h,   v26.8h      // sum_hv[1]
373        ext             v27.16b, v27.16b, v27.16b, #8 // [-x]
374        rev64           v29.4h,  v28.4h               // [-(x >> 1)]
375        ins             v4.h[\i], v25.h[0]            // sum_hv[0]
376.if \i < 6
377        ext             v22.16b, v30.16b, v26.16b, #(16-2*(3-(\i/2)))
378        ext             v23.16b, v26.16b, v30.16b, #(16-2*(3-(\i/2)))
379        add             v18.8h,  v18.8h,  v22.8h      // sum_alt[2]
380        add             v19.4h,  v19.4h,  v23.4h      // sum_alt[2]
381.else
382        add             v18.8h,  v18.8h,  v26.8h      // sum_alt[2]
383.endif
384.if \i == 0
385        mov             v20.16b, v26.16b              // sum_alt[3]
386.elseif \i == 1
387        add             v20.8h,  v20.8h,  v26.8h      // sum_alt[3]
388.else
389        ext             v24.16b, v30.16b, v26.16b, #(16-2*(\i/2))
390        ext             v25.16b, v26.16b, v30.16b, #(16-2*(\i/2))
391        add             v20.8h,  v20.8h,  v24.8h      // sum_alt[3]
392        add             v21.4h,  v21.4h,  v25.4h      // sum_alt[3]
393.endif
394.if \i == 0
395        mov             v0.16b,  v26.16b              // sum_diag[0]
396        dir_load_step1  v26, \bpc
397        mov             v2.16b,  v27.16b              // sum_diag[1]
398        dir_load_step2  v26, \bpc
399        mov             v6.16b,  v28.16b              // sum_alt[0]
400        dir_load_step3  v26, \bpc
401        mov             v16.16b, v29.16b              // sum_alt[1]
402.else
403        ext             v22.16b, v30.16b, v26.16b, #(16-2*\i)
404        ext             v23.16b, v26.16b, v30.16b, #(16-2*\i)
405        ext             v24.16b, v30.16b, v27.16b, #(16-2*\i)
406        ext             v25.16b, v27.16b, v30.16b, #(16-2*\i)
407.if \i != 7 // Nothing to load for the final row
408        dir_load_step1  v26, \bpc // Start setting up the next row early.
409.endif
410        add             v0.8h,   v0.8h,   v22.8h      // sum_diag[0]
411        add             v1.8h,   v1.8h,   v23.8h      // sum_diag[0]
412        add             v2.8h,   v2.8h,   v24.8h      // sum_diag[1]
413        add             v3.8h,   v3.8h,   v25.8h      // sum_diag[1]
414.if \i != 7
415        dir_load_step2  v26, \bpc
416.endif
417        ext             v22.16b, v30.16b, v28.16b, #(16-2*\i)
418        ext             v23.16b, v28.16b, v30.16b, #(16-2*\i)
419        ext             v24.16b, v30.16b, v29.16b, #(16-2*\i)
420        ext             v25.16b, v29.16b, v30.16b, #(16-2*\i)
421.if \i != 7
422        dir_load_step3  v26, \bpc
423.endif
424        add             v6.8h,   v6.8h,   v22.8h      // sum_alt[0]
425        add             v7.4h,   v7.4h,   v23.4h      // sum_alt[0]
426        add             v16.8h,  v16.8h,  v24.8h      // sum_alt[1]
427        add             v17.4h,  v17.4h,  v25.4h      // sum_alt[1]
428.endif
429.endr
430
431        movi            v31.4s,  #105
432
433        smull           v26.4s,  v4.4h,   v4.4h       // sum_hv[0]*sum_hv[0]
434        smlal2          v26.4s,  v4.8h,   v4.8h
435        smull           v27.4s,  v5.4h,   v5.4h       // sum_hv[1]*sum_hv[1]
436        smlal2          v27.4s,  v5.8h,   v5.8h
437        mul             v26.4s,  v26.4s,  v31.4s      // cost[2] *= 105
438        mul             v27.4s,  v27.4s,  v31.4s      // cost[6] *= 105
439        addv            s4,  v26.4s                   // cost[2]
440        addv            s5,  v27.4s                   // cost[6]
441
442        rev64           v1.8h,   v1.8h
443        rev64           v3.8h,   v3.8h
444        ext             v1.16b,  v1.16b,  v1.16b, #10 // sum_diag[0][14-n]
445        ext             v3.16b,  v3.16b,  v3.16b, #10 // sum_diag[1][14-n]
446
447        str             s4,  [sp, #2*4]               // cost[2]
448        str             s5,  [sp, #6*4]               // cost[6]
449
450        movrel          x4,  div_table
451        ld1             {v31.8h}, [x4]
452
453        smull           v22.4s,  v0.4h,   v0.4h       // sum_diag[0]*sum_diag[0]
454        smull2          v23.4s,  v0.8h,   v0.8h
455        smlal           v22.4s,  v1.4h,   v1.4h
456        smlal2          v23.4s,  v1.8h,   v1.8h
457        smull           v24.4s,  v2.4h,   v2.4h       // sum_diag[1]*sum_diag[1]
458        smull2          v25.4s,  v2.8h,   v2.8h
459        smlal           v24.4s,  v3.4h,   v3.4h
460        smlal2          v25.4s,  v3.8h,   v3.8h
461        uxtl            v30.4s,  v31.4h               // div_table
462        uxtl2           v31.4s,  v31.8h
463        mul             v22.4s,  v22.4s,  v30.4s      // cost[0]
464        mla             v22.4s,  v23.4s,  v31.4s      // cost[0]
465        mul             v24.4s,  v24.4s,  v30.4s      // cost[4]
466        mla             v24.4s,  v25.4s,  v31.4s      // cost[4]
467        addv            s0,  v22.4s                   // cost[0]
468        addv            s2,  v24.4s                   // cost[4]
469
470        movrel          x5,  alt_fact
471        ld1             {v29.4h, v30.4h, v31.4h}, [x5]// div_table[2*m+1] + 105
472
473        str             s0,  [sp, #0*4]               // cost[0]
474        str             s2,  [sp, #4*4]               // cost[4]
475
476        uxtl            v29.4s,  v29.4h               // div_table[2*m+1] + 105
477        uxtl            v30.4s,  v30.4h
478        uxtl            v31.4s,  v31.4h
479
480        cost_alt        s6,  s16, v6,  v7,  v16, v17  // cost[1], cost[3]
481        cost_alt        s18, s20, v18, v19, v20, v21  // cost[5], cost[7]
482        str             s6,  [sp, #1*4]               // cost[1]
483        str             s16, [sp, #3*4]               // cost[3]
484
485        mov             w0,  #0                       // best_dir
486        mov             w1,  v0.s[0]                  // best_cost
487        mov             w3,  #1                       // n
488
489        str             s18, [sp, #5*4]               // cost[5]
490        str             s20, [sp, #7*4]               // cost[7]
491
492        mov             w4,  v6.s[0]
493
494        find_best       v6,  v4, v16
495        find_best       v16, v2, v18
496        find_best       v18, v5, v20
497        find_best       v20
498
499        eor             w3,  w0,  #4                  // best_dir ^4
500        ldr             w4,  [sp, w3, uxtw #2]
501        sub             w1,  w1,  w4                  // best_cost - cost[best_dir ^ 4]
502        lsr             w1,  w1,  #10
503        str             w1,  [x2]                     // *var
504
505        add             sp,  sp,  #32
506.if \bpc == 16
507        ldr             d8,  [sp], 0x10
508.endif
509        ret
510endfunc
511.endm
512