xref: /aosp_15_r20/external/libaom/aom_dsp/flow_estimation/corner_match.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <stdlib.h>
13 #include <memory.h>
14 #include <math.h>
15 
16 #include "config/aom_dsp_rtcd.h"
17 
18 #include "aom_dsp/flow_estimation/corner_detect.h"
19 #include "aom_dsp/flow_estimation/corner_match.h"
20 #include "aom_dsp/flow_estimation/disflow.h"
21 #include "aom_dsp/flow_estimation/flow_estimation.h"
22 #include "aom_dsp/flow_estimation/ransac.h"
23 #include "aom_dsp/pyramid.h"
24 #include "aom_scale/yv12config.h"
25 
26 #define THRESHOLD_NCC 0.75
27 
28 /* Compute mean and standard deviation of pixels in a window of size
29    MATCH_SZ by MATCH_SZ centered at (x, y).
30    Store results into *mean and *one_over_stddev
31 
32    Note: The output of this function is scaled by MATCH_SZ, as in
33    *mean = MATCH_SZ * <true mean> and
34    *one_over_stddev = 1 / (MATCH_SZ * <true stddev>)
35 
36    Combined with the fact that we return 1/stddev rather than the standard
37    deviation itself, this allows us to completely avoid divisions in
38    aom_compute_correlation, which is much hotter than this function is.
39 
40    Returns true if this feature point is usable, false otherwise.
41 */
aom_compute_mean_stddev_c(const unsigned char * frame,int stride,int x,int y,double * mean,double * one_over_stddev)42 bool aom_compute_mean_stddev_c(const unsigned char *frame, int stride, int x,
43                                int y, double *mean, double *one_over_stddev) {
44   int sum = 0;
45   int sumsq = 0;
46   for (int i = 0; i < MATCH_SZ; ++i) {
47     for (int j = 0; j < MATCH_SZ; ++j) {
48       sum += frame[(i + y - MATCH_SZ_BY2) * stride + (j + x - MATCH_SZ_BY2)];
49       sumsq += frame[(i + y - MATCH_SZ_BY2) * stride + (j + x - MATCH_SZ_BY2)] *
50                frame[(i + y - MATCH_SZ_BY2) * stride + (j + x - MATCH_SZ_BY2)];
51     }
52   }
53   *mean = (double)sum / MATCH_SZ;
54   const double variance = sumsq - (*mean) * (*mean);
55   if (variance < MIN_FEATURE_VARIANCE) {
56     *one_over_stddev = 0.0;
57     return false;
58   }
59   *one_over_stddev = 1.0 / sqrt(variance);
60   return true;
61 }
62 
63 /* Compute corr(frame1, frame2) over a window of size MATCH_SZ by MATCH_SZ.
64    To save on computation, the mean and (1 divided by the) standard deviation
65    of the window in each frame are precomputed and passed into this function
66    as arguments.
67 */
aom_compute_correlation_c(const unsigned char * frame1,int stride1,int x1,int y1,double mean1,double one_over_stddev1,const unsigned char * frame2,int stride2,int x2,int y2,double mean2,double one_over_stddev2)68 double aom_compute_correlation_c(const unsigned char *frame1, int stride1,
69                                  int x1, int y1, double mean1,
70                                  double one_over_stddev1,
71                                  const unsigned char *frame2, int stride2,
72                                  int x2, int y2, double mean2,
73                                  double one_over_stddev2) {
74   int v1, v2;
75   int cross = 0;
76   for (int i = 0; i < MATCH_SZ; ++i) {
77     for (int j = 0; j < MATCH_SZ; ++j) {
78       v1 = frame1[(i + y1 - MATCH_SZ_BY2) * stride1 + (j + x1 - MATCH_SZ_BY2)];
79       v2 = frame2[(i + y2 - MATCH_SZ_BY2) * stride2 + (j + x2 - MATCH_SZ_BY2)];
80       cross += v1 * v2;
81     }
82   }
83 
84   // Note: In theory, the calculations here "should" be
85   //   covariance = cross / N^2 - mean1 * mean2
86   //   correlation = covariance / (stddev1 * stddev2).
87   //
88   // However, because of the scaling in aom_compute_mean_stddev, the
89   // lines below actually calculate
90   //   covariance * N^2 = cross - (mean1 * N) * (mean2 * N)
91   //   correlation = (covariance * N^2) / ((stddev1 * N) * (stddev2 * N))
92   //
93   // ie. we have removed the need for a division, and still end up with the
94   // correct unscaled correlation (ie, in the range [-1, +1])
95   double covariance = cross - mean1 * mean2;
96   double correlation = covariance * (one_over_stddev1 * one_over_stddev2);
97   return correlation;
98 }
99 
is_eligible_point(int pointx,int pointy,int width,int height)100 static int is_eligible_point(int pointx, int pointy, int width, int height) {
101   return (pointx >= MATCH_SZ_BY2 && pointy >= MATCH_SZ_BY2 &&
102           pointx + MATCH_SZ_BY2 < width && pointy + MATCH_SZ_BY2 < height);
103 }
104 
is_eligible_distance(int point1x,int point1y,int point2x,int point2y,int width,int height)105 static int is_eligible_distance(int point1x, int point1y, int point2x,
106                                 int point2y, int width, int height) {
107   const int thresh = (width < height ? height : width) >> 4;
108   return ((point1x - point2x) * (point1x - point2x) +
109           (point1y - point2y) * (point1y - point2y)) <= thresh * thresh;
110 }
111 
112 typedef struct {
113   int x;
114   int y;
115   double mean;
116   double one_over_stddev;
117   int best_match_idx;
118   double best_match_corr;
119 } PointInfo;
120 
determine_correspondence(const unsigned char * src,const int * src_corners,int num_src_corners,const unsigned char * ref,const int * ref_corners,int num_ref_corners,int width,int height,int src_stride,int ref_stride,Correspondence * correspondences)121 static int determine_correspondence(const unsigned char *src,
122                                     const int *src_corners, int num_src_corners,
123                                     const unsigned char *ref,
124                                     const int *ref_corners, int num_ref_corners,
125                                     int width, int height, int src_stride,
126                                     int ref_stride,
127                                     Correspondence *correspondences) {
128   PointInfo *src_point_info = NULL;
129   PointInfo *ref_point_info = NULL;
130   int num_correspondences = 0;
131 
132   src_point_info =
133       (PointInfo *)aom_calloc(num_src_corners, sizeof(*src_point_info));
134   if (!src_point_info) {
135     goto finished;
136   }
137 
138   ref_point_info =
139       (PointInfo *)aom_calloc(num_ref_corners, sizeof(*ref_point_info));
140   if (!ref_point_info) {
141     goto finished;
142   }
143 
144   // First pass (linear):
145   // Filter corner lists and compute per-patch means and standard deviations,
146   // for the src and ref frames independently
147   int src_point_count = 0;
148   for (int i = 0; i < num_src_corners; i++) {
149     int src_x = src_corners[2 * i];
150     int src_y = src_corners[2 * i + 1];
151     if (!is_eligible_point(src_x, src_y, width, height)) continue;
152 
153     PointInfo *point = &src_point_info[src_point_count];
154     point->x = src_x;
155     point->y = src_y;
156     point->best_match_corr = THRESHOLD_NCC;
157     if (!aom_compute_mean_stddev(src, src_stride, src_x, src_y, &point->mean,
158                                  &point->one_over_stddev))
159       continue;
160     src_point_count++;
161   }
162   if (src_point_count == 0) {
163     goto finished;
164   }
165 
166   int ref_point_count = 0;
167   for (int j = 0; j < num_ref_corners; j++) {
168     int ref_x = ref_corners[2 * j];
169     int ref_y = ref_corners[2 * j + 1];
170     if (!is_eligible_point(ref_x, ref_y, width, height)) continue;
171 
172     PointInfo *point = &ref_point_info[ref_point_count];
173     point->x = ref_x;
174     point->y = ref_y;
175     point->best_match_corr = THRESHOLD_NCC;
176     if (!aom_compute_mean_stddev(ref, ref_stride, ref_x, ref_y, &point->mean,
177                                  &point->one_over_stddev))
178       continue;
179     ref_point_count++;
180   }
181   if (ref_point_count == 0) {
182     goto finished;
183   }
184 
185   // Second pass (quadratic):
186   // For each pair of points, compute correlation, and use this to determine
187   // the best match of each corner, in both directions
188   for (int i = 0; i < src_point_count; ++i) {
189     PointInfo *src_point = &src_point_info[i];
190     for (int j = 0; j < ref_point_count; ++j) {
191       PointInfo *ref_point = &ref_point_info[j];
192       if (!is_eligible_distance(src_point->x, src_point->y, ref_point->x,
193                                 ref_point->y, width, height))
194         continue;
195 
196       double corr = aom_compute_correlation(
197           src, src_stride, src_point->x, src_point->y, src_point->mean,
198           src_point->one_over_stddev, ref, ref_stride, ref_point->x,
199           ref_point->y, ref_point->mean, ref_point->one_over_stddev);
200 
201       if (corr > src_point->best_match_corr) {
202         src_point->best_match_idx = j;
203         src_point->best_match_corr = corr;
204       }
205       if (corr > ref_point->best_match_corr) {
206         ref_point->best_match_idx = i;
207         ref_point->best_match_corr = corr;
208       }
209     }
210   }
211 
212   // Third pass (linear):
213   // Scan through source corners, generating a correspondence for each corner
214   // iff ref_best_match[src_best_match[i]] == i
215   // Then refine the generated correspondences using optical flow
216   for (int i = 0; i < src_point_count; i++) {
217     PointInfo *point = &src_point_info[i];
218 
219     // Skip corners which were not matched, or which didn't find
220     // a good enough match
221     if (point->best_match_corr < THRESHOLD_NCC) continue;
222 
223     PointInfo *match_point = &ref_point_info[point->best_match_idx];
224     if (match_point->best_match_idx == i) {
225       // Refine match using optical flow and store
226       const int sx = point->x;
227       const int sy = point->y;
228       const int rx = match_point->x;
229       const int ry = match_point->y;
230       double u = (double)(rx - sx);
231       double v = (double)(ry - sy);
232 
233       const int patch_tl_x = sx - DISFLOW_PATCH_CENTER;
234       const int patch_tl_y = sy - DISFLOW_PATCH_CENTER;
235 
236       aom_compute_flow_at_point(src, ref, patch_tl_x, patch_tl_y, width, height,
237                                 src_stride, &u, &v);
238 
239       Correspondence *correspondence = &correspondences[num_correspondences];
240       correspondence->x = (double)sx;
241       correspondence->y = (double)sy;
242       correspondence->rx = (double)sx + u;
243       correspondence->ry = (double)sy + v;
244       num_correspondences++;
245     }
246   }
247 
248 finished:
249   aom_free(src_point_info);
250   aom_free(ref_point_info);
251   return num_correspondences;
252 }
253 
av1_compute_global_motion_feature_match(TransformationType type,YV12_BUFFER_CONFIG * src,YV12_BUFFER_CONFIG * ref,int bit_depth,int downsample_level,MotionModel * motion_models,int num_motion_models,bool * mem_alloc_failed)254 bool av1_compute_global_motion_feature_match(
255     TransformationType type, YV12_BUFFER_CONFIG *src, YV12_BUFFER_CONFIG *ref,
256     int bit_depth, int downsample_level, MotionModel *motion_models,
257     int num_motion_models, bool *mem_alloc_failed) {
258   int num_correspondences;
259   Correspondence *correspondences;
260   ImagePyramid *src_pyramid = src->y_pyramid;
261   CornerList *src_corners = src->corners;
262   ImagePyramid *ref_pyramid = ref->y_pyramid;
263   CornerList *ref_corners = ref->corners;
264 
265   // Precompute information we will need about each frame
266   if (aom_compute_pyramid(src, bit_depth, 1, src_pyramid) < 0) {
267     *mem_alloc_failed = true;
268     return false;
269   }
270   if (!av1_compute_corner_list(src, bit_depth, downsample_level, src_corners)) {
271     *mem_alloc_failed = true;
272     return false;
273   }
274   if (aom_compute_pyramid(ref, bit_depth, 1, ref_pyramid) < 0) {
275     *mem_alloc_failed = true;
276     return false;
277   }
278   if (!av1_compute_corner_list(src, bit_depth, downsample_level, ref_corners)) {
279     *mem_alloc_failed = true;
280     return false;
281   }
282 
283   const uint8_t *src_buffer = src_pyramid->layers[0].buffer;
284   const int src_width = src_pyramid->layers[0].width;
285   const int src_height = src_pyramid->layers[0].height;
286   const int src_stride = src_pyramid->layers[0].stride;
287 
288   const uint8_t *ref_buffer = ref_pyramid->layers[0].buffer;
289   assert(ref_pyramid->layers[0].width == src_width);
290   assert(ref_pyramid->layers[0].height == src_height);
291   const int ref_stride = ref_pyramid->layers[0].stride;
292 
293   // find correspondences between the two images
294   correspondences = (Correspondence *)aom_malloc(src_corners->num_corners *
295                                                  sizeof(*correspondences));
296   if (!correspondences) {
297     *mem_alloc_failed = true;
298     return false;
299   }
300   num_correspondences = determine_correspondence(
301       src_buffer, src_corners->corners, src_corners->num_corners, ref_buffer,
302       ref_corners->corners, ref_corners->num_corners, src_width, src_height,
303       src_stride, ref_stride, correspondences);
304 
305   bool result = ransac(correspondences, num_correspondences, type,
306                        motion_models, num_motion_models, mem_alloc_failed);
307 
308   aom_free(correspondences);
309   return result;
310 }
311