xref: /aosp_15_r20/external/libvpx/tools/3D-Reconstruction/MotionEST/Exhaust.py (revision fb1b10ab9aebc7c7068eedab379b749d7e3900be)
1*fb1b10abSAndroid Build Coastguard Worker##  Copyright (c) 2020 The WebM project authors. All Rights Reserved.
2*fb1b10abSAndroid Build Coastguard Worker##
3*fb1b10abSAndroid Build Coastguard Worker##  Use of this source code is governed by a BSD-style license
4*fb1b10abSAndroid Build Coastguard Worker##  that can be found in the LICENSE file in the root of the source
5*fb1b10abSAndroid Build Coastguard Worker##  tree. An additional intellectual property rights grant can be found
6*fb1b10abSAndroid Build Coastguard Worker##  in the file PATENTS.  All contributing project authors may
7*fb1b10abSAndroid Build Coastguard Worker##  be found in the AUTHORS file in the root of the source tree.
8*fb1b10abSAndroid Build Coastguard Worker##
9*fb1b10abSAndroid Build Coastguard Worker
10*fb1b10abSAndroid Build Coastguard Worker# coding: utf-8
11*fb1b10abSAndroid Build Coastguard Workerimport numpy as np
12*fb1b10abSAndroid Build Coastguard Workerimport numpy.linalg as LA
13*fb1b10abSAndroid Build Coastguard Workerfrom Util import MSE
14*fb1b10abSAndroid Build Coastguard Workerfrom MotionEST import MotionEST
15*fb1b10abSAndroid Build Coastguard Worker"""Exhaust Search:"""
16*fb1b10abSAndroid Build Coastguard Worker
17*fb1b10abSAndroid Build Coastguard Worker
18*fb1b10abSAndroid Build Coastguard Workerclass Exhaust(MotionEST):
19*fb1b10abSAndroid Build Coastguard Worker  """
20*fb1b10abSAndroid Build Coastguard Worker    Constructor:
21*fb1b10abSAndroid Build Coastguard Worker        cur_f: current frame
22*fb1b10abSAndroid Build Coastguard Worker        ref_f: reference frame
23*fb1b10abSAndroid Build Coastguard Worker        blk_sz: block size
24*fb1b10abSAndroid Build Coastguard Worker        wnd_size: search window size
25*fb1b10abSAndroid Build Coastguard Worker        metric: metric to compare the blocks distrotion
26*fb1b10abSAndroid Build Coastguard Worker    """
27*fb1b10abSAndroid Build Coastguard Worker
28*fb1b10abSAndroid Build Coastguard Worker  def __init__(self, cur_f, ref_f, blk_size, wnd_size, metric=MSE):
29*fb1b10abSAndroid Build Coastguard Worker    self.name = 'exhaust'
30*fb1b10abSAndroid Build Coastguard Worker    self.wnd_sz = wnd_size
31*fb1b10abSAndroid Build Coastguard Worker    self.metric = metric
32*fb1b10abSAndroid Build Coastguard Worker    super(Exhaust, self).__init__(cur_f, ref_f, blk_size)
33*fb1b10abSAndroid Build Coastguard Worker
34*fb1b10abSAndroid Build Coastguard Worker  """
35*fb1b10abSAndroid Build Coastguard Worker    search method:
36*fb1b10abSAndroid Build Coastguard Worker        cur_r: start row
37*fb1b10abSAndroid Build Coastguard Worker        cur_c: start column
38*fb1b10abSAndroid Build Coastguard Worker    """
39*fb1b10abSAndroid Build Coastguard Worker
40*fb1b10abSAndroid Build Coastguard Worker  def search(self, cur_r, cur_c):
41*fb1b10abSAndroid Build Coastguard Worker    min_loss = self.block_dist(cur_r, cur_c, [0, 0], self.metric)
42*fb1b10abSAndroid Build Coastguard Worker    cur_x = cur_c * self.blk_sz
43*fb1b10abSAndroid Build Coastguard Worker    cur_y = cur_r * self.blk_sz
44*fb1b10abSAndroid Build Coastguard Worker    ref_x = cur_x
45*fb1b10abSAndroid Build Coastguard Worker    ref_y = cur_y
46*fb1b10abSAndroid Build Coastguard Worker    #search all validate positions and select the one with minimum distortion
47*fb1b10abSAndroid Build Coastguard Worker    for y in xrange(cur_y - self.wnd_sz, cur_y + self.wnd_sz):
48*fb1b10abSAndroid Build Coastguard Worker      for x in xrange(cur_x - self.wnd_sz, cur_x + self.wnd_sz):
49*fb1b10abSAndroid Build Coastguard Worker        if 0 <= x < self.width - self.blk_sz and 0 <= y < self.height - self.blk_sz:
50*fb1b10abSAndroid Build Coastguard Worker          loss = self.block_dist(cur_r, cur_c, [y - cur_y, x - cur_x],
51*fb1b10abSAndroid Build Coastguard Worker                                 self.metric)
52*fb1b10abSAndroid Build Coastguard Worker          if loss < min_loss:
53*fb1b10abSAndroid Build Coastguard Worker            min_loss = loss
54*fb1b10abSAndroid Build Coastguard Worker            ref_x = x
55*fb1b10abSAndroid Build Coastguard Worker            ref_y = y
56*fb1b10abSAndroid Build Coastguard Worker    return ref_x, ref_y
57*fb1b10abSAndroid Build Coastguard Worker
58*fb1b10abSAndroid Build Coastguard Worker  def motion_field_estimation(self):
59*fb1b10abSAndroid Build Coastguard Worker    for i in xrange(self.num_row):
60*fb1b10abSAndroid Build Coastguard Worker      for j in xrange(self.num_col):
61*fb1b10abSAndroid Build Coastguard Worker        ref_x, ref_y = self.search(i, j)
62*fb1b10abSAndroid Build Coastguard Worker        self.mf[i, j] = np.array(
63*fb1b10abSAndroid Build Coastguard Worker            [ref_y - i * self.blk_sz, ref_x - j * self.blk_sz])
64*fb1b10abSAndroid Build Coastguard Worker
65*fb1b10abSAndroid Build Coastguard Worker
66*fb1b10abSAndroid Build Coastguard Worker"""Exhaust with Neighbor Constraint"""
67*fb1b10abSAndroid Build Coastguard Worker
68*fb1b10abSAndroid Build Coastguard Worker
69*fb1b10abSAndroid Build Coastguard Workerclass ExhaustNeighbor(MotionEST):
70*fb1b10abSAndroid Build Coastguard Worker  """
71*fb1b10abSAndroid Build Coastguard Worker    Constructor:
72*fb1b10abSAndroid Build Coastguard Worker        cur_f: current frame
73*fb1b10abSAndroid Build Coastguard Worker        ref_f: reference frame
74*fb1b10abSAndroid Build Coastguard Worker        blk_sz: block size
75*fb1b10abSAndroid Build Coastguard Worker        wnd_size: search window size
76*fb1b10abSAndroid Build Coastguard Worker        beta: neigbor loss weight
77*fb1b10abSAndroid Build Coastguard Worker        metric: metric to compare the blocks distrotion
78*fb1b10abSAndroid Build Coastguard Worker    """
79*fb1b10abSAndroid Build Coastguard Worker
80*fb1b10abSAndroid Build Coastguard Worker  def __init__(self, cur_f, ref_f, blk_size, wnd_size, beta, metric=MSE):
81*fb1b10abSAndroid Build Coastguard Worker    self.name = 'exhaust + neighbor'
82*fb1b10abSAndroid Build Coastguard Worker    self.wnd_sz = wnd_size
83*fb1b10abSAndroid Build Coastguard Worker    self.beta = beta
84*fb1b10abSAndroid Build Coastguard Worker    self.metric = metric
85*fb1b10abSAndroid Build Coastguard Worker    super(ExhaustNeighbor, self).__init__(cur_f, ref_f, blk_size)
86*fb1b10abSAndroid Build Coastguard Worker    self.assign = np.zeros((self.num_row, self.num_col), dtype=bool)
87*fb1b10abSAndroid Build Coastguard Worker
88*fb1b10abSAndroid Build Coastguard Worker  """
89*fb1b10abSAndroid Build Coastguard Worker    estimate neighbor loss:
90*fb1b10abSAndroid Build Coastguard Worker        cur_r: current row
91*fb1b10abSAndroid Build Coastguard Worker        cur_c: current column
92*fb1b10abSAndroid Build Coastguard Worker        mv: current motion vector
93*fb1b10abSAndroid Build Coastguard Worker    """
94*fb1b10abSAndroid Build Coastguard Worker
95*fb1b10abSAndroid Build Coastguard Worker  def neighborLoss(self, cur_r, cur_c, mv):
96*fb1b10abSAndroid Build Coastguard Worker    loss = 0
97*fb1b10abSAndroid Build Coastguard Worker    #accumulate difference between current block's motion vector with neighbors'
98*fb1b10abSAndroid Build Coastguard Worker    for i, j in {(-1, 0), (1, 0), (0, 1), (0, -1)}:
99*fb1b10abSAndroid Build Coastguard Worker      nb_r = cur_r + i
100*fb1b10abSAndroid Build Coastguard Worker      nb_c = cur_c + j
101*fb1b10abSAndroid Build Coastguard Worker      if 0 <= nb_r < self.num_row and 0 <= nb_c < self.num_col and self.assign[
102*fb1b10abSAndroid Build Coastguard Worker          nb_r, nb_c]:
103*fb1b10abSAndroid Build Coastguard Worker        loss += LA.norm(mv - self.mf[nb_r, nb_c])
104*fb1b10abSAndroid Build Coastguard Worker    return loss
105*fb1b10abSAndroid Build Coastguard Worker
106*fb1b10abSAndroid Build Coastguard Worker  """
107*fb1b10abSAndroid Build Coastguard Worker    search method:
108*fb1b10abSAndroid Build Coastguard Worker        cur_r: start row
109*fb1b10abSAndroid Build Coastguard Worker        cur_c: start column
110*fb1b10abSAndroid Build Coastguard Worker    """
111*fb1b10abSAndroid Build Coastguard Worker
112*fb1b10abSAndroid Build Coastguard Worker  def search(self, cur_r, cur_c):
113*fb1b10abSAndroid Build Coastguard Worker    dist_loss = self.block_dist(cur_r, cur_c, [0, 0], self.metric)
114*fb1b10abSAndroid Build Coastguard Worker    nb_loss = self.neighborLoss(cur_r, cur_c, np.array([0, 0]))
115*fb1b10abSAndroid Build Coastguard Worker    min_loss = dist_loss + self.beta * nb_loss
116*fb1b10abSAndroid Build Coastguard Worker    cur_x = cur_c * self.blk_sz
117*fb1b10abSAndroid Build Coastguard Worker    cur_y = cur_r * self.blk_sz
118*fb1b10abSAndroid Build Coastguard Worker    ref_x = cur_x
119*fb1b10abSAndroid Build Coastguard Worker    ref_y = cur_y
120*fb1b10abSAndroid Build Coastguard Worker    #search all validate positions and select the one with minimum distortion
121*fb1b10abSAndroid Build Coastguard Worker    # as well as weighted neighbor loss
122*fb1b10abSAndroid Build Coastguard Worker    for y in xrange(cur_y - self.wnd_sz, cur_y + self.wnd_sz):
123*fb1b10abSAndroid Build Coastguard Worker      for x in xrange(cur_x - self.wnd_sz, cur_x + self.wnd_sz):
124*fb1b10abSAndroid Build Coastguard Worker        if 0 <= x < self.width - self.blk_sz and 0 <= y < self.height - self.blk_sz:
125*fb1b10abSAndroid Build Coastguard Worker          dist_loss = self.block_dist(cur_r, cur_c, [y - cur_y, x - cur_x],
126*fb1b10abSAndroid Build Coastguard Worker                                      self.metric)
127*fb1b10abSAndroid Build Coastguard Worker          nb_loss = self.neighborLoss(cur_r, cur_c, [y - cur_y, x - cur_x])
128*fb1b10abSAndroid Build Coastguard Worker          loss = dist_loss + self.beta * nb_loss
129*fb1b10abSAndroid Build Coastguard Worker          if loss < min_loss:
130*fb1b10abSAndroid Build Coastguard Worker            min_loss = loss
131*fb1b10abSAndroid Build Coastguard Worker            ref_x = x
132*fb1b10abSAndroid Build Coastguard Worker            ref_y = y
133*fb1b10abSAndroid Build Coastguard Worker    return ref_x, ref_y
134*fb1b10abSAndroid Build Coastguard Worker
135*fb1b10abSAndroid Build Coastguard Worker  def motion_field_estimation(self):
136*fb1b10abSAndroid Build Coastguard Worker    for i in xrange(self.num_row):
137*fb1b10abSAndroid Build Coastguard Worker      for j in xrange(self.num_col):
138*fb1b10abSAndroid Build Coastguard Worker        ref_x, ref_y = self.search(i, j)
139*fb1b10abSAndroid Build Coastguard Worker        self.mf[i, j] = np.array(
140*fb1b10abSAndroid Build Coastguard Worker            [ref_y - i * self.blk_sz, ref_x - j * self.blk_sz])
141*fb1b10abSAndroid Build Coastguard Worker        self.assign[i, j] = True
142*fb1b10abSAndroid Build Coastguard Worker
143*fb1b10abSAndroid Build Coastguard Worker
144*fb1b10abSAndroid Build Coastguard Worker"""Exhaust with Neighbor Constraint and Feature Score"""
145*fb1b10abSAndroid Build Coastguard Worker
146*fb1b10abSAndroid Build Coastguard Worker
147*fb1b10abSAndroid Build Coastguard Workerclass ExhaustNeighborFeatureScore(MotionEST):
148*fb1b10abSAndroid Build Coastguard Worker  """
149*fb1b10abSAndroid Build Coastguard Worker    Constructor:
150*fb1b10abSAndroid Build Coastguard Worker        cur_f: current frame
151*fb1b10abSAndroid Build Coastguard Worker        ref_f: reference frame
152*fb1b10abSAndroid Build Coastguard Worker        blk_sz: block size
153*fb1b10abSAndroid Build Coastguard Worker        wnd_size: search window size
154*fb1b10abSAndroid Build Coastguard Worker        beta: neigbor loss weight
155*fb1b10abSAndroid Build Coastguard Worker        max_iter: maximum number of iterations
156*fb1b10abSAndroid Build Coastguard Worker        metric: metric to compare the blocks distrotion
157*fb1b10abSAndroid Build Coastguard Worker    """
158*fb1b10abSAndroid Build Coastguard Worker
159*fb1b10abSAndroid Build Coastguard Worker  def __init__(self,
160*fb1b10abSAndroid Build Coastguard Worker               cur_f,
161*fb1b10abSAndroid Build Coastguard Worker               ref_f,
162*fb1b10abSAndroid Build Coastguard Worker               blk_size,
163*fb1b10abSAndroid Build Coastguard Worker               wnd_size,
164*fb1b10abSAndroid Build Coastguard Worker               beta=1,
165*fb1b10abSAndroid Build Coastguard Worker               max_iter=100,
166*fb1b10abSAndroid Build Coastguard Worker               metric=MSE):
167*fb1b10abSAndroid Build Coastguard Worker    self.name = 'exhaust + neighbor+feature score'
168*fb1b10abSAndroid Build Coastguard Worker    self.wnd_sz = wnd_size
169*fb1b10abSAndroid Build Coastguard Worker    self.beta = beta
170*fb1b10abSAndroid Build Coastguard Worker    self.metric = metric
171*fb1b10abSAndroid Build Coastguard Worker    self.max_iter = max_iter
172*fb1b10abSAndroid Build Coastguard Worker    super(ExhaustNeighborFeatureScore, self).__init__(cur_f, ref_f, blk_size)
173*fb1b10abSAndroid Build Coastguard Worker    self.fs = self.getFeatureScore()
174*fb1b10abSAndroid Build Coastguard Worker
175*fb1b10abSAndroid Build Coastguard Worker  """
176*fb1b10abSAndroid Build Coastguard Worker    get feature score of each block
177*fb1b10abSAndroid Build Coastguard Worker    """
178*fb1b10abSAndroid Build Coastguard Worker
179*fb1b10abSAndroid Build Coastguard Worker  def getFeatureScore(self):
180*fb1b10abSAndroid Build Coastguard Worker    fs = np.zeros((self.num_row, self.num_col))
181*fb1b10abSAndroid Build Coastguard Worker    for r in xrange(self.num_row):
182*fb1b10abSAndroid Build Coastguard Worker      for c in xrange(self.num_col):
183*fb1b10abSAndroid Build Coastguard Worker        IxIx = 0
184*fb1b10abSAndroid Build Coastguard Worker        IyIy = 0
185*fb1b10abSAndroid Build Coastguard Worker        IxIy = 0
186*fb1b10abSAndroid Build Coastguard Worker        #get ssd surface
187*fb1b10abSAndroid Build Coastguard Worker        for x in xrange(self.blk_sz - 1):
188*fb1b10abSAndroid Build Coastguard Worker          for y in xrange(self.blk_sz - 1):
189*fb1b10abSAndroid Build Coastguard Worker            ox = c * self.blk_sz + x
190*fb1b10abSAndroid Build Coastguard Worker            oy = r * self.blk_sz + y
191*fb1b10abSAndroid Build Coastguard Worker            Ix = self.cur_yuv[oy, ox + 1, 0] - self.cur_yuv[oy, ox, 0]
192*fb1b10abSAndroid Build Coastguard Worker            Iy = self.cur_yuv[oy + 1, ox, 0] - self.cur_yuv[oy, ox, 0]
193*fb1b10abSAndroid Build Coastguard Worker            IxIx += Ix * Ix
194*fb1b10abSAndroid Build Coastguard Worker            IyIy += Iy * Iy
195*fb1b10abSAndroid Build Coastguard Worker            IxIy += Ix * Iy
196*fb1b10abSAndroid Build Coastguard Worker        #get maximum and minimum eigenvalues
197*fb1b10abSAndroid Build Coastguard Worker        lambda_max = 0.5 * ((IxIx + IyIy) + np.sqrt(4 * IxIy * IxIy +
198*fb1b10abSAndroid Build Coastguard Worker                                                    (IxIx - IyIy)**2))
199*fb1b10abSAndroid Build Coastguard Worker        lambda_min = 0.5 * ((IxIx + IyIy) - np.sqrt(4 * IxIy * IxIy +
200*fb1b10abSAndroid Build Coastguard Worker                                                    (IxIx - IyIy)**2))
201*fb1b10abSAndroid Build Coastguard Worker        fs[r, c] = lambda_max * lambda_min / (1e-6 + lambda_max + lambda_min)
202*fb1b10abSAndroid Build Coastguard Worker        if fs[r, c] < 0:
203*fb1b10abSAndroid Build Coastguard Worker          fs[r, c] = 0
204*fb1b10abSAndroid Build Coastguard Worker    return fs
205*fb1b10abSAndroid Build Coastguard Worker
206*fb1b10abSAndroid Build Coastguard Worker  """
207*fb1b10abSAndroid Build Coastguard Worker    do exhaust search
208*fb1b10abSAndroid Build Coastguard Worker    """
209*fb1b10abSAndroid Build Coastguard Worker
210*fb1b10abSAndroid Build Coastguard Worker  def search(self, cur_r, cur_c):
211*fb1b10abSAndroid Build Coastguard Worker    min_loss = self.block_dist(cur_r, cur_c, [0, 0], self.metric)
212*fb1b10abSAndroid Build Coastguard Worker    cur_x = cur_c * self.blk_sz
213*fb1b10abSAndroid Build Coastguard Worker    cur_y = cur_r * self.blk_sz
214*fb1b10abSAndroid Build Coastguard Worker    ref_x = cur_x
215*fb1b10abSAndroid Build Coastguard Worker    ref_y = cur_y
216*fb1b10abSAndroid Build Coastguard Worker    #search all validate positions and select the one with minimum distortion
217*fb1b10abSAndroid Build Coastguard Worker    for y in xrange(cur_y - self.wnd_sz, cur_y + self.wnd_sz):
218*fb1b10abSAndroid Build Coastguard Worker      for x in xrange(cur_x - self.wnd_sz, cur_x + self.wnd_sz):
219*fb1b10abSAndroid Build Coastguard Worker        if 0 <= x < self.width - self.blk_sz and 0 <= y < self.height - self.blk_sz:
220*fb1b10abSAndroid Build Coastguard Worker          loss = self.block_dist(cur_r, cur_c, [y - cur_y, x - cur_x],
221*fb1b10abSAndroid Build Coastguard Worker                                 self.metric)
222*fb1b10abSAndroid Build Coastguard Worker          if loss < min_loss:
223*fb1b10abSAndroid Build Coastguard Worker            min_loss = loss
224*fb1b10abSAndroid Build Coastguard Worker            ref_x = x
225*fb1b10abSAndroid Build Coastguard Worker            ref_y = y
226*fb1b10abSAndroid Build Coastguard Worker    return ref_x, ref_y
227*fb1b10abSAndroid Build Coastguard Worker
228*fb1b10abSAndroid Build Coastguard Worker  """
229*fb1b10abSAndroid Build Coastguard Worker    add smooth constraint
230*fb1b10abSAndroid Build Coastguard Worker    """
231*fb1b10abSAndroid Build Coastguard Worker
232*fb1b10abSAndroid Build Coastguard Worker  def smooth(self, uvs, mvs):
233*fb1b10abSAndroid Build Coastguard Worker    sm_uvs = np.zeros(uvs.shape)
234*fb1b10abSAndroid Build Coastguard Worker    for r in xrange(self.num_row):
235*fb1b10abSAndroid Build Coastguard Worker      for c in xrange(self.num_col):
236*fb1b10abSAndroid Build Coastguard Worker        avg_uv = np.array([0.0, 0.0])
237*fb1b10abSAndroid Build Coastguard Worker        for i, j in {(r - 1, c), (r + 1, c), (r, c - 1), (r, c + 1)}:
238*fb1b10abSAndroid Build Coastguard Worker          if 0 <= i < self.num_row and 0 <= j < self.num_col:
239*fb1b10abSAndroid Build Coastguard Worker            avg_uv += uvs[i, j] / 6.0
240*fb1b10abSAndroid Build Coastguard Worker        for i, j in {(r - 1, c - 1), (r - 1, c + 1), (r + 1, c - 1),
241*fb1b10abSAndroid Build Coastguard Worker                     (r + 1, c + 1)}:
242*fb1b10abSAndroid Build Coastguard Worker          if 0 <= i < self.num_row and 0 <= j < self.num_col:
243*fb1b10abSAndroid Build Coastguard Worker            avg_uv += uvs[i, j] / 12.0
244*fb1b10abSAndroid Build Coastguard Worker        sm_uvs[r, c] = (self.fs[r, c] * mvs[r, c] + self.beta * avg_uv) / (
245*fb1b10abSAndroid Build Coastguard Worker            self.beta + self.fs[r, c])
246*fb1b10abSAndroid Build Coastguard Worker    return sm_uvs
247*fb1b10abSAndroid Build Coastguard Worker
248*fb1b10abSAndroid Build Coastguard Worker  def motion_field_estimation(self):
249*fb1b10abSAndroid Build Coastguard Worker    #get matching results
250*fb1b10abSAndroid Build Coastguard Worker    mvs = np.zeros(self.mf.shape)
251*fb1b10abSAndroid Build Coastguard Worker    for r in xrange(self.num_row):
252*fb1b10abSAndroid Build Coastguard Worker      for c in xrange(self.num_col):
253*fb1b10abSAndroid Build Coastguard Worker        ref_x, ref_y = self.search(r, c)
254*fb1b10abSAndroid Build Coastguard Worker        mvs[r, c] = np.array([ref_y - r * self.blk_sz, ref_x - c * self.blk_sz])
255*fb1b10abSAndroid Build Coastguard Worker    #add smoothness constraint
256*fb1b10abSAndroid Build Coastguard Worker    uvs = np.zeros(self.mf.shape)
257*fb1b10abSAndroid Build Coastguard Worker    for _ in xrange(self.max_iter):
258*fb1b10abSAndroid Build Coastguard Worker      uvs = self.smooth(uvs, mvs)
259*fb1b10abSAndroid Build Coastguard Worker    self.mf = uvs
260