xref: /aosp_15_r20/external/libvpx/tools/3D-Reconstruction/MotionEST/SearchSmooth.py (revision fb1b10ab9aebc7c7068eedab379b749d7e3900be)
1*fb1b10abSAndroid Build Coastguard Worker##  Copyright (c) 2020 The WebM project authors. All Rights Reserved.
2*fb1b10abSAndroid Build Coastguard Worker##
3*fb1b10abSAndroid Build Coastguard Worker##  Use of this source code is governed by a BSD-style license
4*fb1b10abSAndroid Build Coastguard Worker##  that can be found in the LICENSE file in the root of the source
5*fb1b10abSAndroid Build Coastguard Worker##  tree. An additional intellectual property rights grant can be found
6*fb1b10abSAndroid Build Coastguard Worker##  in the file PATENTS.  All contributing project authors may
7*fb1b10abSAndroid Build Coastguard Worker##  be found in the AUTHORS file in the root of the source tree.
8*fb1b10abSAndroid Build Coastguard Worker##
9*fb1b10abSAndroid Build Coastguard Worker
10*fb1b10abSAndroid Build Coastguard Worker# coding: utf-8
11*fb1b10abSAndroid Build Coastguard Workerimport numpy as np
12*fb1b10abSAndroid Build Coastguard Workerimport numpy.linalg as LA
13*fb1b10abSAndroid Build Coastguard Workerfrom Util import MSE
14*fb1b10abSAndroid Build Coastguard Workerfrom MotionEST import MotionEST
15*fb1b10abSAndroid Build Coastguard Worker"""Search & Smooth Model with Adapt Weights"""
16*fb1b10abSAndroid Build Coastguard Worker
17*fb1b10abSAndroid Build Coastguard Worker
18*fb1b10abSAndroid Build Coastguard Workerclass SearchSmoothAdapt(MotionEST):
19*fb1b10abSAndroid Build Coastguard Worker  """
20*fb1b10abSAndroid Build Coastguard Worker    Constructor:
21*fb1b10abSAndroid Build Coastguard Worker        cur_f: current frame
22*fb1b10abSAndroid Build Coastguard Worker        ref_f: reference frame
23*fb1b10abSAndroid Build Coastguard Worker        blk_sz: block size
24*fb1b10abSAndroid Build Coastguard Worker        wnd_size: search window size
25*fb1b10abSAndroid Build Coastguard Worker        beta: neigbor loss weight
26*fb1b10abSAndroid Build Coastguard Worker        max_iter: maximum number of iterations
27*fb1b10abSAndroid Build Coastguard Worker        metric: metric to compare the blocks distrotion
28*fb1b10abSAndroid Build Coastguard Worker    """
29*fb1b10abSAndroid Build Coastguard Worker
30*fb1b10abSAndroid Build Coastguard Worker  def __init__(self, cur_f, ref_f, blk_size, search, max_iter=100):
31*fb1b10abSAndroid Build Coastguard Worker    self.search = search
32*fb1b10abSAndroid Build Coastguard Worker    self.max_iter = max_iter
33*fb1b10abSAndroid Build Coastguard Worker    super(SearchSmoothAdapt, self).__init__(cur_f, ref_f, blk_size)
34*fb1b10abSAndroid Build Coastguard Worker
35*fb1b10abSAndroid Build Coastguard Worker  """
36*fb1b10abSAndroid Build Coastguard Worker    get local diffiencial of refernce
37*fb1b10abSAndroid Build Coastguard Worker    """
38*fb1b10abSAndroid Build Coastguard Worker
39*fb1b10abSAndroid Build Coastguard Worker  def getRefLocalDiff(self, mvs):
40*fb1b10abSAndroid Build Coastguard Worker    m, n = self.num_row, self.num_col
41*fb1b10abSAndroid Build Coastguard Worker    localDiff = [[] for _ in xrange(m)]
42*fb1b10abSAndroid Build Coastguard Worker    blk_sz = self.blk_sz
43*fb1b10abSAndroid Build Coastguard Worker    for r in xrange(m):
44*fb1b10abSAndroid Build Coastguard Worker      for c in xrange(n):
45*fb1b10abSAndroid Build Coastguard Worker        I_row = 0
46*fb1b10abSAndroid Build Coastguard Worker        I_col = 0
47*fb1b10abSAndroid Build Coastguard Worker        #get ssd surface
48*fb1b10abSAndroid Build Coastguard Worker        count = 0
49*fb1b10abSAndroid Build Coastguard Worker        center = self.cur_yuv[r * blk_sz:(r + 1) * blk_sz,
50*fb1b10abSAndroid Build Coastguard Worker                              c * blk_sz:(c + 1) * blk_sz, 0]
51*fb1b10abSAndroid Build Coastguard Worker        ty = np.clip(r * blk_sz + int(mvs[r, c, 0]), 0, self.height - blk_sz)
52*fb1b10abSAndroid Build Coastguard Worker        tx = np.clip(c * blk_sz + int(mvs[r, c, 1]), 0, self.width - blk_sz)
53*fb1b10abSAndroid Build Coastguard Worker        target = self.ref_yuv[ty:ty + blk_sz, tx:tx + blk_sz, 0]
54*fb1b10abSAndroid Build Coastguard Worker        for y, x in {(ty - blk_sz, tx), (ty + blk_sz, tx)}:
55*fb1b10abSAndroid Build Coastguard Worker          if 0 <= y < self.height - blk_sz and 0 <= x < self.width - blk_sz:
56*fb1b10abSAndroid Build Coastguard Worker            nb = self.ref_yuv[y:y + blk_sz, x:x + blk_sz, 0]
57*fb1b10abSAndroid Build Coastguard Worker            I_row += np.sum(np.abs(nb - center)) - np.sum(
58*fb1b10abSAndroid Build Coastguard Worker                np.abs(target - center))
59*fb1b10abSAndroid Build Coastguard Worker            count += 1
60*fb1b10abSAndroid Build Coastguard Worker        I_row //= (count * blk_sz * blk_sz)
61*fb1b10abSAndroid Build Coastguard Worker        count = 0
62*fb1b10abSAndroid Build Coastguard Worker        for y, x in {(ty, tx - blk_sz), (ty, tx + blk_sz)}:
63*fb1b10abSAndroid Build Coastguard Worker          if 0 <= y < self.height - blk_sz and 0 <= x < self.width - blk_sz:
64*fb1b10abSAndroid Build Coastguard Worker            nb = self.ref_yuv[y:y + blk_sz, x:x + blk_sz, 0]
65*fb1b10abSAndroid Build Coastguard Worker            I_col += np.sum(np.abs(nb - center)) - np.sum(
66*fb1b10abSAndroid Build Coastguard Worker                np.abs(target - center))
67*fb1b10abSAndroid Build Coastguard Worker            count += 1
68*fb1b10abSAndroid Build Coastguard Worker        I_col //= (count * blk_sz * blk_sz)
69*fb1b10abSAndroid Build Coastguard Worker        localDiff[r].append(
70*fb1b10abSAndroid Build Coastguard Worker            np.array([[I_row * I_row, I_row * I_col],
71*fb1b10abSAndroid Build Coastguard Worker                      [I_col * I_row, I_col * I_col]]))
72*fb1b10abSAndroid Build Coastguard Worker    return localDiff
73*fb1b10abSAndroid Build Coastguard Worker
74*fb1b10abSAndroid Build Coastguard Worker  """
75*fb1b10abSAndroid Build Coastguard Worker    add smooth constraint
76*fb1b10abSAndroid Build Coastguard Worker    """
77*fb1b10abSAndroid Build Coastguard Worker
78*fb1b10abSAndroid Build Coastguard Worker  def smooth(self, uvs, mvs):
79*fb1b10abSAndroid Build Coastguard Worker    sm_uvs = np.zeros(uvs.shape)
80*fb1b10abSAndroid Build Coastguard Worker    blk_sz = self.blk_sz
81*fb1b10abSAndroid Build Coastguard Worker    for r in xrange(self.num_row):
82*fb1b10abSAndroid Build Coastguard Worker      for c in xrange(self.num_col):
83*fb1b10abSAndroid Build Coastguard Worker        nb_uv = np.array([0.0, 0.0])
84*fb1b10abSAndroid Build Coastguard Worker        for i, j in {(r - 1, c), (r + 1, c), (r, c - 1), (r, c + 1)}:
85*fb1b10abSAndroid Build Coastguard Worker          if 0 <= i < self.num_row and 0 <= j < self.num_col:
86*fb1b10abSAndroid Build Coastguard Worker            nb_uv += uvs[i, j] / 6.0
87*fb1b10abSAndroid Build Coastguard Worker          else:
88*fb1b10abSAndroid Build Coastguard Worker            nb_uv += uvs[r, c] / 6.0
89*fb1b10abSAndroid Build Coastguard Worker        for i, j in {(r - 1, c - 1), (r - 1, c + 1), (r + 1, c - 1),
90*fb1b10abSAndroid Build Coastguard Worker                     (r + 1, c + 1)}:
91*fb1b10abSAndroid Build Coastguard Worker          if 0 <= i < self.num_row and 0 <= j < self.num_col:
92*fb1b10abSAndroid Build Coastguard Worker            nb_uv += uvs[i, j] / 12.0
93*fb1b10abSAndroid Build Coastguard Worker          else:
94*fb1b10abSAndroid Build Coastguard Worker            nb_uv += uvs[r, c] / 12.0
95*fb1b10abSAndroid Build Coastguard Worker        ssd_nb = self.block_dist(r, c, self.blk_sz * nb_uv)
96*fb1b10abSAndroid Build Coastguard Worker        mv = mvs[r, c]
97*fb1b10abSAndroid Build Coastguard Worker        ssd_mv = self.block_dist(r, c, mv)
98*fb1b10abSAndroid Build Coastguard Worker        alpha = (ssd_nb - ssd_mv) / (ssd_mv + 1e-6)
99*fb1b10abSAndroid Build Coastguard Worker        M = alpha * self.localDiff[r][c]
100*fb1b10abSAndroid Build Coastguard Worker        P = M + np.identity(2)
101*fb1b10abSAndroid Build Coastguard Worker        inv_P = LA.inv(P)
102*fb1b10abSAndroid Build Coastguard Worker        sm_uvs[r, c] = np.dot(inv_P, nb_uv) + np.dot(
103*fb1b10abSAndroid Build Coastguard Worker            np.matmul(inv_P, M), mv / blk_sz)
104*fb1b10abSAndroid Build Coastguard Worker    return sm_uvs
105*fb1b10abSAndroid Build Coastguard Worker
106*fb1b10abSAndroid Build Coastguard Worker  def block_matching(self):
107*fb1b10abSAndroid Build Coastguard Worker    self.search.motion_field_estimation()
108*fb1b10abSAndroid Build Coastguard Worker
109*fb1b10abSAndroid Build Coastguard Worker  def motion_field_estimation(self):
110*fb1b10abSAndroid Build Coastguard Worker    self.localDiff = self.getRefLocalDiff(self.search.mf)
111*fb1b10abSAndroid Build Coastguard Worker    #get matching results
112*fb1b10abSAndroid Build Coastguard Worker    mvs = self.search.mf
113*fb1b10abSAndroid Build Coastguard Worker    #add smoothness constraint
114*fb1b10abSAndroid Build Coastguard Worker    uvs = mvs / self.blk_sz
115*fb1b10abSAndroid Build Coastguard Worker    for _ in xrange(self.max_iter):
116*fb1b10abSAndroid Build Coastguard Worker      uvs = self.smooth(uvs, mvs)
117*fb1b10abSAndroid Build Coastguard Worker    self.mf = uvs * self.blk_sz
118*fb1b10abSAndroid Build Coastguard Worker
119*fb1b10abSAndroid Build Coastguard Worker
120*fb1b10abSAndroid Build Coastguard Worker"""Search & Smooth Model with Fixed Weights"""
121*fb1b10abSAndroid Build Coastguard Worker
122*fb1b10abSAndroid Build Coastguard Worker
123*fb1b10abSAndroid Build Coastguard Workerclass SearchSmoothFix(MotionEST):
124*fb1b10abSAndroid Build Coastguard Worker  """
125*fb1b10abSAndroid Build Coastguard Worker    Constructor:
126*fb1b10abSAndroid Build Coastguard Worker        cur_f: current frame
127*fb1b10abSAndroid Build Coastguard Worker        ref_f: reference frame
128*fb1b10abSAndroid Build Coastguard Worker        blk_sz: block size
129*fb1b10abSAndroid Build Coastguard Worker        wnd_size: search window size
130*fb1b10abSAndroid Build Coastguard Worker        beta: neigbor loss weight
131*fb1b10abSAndroid Build Coastguard Worker        max_iter: maximum number of iterations
132*fb1b10abSAndroid Build Coastguard Worker        metric: metric to compare the blocks distrotion
133*fb1b10abSAndroid Build Coastguard Worker    """
134*fb1b10abSAndroid Build Coastguard Worker
135*fb1b10abSAndroid Build Coastguard Worker  def __init__(self, cur_f, ref_f, blk_size, search, beta, max_iter=100):
136*fb1b10abSAndroid Build Coastguard Worker    self.search = search
137*fb1b10abSAndroid Build Coastguard Worker    self.max_iter = max_iter
138*fb1b10abSAndroid Build Coastguard Worker    self.beta = beta
139*fb1b10abSAndroid Build Coastguard Worker    super(SearchSmoothFix, self).__init__(cur_f, ref_f, blk_size)
140*fb1b10abSAndroid Build Coastguard Worker
141*fb1b10abSAndroid Build Coastguard Worker  """
142*fb1b10abSAndroid Build Coastguard Worker    get local diffiencial of refernce
143*fb1b10abSAndroid Build Coastguard Worker    """
144*fb1b10abSAndroid Build Coastguard Worker
145*fb1b10abSAndroid Build Coastguard Worker  def getRefLocalDiff(self, mvs):
146*fb1b10abSAndroid Build Coastguard Worker    m, n = self.num_row, self.num_col
147*fb1b10abSAndroid Build Coastguard Worker    localDiff = [[] for _ in xrange(m)]
148*fb1b10abSAndroid Build Coastguard Worker    blk_sz = self.blk_sz
149*fb1b10abSAndroid Build Coastguard Worker    for r in xrange(m):
150*fb1b10abSAndroid Build Coastguard Worker      for c in xrange(n):
151*fb1b10abSAndroid Build Coastguard Worker        I_row = 0
152*fb1b10abSAndroid Build Coastguard Worker        I_col = 0
153*fb1b10abSAndroid Build Coastguard Worker        #get ssd surface
154*fb1b10abSAndroid Build Coastguard Worker        count = 0
155*fb1b10abSAndroid Build Coastguard Worker        center = self.cur_yuv[r * blk_sz:(r + 1) * blk_sz,
156*fb1b10abSAndroid Build Coastguard Worker                              c * blk_sz:(c + 1) * blk_sz, 0]
157*fb1b10abSAndroid Build Coastguard Worker        ty = np.clip(r * blk_sz + int(mvs[r, c, 0]), 0, self.height - blk_sz)
158*fb1b10abSAndroid Build Coastguard Worker        tx = np.clip(c * blk_sz + int(mvs[r, c, 1]), 0, self.width - blk_sz)
159*fb1b10abSAndroid Build Coastguard Worker        target = self.ref_yuv[ty:ty + blk_sz, tx:tx + blk_sz, 0]
160*fb1b10abSAndroid Build Coastguard Worker        for y, x in {(ty - blk_sz, tx), (ty + blk_sz, tx)}:
161*fb1b10abSAndroid Build Coastguard Worker          if 0 <= y < self.height - blk_sz and 0 <= x < self.width - blk_sz:
162*fb1b10abSAndroid Build Coastguard Worker            nb = self.ref_yuv[y:y + blk_sz, x:x + blk_sz, 0]
163*fb1b10abSAndroid Build Coastguard Worker            I_row += np.sum(np.abs(nb - center)) - np.sum(
164*fb1b10abSAndroid Build Coastguard Worker                np.abs(target - center))
165*fb1b10abSAndroid Build Coastguard Worker            count += 1
166*fb1b10abSAndroid Build Coastguard Worker        I_row //= (count * blk_sz * blk_sz)
167*fb1b10abSAndroid Build Coastguard Worker        count = 0
168*fb1b10abSAndroid Build Coastguard Worker        for y, x in {(ty, tx - blk_sz), (ty, tx + blk_sz)}:
169*fb1b10abSAndroid Build Coastguard Worker          if 0 <= y < self.height - blk_sz and 0 <= x < self.width - blk_sz:
170*fb1b10abSAndroid Build Coastguard Worker            nb = self.ref_yuv[y:y + blk_sz, x:x + blk_sz, 0]
171*fb1b10abSAndroid Build Coastguard Worker            I_col += np.sum(np.abs(nb - center)) - np.sum(
172*fb1b10abSAndroid Build Coastguard Worker                np.abs(target - center))
173*fb1b10abSAndroid Build Coastguard Worker            count += 1
174*fb1b10abSAndroid Build Coastguard Worker        I_col //= (count * blk_sz * blk_sz)
175*fb1b10abSAndroid Build Coastguard Worker        localDiff[r].append(
176*fb1b10abSAndroid Build Coastguard Worker            np.array([[I_row * I_row, I_row * I_col],
177*fb1b10abSAndroid Build Coastguard Worker                      [I_col * I_row, I_col * I_col]]))
178*fb1b10abSAndroid Build Coastguard Worker    return localDiff
179*fb1b10abSAndroid Build Coastguard Worker
180*fb1b10abSAndroid Build Coastguard Worker  """
181*fb1b10abSAndroid Build Coastguard Worker    add smooth constraint
182*fb1b10abSAndroid Build Coastguard Worker    """
183*fb1b10abSAndroid Build Coastguard Worker
184*fb1b10abSAndroid Build Coastguard Worker  def smooth(self, uvs, mvs):
185*fb1b10abSAndroid Build Coastguard Worker    sm_uvs = np.zeros(uvs.shape)
186*fb1b10abSAndroid Build Coastguard Worker    blk_sz = self.blk_sz
187*fb1b10abSAndroid Build Coastguard Worker    for r in xrange(self.num_row):
188*fb1b10abSAndroid Build Coastguard Worker      for c in xrange(self.num_col):
189*fb1b10abSAndroid Build Coastguard Worker        nb_uv = np.array([0.0, 0.0])
190*fb1b10abSAndroid Build Coastguard Worker        for i, j in {(r - 1, c), (r + 1, c), (r, c - 1), (r, c + 1)}:
191*fb1b10abSAndroid Build Coastguard Worker          if 0 <= i < self.num_row and 0 <= j < self.num_col:
192*fb1b10abSAndroid Build Coastguard Worker            nb_uv += uvs[i, j] / 6.0
193*fb1b10abSAndroid Build Coastguard Worker          else:
194*fb1b10abSAndroid Build Coastguard Worker            nb_uv += uvs[r, c] / 6.0
195*fb1b10abSAndroid Build Coastguard Worker        for i, j in {(r - 1, c - 1), (r - 1, c + 1), (r + 1, c - 1),
196*fb1b10abSAndroid Build Coastguard Worker                     (r + 1, c + 1)}:
197*fb1b10abSAndroid Build Coastguard Worker          if 0 <= i < self.num_row and 0 <= j < self.num_col:
198*fb1b10abSAndroid Build Coastguard Worker            nb_uv += uvs[i, j] / 12.0
199*fb1b10abSAndroid Build Coastguard Worker          else:
200*fb1b10abSAndroid Build Coastguard Worker            nb_uv += uvs[r, c] / 12.0
201*fb1b10abSAndroid Build Coastguard Worker        mv = mvs[r, c] / blk_sz
202*fb1b10abSAndroid Build Coastguard Worker        M = self.localDiff[r][c]
203*fb1b10abSAndroid Build Coastguard Worker        P = M + self.beta * np.identity(2)
204*fb1b10abSAndroid Build Coastguard Worker        inv_P = LA.inv(P)
205*fb1b10abSAndroid Build Coastguard Worker        sm_uvs[r, c] = np.dot(inv_P, self.beta * nb_uv) + np.dot(
206*fb1b10abSAndroid Build Coastguard Worker            np.matmul(inv_P, M), mv)
207*fb1b10abSAndroid Build Coastguard Worker    return sm_uvs
208*fb1b10abSAndroid Build Coastguard Worker
209*fb1b10abSAndroid Build Coastguard Worker  def block_matching(self):
210*fb1b10abSAndroid Build Coastguard Worker    self.search.motion_field_estimation()
211*fb1b10abSAndroid Build Coastguard Worker
212*fb1b10abSAndroid Build Coastguard Worker  def motion_field_estimation(self):
213*fb1b10abSAndroid Build Coastguard Worker    #get local structure
214*fb1b10abSAndroid Build Coastguard Worker    self.localDiff = self.getRefLocalDiff(self.search.mf)
215*fb1b10abSAndroid Build Coastguard Worker    #get matching results
216*fb1b10abSAndroid Build Coastguard Worker    mvs = self.search.mf
217*fb1b10abSAndroid Build Coastguard Worker    #add smoothness constraint
218*fb1b10abSAndroid Build Coastguard Worker    uvs = mvs / self.blk_sz
219*fb1b10abSAndroid Build Coastguard Worker    for _ in xrange(self.max_iter):
220*fb1b10abSAndroid Build Coastguard Worker      uvs = self.smooth(uvs, mvs)
221*fb1b10abSAndroid Build Coastguard Worker    self.mf = uvs * self.blk_sz
222