xref: /aosp_15_r20/external/libvpx/tools/3D-Reconstruction/MotionEST/HornSchunck.py (revision fb1b10ab9aebc7c7068eedab379b749d7e3900be)
1*fb1b10abSAndroid Build Coastguard Worker##  Copyright (c) 2020 The WebM project authors. All Rights Reserved.
2*fb1b10abSAndroid Build Coastguard Worker##
3*fb1b10abSAndroid Build Coastguard Worker##  Use of this source code is governed by a BSD-style license
4*fb1b10abSAndroid Build Coastguard Worker##  that can be found in the LICENSE file in the root of the source
5*fb1b10abSAndroid Build Coastguard Worker##  tree. An additional intellectual property rights grant can be found
6*fb1b10abSAndroid Build Coastguard Worker##  in the file PATENTS.  All contributing project authors may
7*fb1b10abSAndroid Build Coastguard Worker##  be found in the AUTHORS file in the root of the source tree.
8*fb1b10abSAndroid Build Coastguard Worker##
9*fb1b10abSAndroid Build Coastguard Worker
10*fb1b10abSAndroid Build Coastguard Worker# coding: utf-8
11*fb1b10abSAndroid Build Coastguard Workerimport numpy as np
12*fb1b10abSAndroid Build Coastguard Workerimport numpy.linalg as LA
13*fb1b10abSAndroid Build Coastguard Workerfrom scipy.ndimage.filters import gaussian_filter
14*fb1b10abSAndroid Build Coastguard Workerfrom scipy.sparse import csc_matrix
15*fb1b10abSAndroid Build Coastguard Workerfrom scipy.sparse.linalg import inv
16*fb1b10abSAndroid Build Coastguard Workerfrom MotionEST import MotionEST
17*fb1b10abSAndroid Build Coastguard Worker"""Horn & Schunck Model"""
18*fb1b10abSAndroid Build Coastguard Worker
19*fb1b10abSAndroid Build Coastguard Worker
20*fb1b10abSAndroid Build Coastguard Workerclass HornSchunck(MotionEST):
21*fb1b10abSAndroid Build Coastguard Worker  """
22*fb1b10abSAndroid Build Coastguard Worker    constructor:
23*fb1b10abSAndroid Build Coastguard Worker        cur_f: current frame
24*fb1b10abSAndroid Build Coastguard Worker        ref_f: reference frame
25*fb1b10abSAndroid Build Coastguard Worker        blk_sz: block size
26*fb1b10abSAndroid Build Coastguard Worker        alpha: smooth constrain weight
27*fb1b10abSAndroid Build Coastguard Worker        sigma: gaussian blur parameter
28*fb1b10abSAndroid Build Coastguard Worker    """
29*fb1b10abSAndroid Build Coastguard Worker
30*fb1b10abSAndroid Build Coastguard Worker  def __init__(self, cur_f, ref_f, blk_sz, alpha, sigma, max_iter=100):
31*fb1b10abSAndroid Build Coastguard Worker    super(HornSchunck, self).__init__(cur_f, ref_f, blk_sz)
32*fb1b10abSAndroid Build Coastguard Worker    self.cur_I, self.ref_I = self.getIntensity()
33*fb1b10abSAndroid Build Coastguard Worker    #perform gaussian blur to smooth the intensity
34*fb1b10abSAndroid Build Coastguard Worker    self.cur_I = gaussian_filter(self.cur_I, sigma=sigma)
35*fb1b10abSAndroid Build Coastguard Worker    self.ref_I = gaussian_filter(self.ref_I, sigma=sigma)
36*fb1b10abSAndroid Build Coastguard Worker    self.alpha = alpha
37*fb1b10abSAndroid Build Coastguard Worker    self.max_iter = max_iter
38*fb1b10abSAndroid Build Coastguard Worker    self.Ix, self.Iy, self.It = self.intensityDiff()
39*fb1b10abSAndroid Build Coastguard Worker
40*fb1b10abSAndroid Build Coastguard Worker  """
41*fb1b10abSAndroid Build Coastguard Worker    Build Frame Intensity
42*fb1b10abSAndroid Build Coastguard Worker    """
43*fb1b10abSAndroid Build Coastguard Worker
44*fb1b10abSAndroid Build Coastguard Worker  def getIntensity(self):
45*fb1b10abSAndroid Build Coastguard Worker    cur_I = np.zeros((self.num_row, self.num_col))
46*fb1b10abSAndroid Build Coastguard Worker    ref_I = np.zeros((self.num_row, self.num_col))
47*fb1b10abSAndroid Build Coastguard Worker    #use average intensity as block's intensity
48*fb1b10abSAndroid Build Coastguard Worker    for i in xrange(self.num_row):
49*fb1b10abSAndroid Build Coastguard Worker      for j in xrange(self.num_col):
50*fb1b10abSAndroid Build Coastguard Worker        r = i * self.blk_sz
51*fb1b10abSAndroid Build Coastguard Worker        c = j * self.blk_sz
52*fb1b10abSAndroid Build Coastguard Worker        cur_I[i, j] = np.mean(self.cur_yuv[r:r + self.blk_sz, c:c + self.blk_sz,
53*fb1b10abSAndroid Build Coastguard Worker                                           0])
54*fb1b10abSAndroid Build Coastguard Worker        ref_I[i, j] = np.mean(self.ref_yuv[r:r + self.blk_sz, c:c + self.blk_sz,
55*fb1b10abSAndroid Build Coastguard Worker                                           0])
56*fb1b10abSAndroid Build Coastguard Worker    return cur_I, ref_I
57*fb1b10abSAndroid Build Coastguard Worker
58*fb1b10abSAndroid Build Coastguard Worker  """
59*fb1b10abSAndroid Build Coastguard Worker    Get First Order Derivative
60*fb1b10abSAndroid Build Coastguard Worker    """
61*fb1b10abSAndroid Build Coastguard Worker
62*fb1b10abSAndroid Build Coastguard Worker  def intensityDiff(self):
63*fb1b10abSAndroid Build Coastguard Worker    Ix = np.zeros((self.num_row, self.num_col))
64*fb1b10abSAndroid Build Coastguard Worker    Iy = np.zeros((self.num_row, self.num_col))
65*fb1b10abSAndroid Build Coastguard Worker    It = np.zeros((self.num_row, self.num_col))
66*fb1b10abSAndroid Build Coastguard Worker    sz = self.blk_sz
67*fb1b10abSAndroid Build Coastguard Worker    for i in xrange(self.num_row - 1):
68*fb1b10abSAndroid Build Coastguard Worker      for j in xrange(self.num_col - 1):
69*fb1b10abSAndroid Build Coastguard Worker        """
70*fb1b10abSAndroid Build Coastguard Worker                Ix:
71*fb1b10abSAndroid Build Coastguard Worker                (i  ,j) <--- (i  ,j+1)
72*fb1b10abSAndroid Build Coastguard Worker                (i+1,j) <--- (i+1,j+1)
73*fb1b10abSAndroid Build Coastguard Worker                """
74*fb1b10abSAndroid Build Coastguard Worker        count = 0
75*fb1b10abSAndroid Build Coastguard Worker        for r, c in {(i, j + 1), (i + 1, j + 1)}:
76*fb1b10abSAndroid Build Coastguard Worker          if 0 <= r < self.num_row and 0 < c < self.num_col:
77*fb1b10abSAndroid Build Coastguard Worker            Ix[i, j] += (
78*fb1b10abSAndroid Build Coastguard Worker                self.cur_I[r, c] - self.cur_I[r, c - 1] + self.ref_I[r, c] -
79*fb1b10abSAndroid Build Coastguard Worker                self.ref_I[r, c - 1])
80*fb1b10abSAndroid Build Coastguard Worker            count += 2
81*fb1b10abSAndroid Build Coastguard Worker        Ix[i, j] /= count
82*fb1b10abSAndroid Build Coastguard Worker        """
83*fb1b10abSAndroid Build Coastguard Worker                Iy:
84*fb1b10abSAndroid Build Coastguard Worker                (i  ,j)      (i  ,j+1)
85*fb1b10abSAndroid Build Coastguard Worker                   ^             ^
86*fb1b10abSAndroid Build Coastguard Worker                   |             |
87*fb1b10abSAndroid Build Coastguard Worker                (i+1,j)      (i+1,j+1)
88*fb1b10abSAndroid Build Coastguard Worker                """
89*fb1b10abSAndroid Build Coastguard Worker        count = 0
90*fb1b10abSAndroid Build Coastguard Worker        for r, c in {(i + 1, j), (i + 1, j + 1)}:
91*fb1b10abSAndroid Build Coastguard Worker          if 0 < r < self.num_row and 0 <= c < self.num_col:
92*fb1b10abSAndroid Build Coastguard Worker            Iy[i, j] += (
93*fb1b10abSAndroid Build Coastguard Worker                self.cur_I[r, c] - self.cur_I[r - 1, c] + self.ref_I[r, c] -
94*fb1b10abSAndroid Build Coastguard Worker                self.ref_I[r - 1, c])
95*fb1b10abSAndroid Build Coastguard Worker            count += 2
96*fb1b10abSAndroid Build Coastguard Worker        Iy[i, j] /= count
97*fb1b10abSAndroid Build Coastguard Worker        count = 0
98*fb1b10abSAndroid Build Coastguard Worker        #It:
99*fb1b10abSAndroid Build Coastguard Worker        for r in xrange(i, i + 2):
100*fb1b10abSAndroid Build Coastguard Worker          for c in xrange(j, j + 2):
101*fb1b10abSAndroid Build Coastguard Worker            if 0 <= r < self.num_row and 0 <= c < self.num_col:
102*fb1b10abSAndroid Build Coastguard Worker              It[i, j] += (self.ref_I[r, c] - self.cur_I[r, c])
103*fb1b10abSAndroid Build Coastguard Worker              count += 1
104*fb1b10abSAndroid Build Coastguard Worker        It[i, j] /= count
105*fb1b10abSAndroid Build Coastguard Worker    return Ix, Iy, It
106*fb1b10abSAndroid Build Coastguard Worker
107*fb1b10abSAndroid Build Coastguard Worker  """
108*fb1b10abSAndroid Build Coastguard Worker    Get weighted average of neighbor motion vectors
109*fb1b10abSAndroid Build Coastguard Worker    for evaluation of laplacian
110*fb1b10abSAndroid Build Coastguard Worker    """
111*fb1b10abSAndroid Build Coastguard Worker
112*fb1b10abSAndroid Build Coastguard Worker  def averageMV(self):
113*fb1b10abSAndroid Build Coastguard Worker    avg = np.zeros((self.num_row, self.num_col, 2))
114*fb1b10abSAndroid Build Coastguard Worker    """
115*fb1b10abSAndroid Build Coastguard Worker        1/12 ---  1/6 --- 1/12
116*fb1b10abSAndroid Build Coastguard Worker         |         |       |
117*fb1b10abSAndroid Build Coastguard Worker        1/6  --- -1/8 --- 1/6
118*fb1b10abSAndroid Build Coastguard Worker         |         |       |
119*fb1b10abSAndroid Build Coastguard Worker        1/12 ---  1/6 --- 1/12
120*fb1b10abSAndroid Build Coastguard Worker        """
121*fb1b10abSAndroid Build Coastguard Worker    for i in xrange(self.num_row):
122*fb1b10abSAndroid Build Coastguard Worker      for j in xrange(self.num_col):
123*fb1b10abSAndroid Build Coastguard Worker        for r, c in {(-1, 0), (1, 0), (0, -1), (0, 1)}:
124*fb1b10abSAndroid Build Coastguard Worker          if 0 <= i + r < self.num_row and 0 <= j + c < self.num_col:
125*fb1b10abSAndroid Build Coastguard Worker            avg[i, j] += self.mf[i + r, j + c] / 6.0
126*fb1b10abSAndroid Build Coastguard Worker        for r, c in {(-1, -1), (-1, 1), (1, -1), (1, 1)}:
127*fb1b10abSAndroid Build Coastguard Worker          if 0 <= i + r < self.num_row and 0 <= j + c < self.num_col:
128*fb1b10abSAndroid Build Coastguard Worker            avg[i, j] += self.mf[i + r, j + c] / 12.0
129*fb1b10abSAndroid Build Coastguard Worker    return avg
130*fb1b10abSAndroid Build Coastguard Worker
131*fb1b10abSAndroid Build Coastguard Worker  def motion_field_estimation(self):
132*fb1b10abSAndroid Build Coastguard Worker    count = 0
133*fb1b10abSAndroid Build Coastguard Worker    """
134*fb1b10abSAndroid Build Coastguard Worker        u_{n+1} = ~u_n - Ix(Ix.~u_n+Iy.~v+It)/(IxIx+IyIy+alpha^2)
135*fb1b10abSAndroid Build Coastguard Worker        v_{n+1} = ~v_n - Iy(Ix.~u_n+Iy.~v+It)/(IxIx+IyIy+alpha^2)
136*fb1b10abSAndroid Build Coastguard Worker        """
137*fb1b10abSAndroid Build Coastguard Worker    denom = self.alpha**2 + np.power(self.Ix, 2) + np.power(self.Iy, 2)
138*fb1b10abSAndroid Build Coastguard Worker    while count < self.max_iter:
139*fb1b10abSAndroid Build Coastguard Worker      avg = self.averageMV()
140*fb1b10abSAndroid Build Coastguard Worker      self.mf[:, :, 1] = avg[:, :, 1] - self.Ix * (
141*fb1b10abSAndroid Build Coastguard Worker          self.Ix * avg[:, :, 1] + self.Iy * avg[:, :, 0] + self.It) / denom
142*fb1b10abSAndroid Build Coastguard Worker      self.mf[:, :, 0] = avg[:, :, 0] - self.Iy * (
143*fb1b10abSAndroid Build Coastguard Worker          self.Ix * avg[:, :, 1] + self.Iy * avg[:, :, 0] + self.It) / denom
144*fb1b10abSAndroid Build Coastguard Worker      count += 1
145*fb1b10abSAndroid Build Coastguard Worker    self.mf *= self.blk_sz
146*fb1b10abSAndroid Build Coastguard Worker
147*fb1b10abSAndroid Build Coastguard Worker  def motion_field_estimation_mat(self):
148*fb1b10abSAndroid Build Coastguard Worker    row_idx = []
149*fb1b10abSAndroid Build Coastguard Worker    col_idx = []
150*fb1b10abSAndroid Build Coastguard Worker    data = []
151*fb1b10abSAndroid Build Coastguard Worker
152*fb1b10abSAndroid Build Coastguard Worker    N = 2 * self.num_row * self.num_col
153*fb1b10abSAndroid Build Coastguard Worker    b = np.zeros((N, 1))
154*fb1b10abSAndroid Build Coastguard Worker    for i in xrange(self.num_row):
155*fb1b10abSAndroid Build Coastguard Worker      for j in xrange(self.num_col):
156*fb1b10abSAndroid Build Coastguard Worker        """(IxIx+alpha^2)u+IxIy.v-alpha^2~u IxIy.u+(IyIy+alpha^2)v-alpha^2~v"""
157*fb1b10abSAndroid Build Coastguard Worker        u_idx = i * 2 * self.num_col + 2 * j
158*fb1b10abSAndroid Build Coastguard Worker        v_idx = u_idx + 1
159*fb1b10abSAndroid Build Coastguard Worker        b[u_idx, 0] = -self.Ix[i, j] * self.It[i, j]
160*fb1b10abSAndroid Build Coastguard Worker        b[v_idx, 0] = -self.Iy[i, j] * self.It[i, j]
161*fb1b10abSAndroid Build Coastguard Worker        #u: (IxIx+alpha^2)u
162*fb1b10abSAndroid Build Coastguard Worker        row_idx.append(u_idx)
163*fb1b10abSAndroid Build Coastguard Worker        col_idx.append(u_idx)
164*fb1b10abSAndroid Build Coastguard Worker        data.append(self.Ix[i, j] * self.Ix[i, j] + self.alpha**2)
165*fb1b10abSAndroid Build Coastguard Worker        #IxIy.v
166*fb1b10abSAndroid Build Coastguard Worker        row_idx.append(u_idx)
167*fb1b10abSAndroid Build Coastguard Worker        col_idx.append(v_idx)
168*fb1b10abSAndroid Build Coastguard Worker        data.append(self.Ix[i, j] * self.Iy[i, j])
169*fb1b10abSAndroid Build Coastguard Worker
170*fb1b10abSAndroid Build Coastguard Worker        #v: IxIy.u
171*fb1b10abSAndroid Build Coastguard Worker        row_idx.append(v_idx)
172*fb1b10abSAndroid Build Coastguard Worker        col_idx.append(u_idx)
173*fb1b10abSAndroid Build Coastguard Worker        data.append(self.Ix[i, j] * self.Iy[i, j])
174*fb1b10abSAndroid Build Coastguard Worker        #(IyIy+alpha^2)v
175*fb1b10abSAndroid Build Coastguard Worker        row_idx.append(v_idx)
176*fb1b10abSAndroid Build Coastguard Worker        col_idx.append(v_idx)
177*fb1b10abSAndroid Build Coastguard Worker        data.append(self.Iy[i, j] * self.Iy[i, j] + self.alpha**2)
178*fb1b10abSAndroid Build Coastguard Worker
179*fb1b10abSAndroid Build Coastguard Worker        #-alpha^2~u
180*fb1b10abSAndroid Build Coastguard Worker        #-alpha^2~v
181*fb1b10abSAndroid Build Coastguard Worker        for r, c in {(-1, 0), (1, 0), (0, -1), (0, 1)}:
182*fb1b10abSAndroid Build Coastguard Worker          if 0 <= i + r < self.num_row and 0 <= j + c < self.num_col:
183*fb1b10abSAndroid Build Coastguard Worker            u_nb = (i + r) * 2 * self.num_col + 2 * (j + c)
184*fb1b10abSAndroid Build Coastguard Worker            v_nb = u_nb + 1
185*fb1b10abSAndroid Build Coastguard Worker
186*fb1b10abSAndroid Build Coastguard Worker            row_idx.append(u_idx)
187*fb1b10abSAndroid Build Coastguard Worker            col_idx.append(u_nb)
188*fb1b10abSAndroid Build Coastguard Worker            data.append(-1 * self.alpha**2 / 6.0)
189*fb1b10abSAndroid Build Coastguard Worker
190*fb1b10abSAndroid Build Coastguard Worker            row_idx.append(v_idx)
191*fb1b10abSAndroid Build Coastguard Worker            col_idx.append(v_nb)
192*fb1b10abSAndroid Build Coastguard Worker            data.append(-1 * self.alpha**2 / 6.0)
193*fb1b10abSAndroid Build Coastguard Worker        for r, c in {(-1, -1), (-1, 1), (1, -1), (1, 1)}:
194*fb1b10abSAndroid Build Coastguard Worker          if 0 <= i + r < self.num_row and 0 <= j + c < self.num_col:
195*fb1b10abSAndroid Build Coastguard Worker            u_nb = (i + r) * 2 * self.num_col + 2 * (j + c)
196*fb1b10abSAndroid Build Coastguard Worker            v_nb = u_nb + 1
197*fb1b10abSAndroid Build Coastguard Worker
198*fb1b10abSAndroid Build Coastguard Worker            row_idx.append(u_idx)
199*fb1b10abSAndroid Build Coastguard Worker            col_idx.append(u_nb)
200*fb1b10abSAndroid Build Coastguard Worker            data.append(-1 * self.alpha**2 / 12.0)
201*fb1b10abSAndroid Build Coastguard Worker
202*fb1b10abSAndroid Build Coastguard Worker            row_idx.append(v_idx)
203*fb1b10abSAndroid Build Coastguard Worker            col_idx.append(v_nb)
204*fb1b10abSAndroid Build Coastguard Worker            data.append(-1 * self.alpha**2 / 12.0)
205*fb1b10abSAndroid Build Coastguard Worker    M = csc_matrix((data, (row_idx, col_idx)), shape=(N, N))
206*fb1b10abSAndroid Build Coastguard Worker    M_inv = inv(M)
207*fb1b10abSAndroid Build Coastguard Worker    uv = M_inv.dot(b)
208*fb1b10abSAndroid Build Coastguard Worker
209*fb1b10abSAndroid Build Coastguard Worker    for i in xrange(self.num_row):
210*fb1b10abSAndroid Build Coastguard Worker      for j in xrange(self.num_col):
211*fb1b10abSAndroid Build Coastguard Worker        self.mf[i, j, 0] = uv[i * 2 * self.num_col + 2 * j + 1, 0] * self.blk_sz
212*fb1b10abSAndroid Build Coastguard Worker        self.mf[i, j, 1] = uv[i * 2 * self.num_col + 2 * j, 0] * self.blk_sz
213