xref: /aosp_15_r20/external/libdav1d/src/refmvs.c (revision c09093415860a1c2373dacd84c4fde00c507cdfd)
1 /*
2  * Copyright © 2020, VideoLAN and dav1d authors
3  * Copyright © 2020, Two Orioles, LLC
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  *
9  * 1. Redistributions of source code must retain the above copyright notice, this
10  *    list of conditions and the following disclaimer.
11  *
12  * 2. Redistributions in binary form must reproduce the above copyright notice,
13  *    this list of conditions and the following disclaimer in the documentation
14  *    and/or other materials provided with the distribution.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
20  * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  */
27 
28 #include "config.h"
29 
30 #include <limits.h>
31 #include <stdlib.h>
32 
33 #include "dav1d/common.h"
34 
35 #include "common/intops.h"
36 
37 #include "src/env.h"
38 #include "src/mem.h"
39 #include "src/refmvs.h"
40 
add_spatial_candidate(refmvs_candidate * const mvstack,int * const cnt,const int weight,const refmvs_block * const b,const union refmvs_refpair ref,const mv gmv[2],int * const have_newmv_match,int * const have_refmv_match)41 static void add_spatial_candidate(refmvs_candidate *const mvstack, int *const cnt,
42                                   const int weight, const refmvs_block *const b,
43                                   const union refmvs_refpair ref, const mv gmv[2],
44                                   int *const have_newmv_match,
45                                   int *const have_refmv_match)
46 {
47     if (b->mv.mv[0].n == INVALID_MV) return; // intra block, no intrabc
48 
49     if (ref.ref[1] == -1) {
50         for (int n = 0; n < 2; n++) {
51             if (b->ref.ref[n] == ref.ref[0]) {
52                 const mv cand_mv = ((b->mf & 1) && gmv[0].n != INVALID_MV) ?
53                                    gmv[0] : b->mv.mv[n];
54 
55                 *have_refmv_match = 1;
56                 *have_newmv_match |= b->mf >> 1;
57 
58                 const int last = *cnt;
59                 for (int m = 0; m < last; m++)
60                     if (mvstack[m].mv.mv[0].n == cand_mv.n) {
61                         mvstack[m].weight += weight;
62                         return;
63                     }
64 
65                 if (last < 8) {
66                     mvstack[last].mv.mv[0] = cand_mv;
67                     mvstack[last].weight = weight;
68                     *cnt = last + 1;
69                 }
70                 return;
71             }
72         }
73     } else if (b->ref.pair == ref.pair) {
74         const refmvs_mvpair cand_mv = { .mv = {
75             [0] = ((b->mf & 1) && gmv[0].n != INVALID_MV) ? gmv[0] : b->mv.mv[0],
76             [1] = ((b->mf & 1) && gmv[1].n != INVALID_MV) ? gmv[1] : b->mv.mv[1],
77         }};
78 
79         *have_refmv_match = 1;
80         *have_newmv_match |= b->mf >> 1;
81 
82         const int last = *cnt;
83         for (int n = 0; n < last; n++)
84             if (mvstack[n].mv.n == cand_mv.n) {
85                 mvstack[n].weight += weight;
86                 return;
87             }
88 
89         if (last < 8) {
90             mvstack[last].mv = cand_mv;
91             mvstack[last].weight = weight;
92             *cnt = last + 1;
93         }
94     }
95 }
96 
scan_row(refmvs_candidate * const mvstack,int * const cnt,const union refmvs_refpair ref,const mv gmv[2],const refmvs_block * b,const int bw4,const int w4,const int max_rows,const int step,int * const have_newmv_match,int * const have_refmv_match)97 static int scan_row(refmvs_candidate *const mvstack, int *const cnt,
98                     const union refmvs_refpair ref, const mv gmv[2],
99                     const refmvs_block *b, const int bw4, const int w4,
100                     const int max_rows, const int step,
101                     int *const have_newmv_match, int *const have_refmv_match)
102 {
103     const refmvs_block *cand_b = b;
104     const enum BlockSize first_cand_bs = cand_b->bs;
105     const uint8_t *const first_cand_b_dim = dav1d_block_dimensions[first_cand_bs];
106     int cand_bw4 = first_cand_b_dim[0];
107     int len = imax(step, imin(bw4, cand_bw4));
108 
109     if (bw4 <= cand_bw4) {
110         // FIXME weight can be higher for odd blocks (bx4 & 1), but then the
111         // position of the first block has to be odd already, i.e. not just
112         // for row_offset=-3/-5
113         // FIXME why can this not be cand_bw4?
114         const int weight = bw4 == 1 ? 2 :
115                            imax(2, imin(2 * max_rows, first_cand_b_dim[1]));
116         add_spatial_candidate(mvstack, cnt, len * weight, cand_b, ref, gmv,
117                               have_newmv_match, have_refmv_match);
118         return weight >> 1;
119     }
120 
121     for (int x = 0;;) {
122         // FIXME if we overhang above, we could fill a bitmask so we don't have
123         // to repeat the add_spatial_candidate() for the next row, but just increase
124         // the weight here
125         add_spatial_candidate(mvstack, cnt, len * 2, cand_b, ref, gmv,
126                               have_newmv_match, have_refmv_match);
127         x += len;
128         if (x >= w4) return 1;
129         cand_b = &b[x];
130         cand_bw4 = dav1d_block_dimensions[cand_b->bs][0];
131         assert(cand_bw4 < bw4);
132         len = imax(step, cand_bw4);
133     }
134 }
135 
scan_col(refmvs_candidate * const mvstack,int * const cnt,const union refmvs_refpair ref,const mv gmv[2],refmvs_block * const * b,const int bh4,const int h4,const int bx4,const int max_cols,const int step,int * const have_newmv_match,int * const have_refmv_match)136 static int scan_col(refmvs_candidate *const mvstack, int *const cnt,
137                     const union refmvs_refpair ref, const mv gmv[2],
138                     /*const*/ refmvs_block *const *b, const int bh4, const int h4,
139                     const int bx4, const int max_cols, const int step,
140                     int *const have_newmv_match, int *const have_refmv_match)
141 {
142     const refmvs_block *cand_b = &b[0][bx4];
143     const enum BlockSize first_cand_bs = cand_b->bs;
144     const uint8_t *const first_cand_b_dim = dav1d_block_dimensions[first_cand_bs];
145     int cand_bh4 = first_cand_b_dim[1];
146     int len = imax(step, imin(bh4, cand_bh4));
147 
148     if (bh4 <= cand_bh4) {
149         // FIXME weight can be higher for odd blocks (by4 & 1), but then the
150         // position of the first block has to be odd already, i.e. not just
151         // for col_offset=-3/-5
152         // FIXME why can this not be cand_bh4?
153         const int weight = bh4 == 1 ? 2 :
154                            imax(2, imin(2 * max_cols, first_cand_b_dim[0]));
155         add_spatial_candidate(mvstack, cnt, len * weight, cand_b, ref, gmv,
156                             have_newmv_match, have_refmv_match);
157         return weight >> 1;
158     }
159 
160     for (int y = 0;;) {
161         // FIXME if we overhang above, we could fill a bitmask so we don't have
162         // to repeat the add_spatial_candidate() for the next row, but just increase
163         // the weight here
164         add_spatial_candidate(mvstack, cnt, len * 2, cand_b, ref, gmv,
165                               have_newmv_match, have_refmv_match);
166         y += len;
167         if (y >= h4) return 1;
168         cand_b = &b[y][bx4];
169         cand_bh4 = dav1d_block_dimensions[cand_b->bs][1];
170         assert(cand_bh4 < bh4);
171         len = imax(step, cand_bh4);
172     }
173 }
174 
mv_projection(const union mv mv,const int num,const int den)175 static inline union mv mv_projection(const union mv mv, const int num, const int den) {
176     static const uint16_t div_mult[32] = {
177            0, 16384, 8192, 5461, 4096, 3276, 2730, 2340,
178         2048,  1820, 1638, 1489, 1365, 1260, 1170, 1092,
179         1024,   963,  910,  862,  819,  780,  744,  712,
180          682,   655,  630,  606,  585,  564,  546,  528
181     };
182     assert(den > 0 && den < 32);
183     assert(num > -32 && num < 32);
184     const int frac = num * div_mult[den];
185     const int y = mv.y * frac, x = mv.x * frac;
186     // Round and clip according to AV1 spec section 7.9.3
187     return (union mv) { // 0x3fff == (1 << 14) - 1
188         .y = iclip((y + 8192 + (y >> 31)) >> 14, -0x3fff, 0x3fff),
189         .x = iclip((x + 8192 + (x >> 31)) >> 14, -0x3fff, 0x3fff)
190     };
191 }
192 
add_temporal_candidate(const refmvs_frame * const rf,refmvs_candidate * const mvstack,int * const cnt,const refmvs_temporal_block * const rb,const union refmvs_refpair ref,int * const globalmv_ctx,const union mv gmv[])193 static void add_temporal_candidate(const refmvs_frame *const rf,
194                                    refmvs_candidate *const mvstack, int *const cnt,
195                                    const refmvs_temporal_block *const rb,
196                                    const union refmvs_refpair ref, int *const globalmv_ctx,
197                                    const union mv gmv[])
198 {
199     if (rb->mv.n == INVALID_MV) return;
200 
201     union mv mv = mv_projection(rb->mv, rf->pocdiff[ref.ref[0] - 1], rb->ref);
202     fix_mv_precision(rf->frm_hdr, &mv);
203 
204     const int last = *cnt;
205     if (ref.ref[1] == -1) {
206         if (globalmv_ctx)
207             *globalmv_ctx = (abs(mv.x - gmv[0].x) | abs(mv.y - gmv[0].y)) >= 16;
208 
209         for (int n = 0; n < last; n++)
210             if (mvstack[n].mv.mv[0].n == mv.n) {
211                 mvstack[n].weight += 2;
212                 return;
213             }
214         if (last < 8) {
215             mvstack[last].mv.mv[0] = mv;
216             mvstack[last].weight = 2;
217             *cnt = last + 1;
218         }
219     } else {
220         refmvs_mvpair mvp = { .mv = {
221             [0] = mv,
222             [1] = mv_projection(rb->mv, rf->pocdiff[ref.ref[1] - 1], rb->ref),
223         }};
224         fix_mv_precision(rf->frm_hdr, &mvp.mv[1]);
225 
226         for (int n = 0; n < last; n++)
227             if (mvstack[n].mv.n == mvp.n) {
228                 mvstack[n].weight += 2;
229                 return;
230             }
231         if (last < 8) {
232             mvstack[last].mv = mvp;
233             mvstack[last].weight = 2;
234             *cnt = last + 1;
235         }
236     }
237 }
238 
add_compound_extended_candidate(refmvs_candidate * const same,int * const same_count,const refmvs_block * const cand_b,const int sign0,const int sign1,const union refmvs_refpair ref,const uint8_t * const sign_bias)239 static void add_compound_extended_candidate(refmvs_candidate *const same,
240                                             int *const same_count,
241                                             const refmvs_block *const cand_b,
242                                             const int sign0, const int sign1,
243                                             const union refmvs_refpair ref,
244                                             const uint8_t *const sign_bias)
245 {
246     refmvs_candidate *const diff = &same[2];
247     int *const diff_count = &same_count[2];
248 
249     for (int n = 0; n < 2; n++) {
250         const int cand_ref = cand_b->ref.ref[n];
251 
252         if (cand_ref <= 0) break;
253 
254         mv cand_mv = cand_b->mv.mv[n];
255         if (cand_ref == ref.ref[0]) {
256             if (same_count[0] < 2)
257                 same[same_count[0]++].mv.mv[0] = cand_mv;
258             if (diff_count[1] < 2) {
259                 if (sign1 ^ sign_bias[cand_ref - 1]) {
260                     cand_mv.y = -cand_mv.y;
261                     cand_mv.x = -cand_mv.x;
262                 }
263                 diff[diff_count[1]++].mv.mv[1] = cand_mv;
264             }
265         } else if (cand_ref == ref.ref[1]) {
266             if (same_count[1] < 2)
267                 same[same_count[1]++].mv.mv[1] = cand_mv;
268             if (diff_count[0] < 2) {
269                 if (sign0 ^ sign_bias[cand_ref - 1]) {
270                     cand_mv.y = -cand_mv.y;
271                     cand_mv.x = -cand_mv.x;
272                 }
273                 diff[diff_count[0]++].mv.mv[0] = cand_mv;
274             }
275         } else {
276             mv i_cand_mv = (union mv) {
277                 .x = -cand_mv.x,
278                 .y = -cand_mv.y
279             };
280 
281             if (diff_count[0] < 2) {
282                 diff[diff_count[0]++].mv.mv[0] =
283                     sign0 ^ sign_bias[cand_ref - 1] ?
284                     i_cand_mv : cand_mv;
285             }
286 
287             if (diff_count[1] < 2) {
288                 diff[diff_count[1]++].mv.mv[1] =
289                     sign1 ^ sign_bias[cand_ref - 1] ?
290                     i_cand_mv : cand_mv;
291             }
292         }
293     }
294 }
295 
add_single_extended_candidate(refmvs_candidate mvstack[8],int * const cnt,const refmvs_block * const cand_b,const int sign,const uint8_t * const sign_bias)296 static void add_single_extended_candidate(refmvs_candidate mvstack[8], int *const cnt,
297                                           const refmvs_block *const cand_b,
298                                           const int sign, const uint8_t *const sign_bias)
299 {
300     for (int n = 0; n < 2; n++) {
301         const int cand_ref = cand_b->ref.ref[n];
302 
303         if (cand_ref <= 0) break;
304         // we need to continue even if cand_ref == ref.ref[0], since
305         // the candidate could have been added as a globalmv variant,
306         // which changes the value
307         // FIXME if scan_{row,col}() returned a mask for the nearest
308         // edge, we could skip the appropriate ones here
309 
310         mv cand_mv = cand_b->mv.mv[n];
311         if (sign ^ sign_bias[cand_ref - 1]) {
312             cand_mv.y = -cand_mv.y;
313             cand_mv.x = -cand_mv.x;
314         }
315 
316         int m;
317         const int last = *cnt;
318         for (m = 0; m < last; m++)
319             if (cand_mv.n == mvstack[m].mv.mv[0].n)
320                 break;
321         if (m == last) {
322             mvstack[m].mv.mv[0] = cand_mv;
323             mvstack[m].weight = 2; // "minimal"
324             *cnt = last + 1;
325         }
326     }
327 }
328 
329 /*
330  * refmvs_frame allocates memory for one sbrow (32 blocks high, whole frame
331  * wide) of 4x4-resolution refmvs_block entries for spatial MV referencing.
332  * mvrefs_tile[] keeps a list of 35 (32 + 3 above) pointers into this memory,
333  * and each sbrow, the bottom entries (y=27/29/31) are exchanged with the top
334  * (-5/-3/-1) pointers by calling dav1d_refmvs_tile_sbrow_init() at the start
335  * of each tile/sbrow.
336  *
337  * For temporal MV referencing, we call dav1d_refmvs_save_tmvs() at the end of
338  * each tile/sbrow (when tile column threading is enabled), or at the start of
339  * each interleaved sbrow (i.e. once for all tile columns together, when tile
340  * column threading is disabled). This will copy the 4x4-resolution spatial MVs
341  * into 8x8-resolution refmvs_temporal_block structures. Then, for subsequent
342  * frames, at the start of each tile/sbrow (when tile column threading is
343  * enabled) or at the start of each interleaved sbrow (when tile column
344  * threading is disabled), we call load_tmvs(), which will project the MVs to
345  * their respective position in the current frame.
346  */
347 
dav1d_refmvs_find(const refmvs_tile * const rt,refmvs_candidate mvstack[8],int * const cnt,int * const ctx,const union refmvs_refpair ref,const enum BlockSize bs,const enum EdgeFlags edge_flags,const int by4,const int bx4)348 void dav1d_refmvs_find(const refmvs_tile *const rt,
349                        refmvs_candidate mvstack[8], int *const cnt,
350                        int *const ctx,
351                        const union refmvs_refpair ref, const enum BlockSize bs,
352                        const enum EdgeFlags edge_flags,
353                        const int by4, const int bx4)
354 {
355     const refmvs_frame *const rf = rt->rf;
356     const uint8_t *const b_dim = dav1d_block_dimensions[bs];
357     const int bw4 = b_dim[0], w4 = imin(imin(bw4, 16), rt->tile_col.end - bx4);
358     const int bh4 = b_dim[1], h4 = imin(imin(bh4, 16), rt->tile_row.end - by4);
359     mv gmv[2], tgmv[2];
360 
361     *cnt = 0;
362     assert(ref.ref[0] >=  0 && ref.ref[0] <= 8 &&
363            ref.ref[1] >= -1 && ref.ref[1] <= 8);
364     if (ref.ref[0] > 0) {
365         tgmv[0] = get_gmv_2d(&rf->frm_hdr->gmv[ref.ref[0] - 1],
366                              bx4, by4, bw4, bh4, rf->frm_hdr);
367         gmv[0] = rf->frm_hdr->gmv[ref.ref[0] - 1].type > DAV1D_WM_TYPE_TRANSLATION ?
368                  tgmv[0] : (mv) { .n = INVALID_MV };
369     } else {
370         tgmv[0] = (mv) { .n = 0 };
371         gmv[0] = (mv) { .n = INVALID_MV };
372     }
373     if (ref.ref[1] > 0) {
374         tgmv[1] = get_gmv_2d(&rf->frm_hdr->gmv[ref.ref[1] - 1],
375                              bx4, by4, bw4, bh4, rf->frm_hdr);
376         gmv[1] = rf->frm_hdr->gmv[ref.ref[1] - 1].type > DAV1D_WM_TYPE_TRANSLATION ?
377                  tgmv[1] : (mv) { .n = INVALID_MV };
378     }
379 
380     // top
381     int have_newmv = 0, have_col_mvs = 0, have_row_mvs = 0;
382     unsigned max_rows = 0, n_rows = ~0;
383     const refmvs_block *b_top;
384     if (by4 > rt->tile_row.start) {
385         max_rows = imin((by4 - rt->tile_row.start + 1) >> 1, 2 + (bh4 > 1));
386         b_top = &rt->r[(by4 & 31) - 1 + 5][bx4];
387         n_rows = scan_row(mvstack, cnt, ref, gmv, b_top,
388                           bw4, w4, max_rows, bw4 >= 16 ? 4 : 1,
389                           &have_newmv, &have_row_mvs);
390     }
391 
392     // left
393     unsigned max_cols = 0, n_cols = ~0U;
394     refmvs_block *const *b_left;
395     if (bx4 > rt->tile_col.start) {
396         max_cols = imin((bx4 - rt->tile_col.start + 1) >> 1, 2 + (bw4 > 1));
397         b_left = &rt->r[(by4 & 31) + 5];
398         n_cols = scan_col(mvstack, cnt, ref, gmv, b_left,
399                           bh4, h4, bx4 - 1, max_cols, bh4 >= 16 ? 4 : 1,
400                           &have_newmv, &have_col_mvs);
401     }
402 
403     // top/right
404     if (n_rows != ~0U && edge_flags & EDGE_I444_TOP_HAS_RIGHT &&
405         imax(bw4, bh4) <= 16 && bw4 + bx4 < rt->tile_col.end)
406     {
407         add_spatial_candidate(mvstack, cnt, 4, &b_top[bw4], ref, gmv,
408                               &have_newmv, &have_row_mvs);
409     }
410 
411     const int nearest_match = have_col_mvs + have_row_mvs;
412     const int nearest_cnt = *cnt;
413     for (int n = 0; n < nearest_cnt; n++)
414         mvstack[n].weight += 640;
415 
416     // temporal
417     int globalmv_ctx = rf->frm_hdr->use_ref_frame_mvs;
418     if (rf->use_ref_frame_mvs) {
419         const ptrdiff_t stride = rf->rp_stride;
420         const int by8 = by4 >> 1, bx8 = bx4 >> 1;
421         const refmvs_temporal_block *const rbi = &rt->rp_proj[(by8 & 15) * stride + bx8];
422         const refmvs_temporal_block *rb = rbi;
423         const int step_h = bw4 >= 16 ? 2 : 1, step_v = bh4 >= 16 ? 2 : 1;
424         const int w8 = imin((w4 + 1) >> 1, 8), h8 = imin((h4 + 1) >> 1, 8);
425         for (int y = 0; y < h8; y += step_v) {
426             for (int x = 0; x < w8; x+= step_h) {
427                 add_temporal_candidate(rf, mvstack, cnt, &rb[x], ref,
428                                        !(x | y) ? &globalmv_ctx : NULL, tgmv);
429             }
430             rb += stride * step_v;
431         }
432         if (imin(bw4, bh4) >= 2 && imax(bw4, bh4) < 16) {
433             const int bh8 = bh4 >> 1, bw8 = bw4 >> 1;
434             rb = &rbi[bh8 * stride];
435             const int has_bottom = by8 + bh8 < imin(rt->tile_row.end >> 1,
436                                                     (by8 & ~7) + 8);
437             if (has_bottom && bx8 - 1 >= imax(rt->tile_col.start >> 1, bx8 & ~7)) {
438                 add_temporal_candidate(rf, mvstack, cnt, &rb[-1], ref,
439                                        NULL, NULL);
440             }
441             if (bx8 + bw8 < imin(rt->tile_col.end >> 1, (bx8 & ~7) + 8)) {
442                 if (has_bottom) {
443                     add_temporal_candidate(rf, mvstack, cnt, &rb[bw8], ref,
444                                            NULL, NULL);
445                 }
446                 if (by8 + bh8 - 1 < imin(rt->tile_row.end >> 1, (by8 & ~7) + 8)) {
447                     add_temporal_candidate(rf, mvstack, cnt, &rb[bw8 - stride],
448                                            ref, NULL, NULL);
449                 }
450             }
451         }
452     }
453     assert(*cnt <= 8);
454 
455     // top/left (which, confusingly, is part of "secondary" references)
456     int have_dummy_newmv_match;
457     if ((n_rows | n_cols) != ~0U) {
458         add_spatial_candidate(mvstack, cnt, 4, &b_top[-1], ref, gmv,
459                               &have_dummy_newmv_match, &have_row_mvs);
460     }
461 
462     // "secondary" (non-direct neighbour) top & left edges
463     // what is different about secondary is that everything is now in 8x8 resolution
464     for (int n = 2; n <= 3; n++) {
465         if ((unsigned) n > n_rows && (unsigned) n <= max_rows) {
466             n_rows += scan_row(mvstack, cnt, ref, gmv,
467                                &rt->r[(((by4 & 31) - 2 * n + 1) | 1) + 5][bx4 | 1],
468                                bw4, w4, 1 + max_rows - n, bw4 >= 16 ? 4 : 2,
469                                &have_dummy_newmv_match, &have_row_mvs);
470         }
471 
472         if ((unsigned) n > n_cols && (unsigned) n <= max_cols) {
473             n_cols += scan_col(mvstack, cnt, ref, gmv, &rt->r[((by4 & 31) | 1) + 5],
474                                bh4, h4, (bx4 - n * 2 + 1) | 1,
475                                1 + max_cols - n, bh4 >= 16 ? 4 : 2,
476                                &have_dummy_newmv_match, &have_col_mvs);
477         }
478     }
479     assert(*cnt <= 8);
480 
481     const int ref_match_count = have_col_mvs + have_row_mvs;
482 
483     // context build-up
484     int refmv_ctx, newmv_ctx;
485     switch (nearest_match) {
486     case 0:
487         refmv_ctx = imin(2, ref_match_count);
488         newmv_ctx = ref_match_count > 0;
489         break;
490     case 1:
491         refmv_ctx = imin(ref_match_count * 3, 4);
492         newmv_ctx = 3 - have_newmv;
493         break;
494     case 2:
495         refmv_ctx = 5;
496         newmv_ctx = 5 - have_newmv;
497         break;
498     }
499 
500     // sorting (nearest, then "secondary")
501     int len = nearest_cnt;
502     while (len) {
503         int last = 0;
504         for (int n = 1; n < len; n++) {
505             if (mvstack[n - 1].weight < mvstack[n].weight) {
506 #define EXCHANGE(a, b) do { refmvs_candidate tmp = a; a = b; b = tmp; } while (0)
507                 EXCHANGE(mvstack[n - 1], mvstack[n]);
508                 last = n;
509             }
510         }
511         len = last;
512     }
513     len = *cnt;
514     while (len > nearest_cnt) {
515         int last = nearest_cnt;
516         for (int n = nearest_cnt + 1; n < len; n++) {
517             if (mvstack[n - 1].weight < mvstack[n].weight) {
518                 EXCHANGE(mvstack[n - 1], mvstack[n]);
519 #undef EXCHANGE
520                 last = n;
521             }
522         }
523         len = last;
524     }
525 
526     if (ref.ref[1] > 0) {
527         if (*cnt < 2) {
528             const int sign0 = rf->sign_bias[ref.ref[0] - 1];
529             const int sign1 = rf->sign_bias[ref.ref[1] - 1];
530             const int sz4 = imin(w4, h4);
531             refmvs_candidate *const same = &mvstack[*cnt];
532             int same_count[4] = { 0 };
533 
534             // non-self references in top
535             if (n_rows != ~0U) for (int x = 0; x < sz4;) {
536                 const refmvs_block *const cand_b = &b_top[x];
537                 add_compound_extended_candidate(same, same_count, cand_b,
538                                                 sign0, sign1, ref, rf->sign_bias);
539                 x += dav1d_block_dimensions[cand_b->bs][0];
540             }
541 
542             // non-self references in left
543             if (n_cols != ~0U) for (int y = 0; y < sz4;) {
544                 const refmvs_block *const cand_b = &b_left[y][bx4 - 1];
545                 add_compound_extended_candidate(same, same_count, cand_b,
546                                                 sign0, sign1, ref, rf->sign_bias);
547                 y += dav1d_block_dimensions[cand_b->bs][1];
548             }
549 
550             refmvs_candidate *const diff = &same[2];
551             const int *const diff_count = &same_count[2];
552 
553             // merge together
554             for (int n = 0; n < 2; n++) {
555                 int m = same_count[n];
556 
557                 if (m >= 2) continue;
558 
559                 const int l = diff_count[n];
560                 if (l) {
561                     same[m].mv.mv[n] = diff[0].mv.mv[n];
562                     if (++m == 2) continue;
563                     if (l == 2) {
564                         same[1].mv.mv[n] = diff[1].mv.mv[n];
565                         continue;
566                     }
567                 }
568                 do {
569                     same[m].mv.mv[n] = tgmv[n];
570                 } while (++m < 2);
571             }
572 
573             // if the first extended was the same as the non-extended one,
574             // then replace it with the second extended one
575             int n = *cnt;
576             if (n == 1 && mvstack[0].mv.n == same[0].mv.n)
577                 mvstack[1].mv = mvstack[2].mv;
578             do {
579                 mvstack[n].weight = 2;
580             } while (++n < 2);
581             *cnt = 2;
582         }
583 
584         // clamping
585         const int left = -(bx4 + bw4 + 4) * 4 * 8;
586         const int right = (rf->iw4 - bx4 + 4) * 4 * 8;
587         const int top = -(by4 + bh4 + 4) * 4 * 8;
588         const int bottom = (rf->ih4 - by4 + 4) * 4 * 8;
589 
590         const int n_refmvs = *cnt;
591         int n = 0;
592         do {
593             mvstack[n].mv.mv[0].x = iclip(mvstack[n].mv.mv[0].x, left, right);
594             mvstack[n].mv.mv[0].y = iclip(mvstack[n].mv.mv[0].y, top, bottom);
595             mvstack[n].mv.mv[1].x = iclip(mvstack[n].mv.mv[1].x, left, right);
596             mvstack[n].mv.mv[1].y = iclip(mvstack[n].mv.mv[1].y, top, bottom);
597         } while (++n < n_refmvs);
598 
599         switch (refmv_ctx >> 1) {
600         case 0:
601             *ctx = imin(newmv_ctx, 1);
602             break;
603         case 1:
604             *ctx = 1 + imin(newmv_ctx, 3);
605             break;
606         case 2:
607             *ctx = iclip(3 + newmv_ctx, 4, 7);
608             break;
609         }
610 
611         return;
612     } else if (*cnt < 2 && ref.ref[0] > 0) {
613         const int sign = rf->sign_bias[ref.ref[0] - 1];
614         const int sz4 = imin(w4, h4);
615 
616         // non-self references in top
617         if (n_rows != ~0U) for (int x = 0; x < sz4 && *cnt < 2;) {
618             const refmvs_block *const cand_b = &b_top[x];
619             add_single_extended_candidate(mvstack, cnt, cand_b, sign, rf->sign_bias);
620             x += dav1d_block_dimensions[cand_b->bs][0];
621         }
622 
623         // non-self references in left
624         if (n_cols != ~0U) for (int y = 0; y < sz4 && *cnt < 2;) {
625             const refmvs_block *const cand_b = &b_left[y][bx4 - 1];
626             add_single_extended_candidate(mvstack, cnt, cand_b, sign, rf->sign_bias);
627             y += dav1d_block_dimensions[cand_b->bs][1];
628         }
629     }
630     assert(*cnt <= 8);
631 
632     // clamping
633     int n_refmvs = *cnt;
634     if (n_refmvs) {
635         const int left = -(bx4 + bw4 + 4) * 4 * 8;
636         const int right = (rf->iw4 - bx4 + 4) * 4 * 8;
637         const int top = -(by4 + bh4 + 4) * 4 * 8;
638         const int bottom = (rf->ih4 - by4 + 4) * 4 * 8;
639 
640         int n = 0;
641         do {
642             mvstack[n].mv.mv[0].x = iclip(mvstack[n].mv.mv[0].x, left, right);
643             mvstack[n].mv.mv[0].y = iclip(mvstack[n].mv.mv[0].y, top, bottom);
644         } while (++n < n_refmvs);
645     }
646 
647     for (int n = *cnt; n < 2; n++)
648         mvstack[n].mv.mv[0] = tgmv[0];
649 
650     *ctx = (refmv_ctx << 4) | (globalmv_ctx << 3) | newmv_ctx;
651 }
652 
dav1d_refmvs_tile_sbrow_init(refmvs_tile * const rt,const refmvs_frame * const rf,const int tile_col_start4,const int tile_col_end4,const int tile_row_start4,const int tile_row_end4,const int sby,int tile_row_idx,const int pass)653 void dav1d_refmvs_tile_sbrow_init(refmvs_tile *const rt, const refmvs_frame *const rf,
654                                   const int tile_col_start4, const int tile_col_end4,
655                                   const int tile_row_start4, const int tile_row_end4,
656                                   const int sby, int tile_row_idx, const int pass)
657 {
658     if (rf->n_tile_threads == 1) tile_row_idx = 0;
659     rt->rp_proj = &rf->rp_proj[16 * rf->rp_stride * tile_row_idx];
660     const ptrdiff_t r_stride = rf->rp_stride * 2;
661     const ptrdiff_t pass_off = (rf->n_frame_threads > 1 && pass == 2) ?
662         35 * 2 * rf->n_blocks : 0;
663     refmvs_block *r = &rf->r[35 * r_stride * tile_row_idx + pass_off];
664     const int sbsz = rf->sbsz;
665     const int off = (sbsz * sby) & 16;
666     for (int i = 0; i < sbsz; i++, r += r_stride)
667         rt->r[off + 5 + i] = r;
668     rt->r[off + 0] = r;
669     r += r_stride;
670     rt->r[off + 1] = NULL;
671     rt->r[off + 2] = r;
672     r += r_stride;
673     rt->r[off + 3] = NULL;
674     rt->r[off + 4] = r;
675     if (sby & 1) {
676 #define EXCHANGE(a, b) do { void *const tmp = a; a = b; b = tmp; } while (0)
677         EXCHANGE(rt->r[off + 0], rt->r[off + sbsz + 0]);
678         EXCHANGE(rt->r[off + 2], rt->r[off + sbsz + 2]);
679         EXCHANGE(rt->r[off + 4], rt->r[off + sbsz + 4]);
680 #undef EXCHANGE
681     }
682 
683     rt->rf = rf;
684     rt->tile_row.start = tile_row_start4;
685     rt->tile_row.end = imin(tile_row_end4, rf->ih4);
686     rt->tile_col.start = tile_col_start4;
687     rt->tile_col.end = imin(tile_col_end4, rf->iw4);
688 }
689 
load_tmvs_c(const refmvs_frame * const rf,int tile_row_idx,const int col_start8,const int col_end8,const int row_start8,int row_end8)690 static void load_tmvs_c(const refmvs_frame *const rf, int tile_row_idx,
691                         const int col_start8, const int col_end8,
692                         const int row_start8, int row_end8)
693 {
694     if (rf->n_tile_threads == 1) tile_row_idx = 0;
695     assert(row_start8 >= 0);
696     assert((unsigned) (row_end8 - row_start8) <= 16U);
697     row_end8 = imin(row_end8, rf->ih8);
698     const int col_start8i = imax(col_start8 - 8, 0);
699     const int col_end8i = imin(col_end8 + 8, rf->iw8);
700 
701     const ptrdiff_t stride = rf->rp_stride;
702     refmvs_temporal_block *rp_proj =
703         &rf->rp_proj[16 * stride * tile_row_idx + (row_start8 & 15) * stride];
704     for (int y = row_start8; y < row_end8; y++) {
705         for (int x = col_start8; x < col_end8; x++)
706             rp_proj[x].mv.n = INVALID_MV;
707         rp_proj += stride;
708     }
709 
710     rp_proj = &rf->rp_proj[16 * stride * tile_row_idx];
711     for (int n = 0; n < rf->n_mfmvs; n++) {
712         const int ref2cur = rf->mfmv_ref2cur[n];
713         if (ref2cur == INT_MIN) continue;
714 
715         const int ref = rf->mfmv_ref[n];
716         const int ref_sign = ref - 4;
717         const refmvs_temporal_block *r = &rf->rp_ref[ref][row_start8 * stride];
718         for (int y = row_start8; y < row_end8; y++) {
719             const int y_sb_align = y & ~7;
720             const int y_proj_start = imax(y_sb_align, row_start8);
721             const int y_proj_end = imin(y_sb_align + 8, row_end8);
722             for (int x = col_start8i; x < col_end8i; x++) {
723                 const refmvs_temporal_block *rb = &r[x];
724                 const int b_ref = rb->ref;
725                 if (!b_ref) continue;
726                 const int ref2ref = rf->mfmv_ref2ref[n][b_ref - 1];
727                 if (!ref2ref) continue;
728                 const mv b_mv = rb->mv;
729                 const mv offset = mv_projection(b_mv, ref2cur, ref2ref);
730                 int pos_x = x + apply_sign(abs(offset.x) >> 6,
731                                            offset.x ^ ref_sign);
732                 const int pos_y = y + apply_sign(abs(offset.y) >> 6,
733                                                  offset.y ^ ref_sign);
734                 if (pos_y >= y_proj_start && pos_y < y_proj_end) {
735                     const ptrdiff_t pos = (pos_y & 15) * stride;
736                     for (;;) {
737                         const int x_sb_align = x & ~7;
738                         if (pos_x >= imax(x_sb_align - 8, col_start8) &&
739                             pos_x < imin(x_sb_align + 16, col_end8))
740                         {
741                             rp_proj[pos + pos_x].mv = rb->mv;
742                             rp_proj[pos + pos_x].ref = ref2ref;
743                         }
744                         if (++x >= col_end8i) break;
745                         rb++;
746                         if (rb->ref != b_ref || rb->mv.n != b_mv.n) break;
747                         pos_x++;
748                     }
749                 } else {
750                     for (;;) {
751                         if (++x >= col_end8i) break;
752                         rb++;
753                         if (rb->ref != b_ref || rb->mv.n != b_mv.n) break;
754                     }
755                 }
756                 x--;
757             }
758             r += stride;
759         }
760     }
761 }
762 
save_tmvs_c(refmvs_temporal_block * rp,const ptrdiff_t stride,refmvs_block * const * const rr,const uint8_t * const ref_sign,const int col_end8,const int row_end8,const int col_start8,const int row_start8)763 static void save_tmvs_c(refmvs_temporal_block *rp, const ptrdiff_t stride,
764                         refmvs_block *const *const rr,
765                         const uint8_t *const ref_sign,
766                         const int col_end8, const int row_end8,
767                         const int col_start8, const int row_start8)
768 {
769     for (int y = row_start8; y < row_end8; y++) {
770         const refmvs_block *const b = rr[(y & 15) * 2];
771 
772         for (int x = col_start8; x < col_end8;) {
773             const refmvs_block *const cand_b = &b[x * 2 + 1];
774             const int bw8 = (dav1d_block_dimensions[cand_b->bs][0] + 1) >> 1;
775 
776             if (cand_b->ref.ref[1] > 0 && ref_sign[cand_b->ref.ref[1] - 1] &&
777                 (abs(cand_b->mv.mv[1].y) | abs(cand_b->mv.mv[1].x)) < 4096)
778             {
779                 for (int n = 0; n < bw8; n++, x++)
780                     rp[x] = (refmvs_temporal_block) { .mv = cand_b->mv.mv[1],
781                                                       .ref = cand_b->ref.ref[1] };
782             } else if (cand_b->ref.ref[0] > 0 && ref_sign[cand_b->ref.ref[0] - 1] &&
783                        (abs(cand_b->mv.mv[0].y) | abs(cand_b->mv.mv[0].x)) < 4096)
784             {
785                 for (int n = 0; n < bw8; n++, x++)
786                     rp[x] = (refmvs_temporal_block) { .mv = cand_b->mv.mv[0],
787                                                       .ref = cand_b->ref.ref[0] };
788             } else {
789                 for (int n = 0; n < bw8; n++, x++) {
790                     rp[x].mv.n = 0;
791                     rp[x].ref = 0; // "invalid"
792                 }
793             }
794         }
795         rp += stride;
796     }
797 }
798 
dav1d_refmvs_init_frame(refmvs_frame * const rf,const Dav1dSequenceHeader * const seq_hdr,const Dav1dFrameHeader * const frm_hdr,const unsigned ref_poc[7],refmvs_temporal_block * const rp,const unsigned ref_ref_poc[7][7],refmvs_temporal_block * const rp_ref[7],const int n_tile_threads,const int n_frame_threads)799 int dav1d_refmvs_init_frame(refmvs_frame *const rf,
800                             const Dav1dSequenceHeader *const seq_hdr,
801                             const Dav1dFrameHeader *const frm_hdr,
802                             const unsigned ref_poc[7],
803                             refmvs_temporal_block *const rp,
804                             const unsigned ref_ref_poc[7][7],
805                             /*const*/ refmvs_temporal_block *const rp_ref[7],
806                             const int n_tile_threads, const int n_frame_threads)
807 {
808     const int rp_stride = ((frm_hdr->width[0] + 127) & ~127) >> 3;
809     const int n_tile_rows = n_tile_threads > 1 ? frm_hdr->tiling.rows : 1;
810     const int n_blocks = rp_stride * n_tile_rows;
811 
812     rf->sbsz = 16 << seq_hdr->sb128;
813     rf->frm_hdr = frm_hdr;
814     rf->iw8 = (frm_hdr->width[0] + 7) >> 3;
815     rf->ih8 = (frm_hdr->height + 7) >> 3;
816     rf->iw4 = rf->iw8 << 1;
817     rf->ih4 = rf->ih8 << 1;
818     rf->rp = rp;
819     rf->rp_stride = rp_stride;
820     rf->n_tile_threads = n_tile_threads;
821     rf->n_frame_threads = n_frame_threads;
822 
823     if (n_blocks != rf->n_blocks) {
824         const size_t r_sz = sizeof(*rf->r) * 35 * 2 * n_blocks * (1 + (n_frame_threads > 1));
825         const size_t rp_proj_sz = sizeof(*rf->rp_proj) * 16 * n_blocks;
826         /* Note that sizeof(*rf->r) == 12, but it's accessed using 16-byte unaligned
827          * loads in save_tmvs() asm which can overread 4 bytes into rp_proj. */
828         dav1d_free_aligned(rf->r);
829         rf->r = dav1d_alloc_aligned(ALLOC_REFMVS, r_sz + rp_proj_sz, 64);
830         if (!rf->r) {
831             rf->n_blocks = 0;
832             return DAV1D_ERR(ENOMEM);
833         }
834 
835         rf->rp_proj = (refmvs_temporal_block*)((uintptr_t)rf->r + r_sz);
836         rf->n_blocks = n_blocks;
837     }
838 
839     const unsigned poc = frm_hdr->frame_offset;
840     for (int i = 0; i < 7; i++) {
841         const int poc_diff = get_poc_diff(seq_hdr->order_hint_n_bits,
842                                           ref_poc[i], poc);
843         rf->sign_bias[i] = poc_diff > 0;
844         rf->mfmv_sign[i] = poc_diff < 0;
845         rf->pocdiff[i] = iclip(get_poc_diff(seq_hdr->order_hint_n_bits,
846                                             poc, ref_poc[i]), -31, 31);
847     }
848 
849     // temporal MV setup
850     rf->n_mfmvs = 0;
851     rf->rp_ref = rp_ref;
852     if (frm_hdr->use_ref_frame_mvs && seq_hdr->order_hint_n_bits) {
853         int total = 2;
854         if (rp_ref[0] && ref_ref_poc[0][6] != ref_poc[3] /* alt-of-last != gold */) {
855             rf->mfmv_ref[rf->n_mfmvs++] = 0; // last
856             total = 3;
857         }
858         if (rp_ref[4] && get_poc_diff(seq_hdr->order_hint_n_bits, ref_poc[4],
859                                       frm_hdr->frame_offset) > 0)
860         {
861             rf->mfmv_ref[rf->n_mfmvs++] = 4; // bwd
862         }
863         if (rp_ref[5] && get_poc_diff(seq_hdr->order_hint_n_bits, ref_poc[5],
864                                       frm_hdr->frame_offset) > 0)
865         {
866             rf->mfmv_ref[rf->n_mfmvs++] = 5; // altref2
867         }
868         if (rf->n_mfmvs < total && rp_ref[6] &&
869             get_poc_diff(seq_hdr->order_hint_n_bits, ref_poc[6],
870                          frm_hdr->frame_offset) > 0)
871         {
872             rf->mfmv_ref[rf->n_mfmvs++] = 6; // altref
873         }
874         if (rf->n_mfmvs < total && rp_ref[1])
875             rf->mfmv_ref[rf->n_mfmvs++] = 1; // last2
876 
877         for (int n = 0; n < rf->n_mfmvs; n++) {
878             const unsigned rpoc = ref_poc[rf->mfmv_ref[n]];
879             const int diff1 = get_poc_diff(seq_hdr->order_hint_n_bits,
880                                            rpoc, frm_hdr->frame_offset);
881             if (abs(diff1) > 31) {
882                 rf->mfmv_ref2cur[n] = INT_MIN;
883             } else {
884                 rf->mfmv_ref2cur[n] = rf->mfmv_ref[n] < 4 ? -diff1 : diff1;
885                 for (int m = 0; m < 7; m++) {
886                     const unsigned rrpoc = ref_ref_poc[rf->mfmv_ref[n]][m];
887                     const int diff2 = get_poc_diff(seq_hdr->order_hint_n_bits,
888                                                    rpoc, rrpoc);
889                     // unsigned comparison also catches the < 0 case
890                     rf->mfmv_ref2ref[n][m] = (unsigned) diff2 > 31U ? 0 : diff2;
891                 }
892             }
893         }
894     }
895     rf->use_ref_frame_mvs = rf->n_mfmvs > 0;
896 
897     return 0;
898 }
899 
splat_mv_c(refmvs_block ** rr,const refmvs_block * const rmv,const int bx4,const int bw4,int bh4)900 static void splat_mv_c(refmvs_block **rr, const refmvs_block *const rmv,
901                        const int bx4, const int bw4, int bh4)
902 {
903     do {
904         refmvs_block *const r = *rr++ + bx4;
905         for (int x = 0; x < bw4; x++)
906             r[x] = *rmv;
907     } while (--bh4);
908 }
909 
910 #if HAVE_ASM
911 #if ARCH_AARCH64 || ARCH_ARM
912 #include "src/arm/refmvs.h"
913 #elif ARCH_LOONGARCH64
914 #include "src/loongarch/refmvs.h"
915 #elif ARCH_X86
916 #include "src/x86/refmvs.h"
917 #endif
918 #endif
919 
dav1d_refmvs_dsp_init(Dav1dRefmvsDSPContext * const c)920 COLD void dav1d_refmvs_dsp_init(Dav1dRefmvsDSPContext *const c)
921 {
922     c->load_tmvs = load_tmvs_c;
923     c->save_tmvs = save_tmvs_c;
924     c->splat_mv = splat_mv_c;
925 
926 #if HAVE_ASM
927 #if ARCH_AARCH64 || ARCH_ARM
928     refmvs_dsp_init_arm(c);
929 #elif ARCH_LOONGARCH64
930     refmvs_dsp_init_loongarch(c);
931 #elif ARCH_X86
932     refmvs_dsp_init_x86(c);
933 #endif
934 #endif
935 }
936