xref: /aosp_15_r20/external/libvpx/tools/3D-Reconstruction/MotionEST/Anandan.py (revision fb1b10ab9aebc7c7068eedab379b749d7e3900be)
1*fb1b10abSAndroid Build Coastguard Worker##  Copyright (c) 2020 The WebM project authors. All Rights Reserved.
2*fb1b10abSAndroid Build Coastguard Worker##
3*fb1b10abSAndroid Build Coastguard Worker##  Use of this source code is governed by a BSD-style license
4*fb1b10abSAndroid Build Coastguard Worker##  that can be found in the LICENSE file in the root of the source
5*fb1b10abSAndroid Build Coastguard Worker##  tree. An additional intellectual property rights grant can be found
6*fb1b10abSAndroid Build Coastguard Worker##  in the file PATENTS.  All contributing project authors may
7*fb1b10abSAndroid Build Coastguard Worker##  be found in the AUTHORS file in the root of the source tree.
8*fb1b10abSAndroid Build Coastguard Worker##
9*fb1b10abSAndroid Build Coastguard Worker
10*fb1b10abSAndroid Build Coastguard Worker# coding: utf-8
11*fb1b10abSAndroid Build Coastguard Workerimport numpy as np
12*fb1b10abSAndroid Build Coastguard Workerimport numpy.linalg as LA
13*fb1b10abSAndroid Build Coastguard Workerfrom scipy.ndimage.filters import gaussian_filter
14*fb1b10abSAndroid Build Coastguard Workerfrom scipy.sparse import csc_matrix
15*fb1b10abSAndroid Build Coastguard Workerfrom scipy.sparse.linalg import inv
16*fb1b10abSAndroid Build Coastguard Workerfrom MotionEST import MotionEST
17*fb1b10abSAndroid Build Coastguard Worker"""Anandan Model"""
18*fb1b10abSAndroid Build Coastguard Worker
19*fb1b10abSAndroid Build Coastguard Worker
20*fb1b10abSAndroid Build Coastguard Workerclass Anandan(MotionEST):
21*fb1b10abSAndroid Build Coastguard Worker  """
22*fb1b10abSAndroid Build Coastguard Worker    constructor:
23*fb1b10abSAndroid Build Coastguard Worker        cur_f: current frame
24*fb1b10abSAndroid Build Coastguard Worker        ref_f: reference frame
25*fb1b10abSAndroid Build Coastguard Worker        blk_sz: block size
26*fb1b10abSAndroid Build Coastguard Worker        beta: smooth constrain weight
27*fb1b10abSAndroid Build Coastguard Worker        k1,k2,k3: confidence coefficients
28*fb1b10abSAndroid Build Coastguard Worker        max_iter: maximum number of iterations
29*fb1b10abSAndroid Build Coastguard Worker    """
30*fb1b10abSAndroid Build Coastguard Worker
31*fb1b10abSAndroid Build Coastguard Worker  def __init__(self, cur_f, ref_f, blk_sz, beta, k1, k2, k3, max_iter=100):
32*fb1b10abSAndroid Build Coastguard Worker    super(Anandan, self).__init__(cur_f, ref_f, blk_sz)
33*fb1b10abSAndroid Build Coastguard Worker    self.levels = int(np.log2(blk_sz))
34*fb1b10abSAndroid Build Coastguard Worker    self.intensity_hierarchy()
35*fb1b10abSAndroid Build Coastguard Worker    self.c_maxs = []
36*fb1b10abSAndroid Build Coastguard Worker    self.c_mins = []
37*fb1b10abSAndroid Build Coastguard Worker    self.e_maxs = []
38*fb1b10abSAndroid Build Coastguard Worker    self.e_mins = []
39*fb1b10abSAndroid Build Coastguard Worker    for l in xrange(self.levels + 1):
40*fb1b10abSAndroid Build Coastguard Worker      c_max, c_min, e_max, e_min = self.get_curvature(self.cur_Is[l])
41*fb1b10abSAndroid Build Coastguard Worker      self.c_maxs.append(c_max)
42*fb1b10abSAndroid Build Coastguard Worker      self.c_mins.append(c_min)
43*fb1b10abSAndroid Build Coastguard Worker      self.e_maxs.append(e_max)
44*fb1b10abSAndroid Build Coastguard Worker      self.e_mins.append(e_min)
45*fb1b10abSAndroid Build Coastguard Worker    self.beta = beta
46*fb1b10abSAndroid Build Coastguard Worker    self.k1, self.k2, self.k3 = k1, k2, k3
47*fb1b10abSAndroid Build Coastguard Worker    self.max_iter = max_iter
48*fb1b10abSAndroid Build Coastguard Worker
49*fb1b10abSAndroid Build Coastguard Worker  """
50*fb1b10abSAndroid Build Coastguard Worker    build intensity hierarchy
51*fb1b10abSAndroid Build Coastguard Worker    """
52*fb1b10abSAndroid Build Coastguard Worker
53*fb1b10abSAndroid Build Coastguard Worker  def intensity_hierarchy(self):
54*fb1b10abSAndroid Build Coastguard Worker    level = 0
55*fb1b10abSAndroid Build Coastguard Worker    self.cur_Is = []
56*fb1b10abSAndroid Build Coastguard Worker    self.ref_Is = []
57*fb1b10abSAndroid Build Coastguard Worker    #build each level itensity by using gaussian filters
58*fb1b10abSAndroid Build Coastguard Worker    while level <= self.levels:
59*fb1b10abSAndroid Build Coastguard Worker      cur_I = gaussian_filter(self.cur_yuv[:, :, 0], sigma=(2**level) * 0.56)
60*fb1b10abSAndroid Build Coastguard Worker      ref_I = gaussian_filter(self.ref_yuv[:, :, 0], sigma=(2**level) * 0.56)
61*fb1b10abSAndroid Build Coastguard Worker      self.ref_Is.append(ref_I)
62*fb1b10abSAndroid Build Coastguard Worker      self.cur_Is.append(cur_I)
63*fb1b10abSAndroid Build Coastguard Worker      level += 1
64*fb1b10abSAndroid Build Coastguard Worker
65*fb1b10abSAndroid Build Coastguard Worker  """
66*fb1b10abSAndroid Build Coastguard Worker    get curvature of each block
67*fb1b10abSAndroid Build Coastguard Worker    """
68*fb1b10abSAndroid Build Coastguard Worker
69*fb1b10abSAndroid Build Coastguard Worker  def get_curvature(self, I):
70*fb1b10abSAndroid Build Coastguard Worker    c_max = np.zeros((self.num_row, self.num_col))
71*fb1b10abSAndroid Build Coastguard Worker    c_min = np.zeros((self.num_row, self.num_col))
72*fb1b10abSAndroid Build Coastguard Worker    e_max = np.zeros((self.num_row, self.num_col, 2))
73*fb1b10abSAndroid Build Coastguard Worker    e_min = np.zeros((self.num_row, self.num_col, 2))
74*fb1b10abSAndroid Build Coastguard Worker    for r in xrange(self.num_row):
75*fb1b10abSAndroid Build Coastguard Worker      for c in xrange(self.num_col):
76*fb1b10abSAndroid Build Coastguard Worker        h11, h12, h21, h22 = 0, 0, 0, 0
77*fb1b10abSAndroid Build Coastguard Worker        for i in xrange(r * self.blk_sz, r * self.blk_sz + self.blk_sz):
78*fb1b10abSAndroid Build Coastguard Worker          for j in xrange(c * self.blk_sz, c * self.blk_sz + self.blk_sz):
79*fb1b10abSAndroid Build Coastguard Worker            if 0 <= i < self.height - 1 and 0 <= j < self.width - 1:
80*fb1b10abSAndroid Build Coastguard Worker              Ix = I[i][j + 1] - I[i][j]
81*fb1b10abSAndroid Build Coastguard Worker              Iy = I[i + 1][j] - I[i][j]
82*fb1b10abSAndroid Build Coastguard Worker              h11 += Iy * Iy
83*fb1b10abSAndroid Build Coastguard Worker              h12 += Ix * Iy
84*fb1b10abSAndroid Build Coastguard Worker              h21 += Ix * Iy
85*fb1b10abSAndroid Build Coastguard Worker              h22 += Ix * Ix
86*fb1b10abSAndroid Build Coastguard Worker        U, S, _ = LA.svd(np.array([[h11, h12], [h21, h22]]))
87*fb1b10abSAndroid Build Coastguard Worker        c_max[r, c], c_min[r, c] = S[0], S[1]
88*fb1b10abSAndroid Build Coastguard Worker        e_max[r, c] = U[:, 0]
89*fb1b10abSAndroid Build Coastguard Worker        e_min[r, c] = U[:, 1]
90*fb1b10abSAndroid Build Coastguard Worker    return c_max, c_min, e_max, e_min
91*fb1b10abSAndroid Build Coastguard Worker
92*fb1b10abSAndroid Build Coastguard Worker  """
93*fb1b10abSAndroid Build Coastguard Worker    get ssd of motion vector:
94*fb1b10abSAndroid Build Coastguard Worker      cur_I: current intensity
95*fb1b10abSAndroid Build Coastguard Worker      ref_I: reference intensity
96*fb1b10abSAndroid Build Coastguard Worker      center: current position
97*fb1b10abSAndroid Build Coastguard Worker      mv: motion vector
98*fb1b10abSAndroid Build Coastguard Worker    """
99*fb1b10abSAndroid Build Coastguard Worker
100*fb1b10abSAndroid Build Coastguard Worker  def get_ssd(self, cur_I, ref_I, center, mv):
101*fb1b10abSAndroid Build Coastguard Worker    ssd = 0
102*fb1b10abSAndroid Build Coastguard Worker    for r in xrange(int(center[0]), int(center[0]) + self.blk_sz):
103*fb1b10abSAndroid Build Coastguard Worker      for c in xrange(int(center[1]), int(center[1]) + self.blk_sz):
104*fb1b10abSAndroid Build Coastguard Worker        if 0 <= r < self.height and 0 <= c < self.width:
105*fb1b10abSAndroid Build Coastguard Worker          tr, tc = r + int(mv[0]), c + int(mv[1])
106*fb1b10abSAndroid Build Coastguard Worker          if 0 <= tr < self.height and 0 <= tc < self.width:
107*fb1b10abSAndroid Build Coastguard Worker            ssd += (ref_I[tr, tc] - cur_I[r, c])**2
108*fb1b10abSAndroid Build Coastguard Worker          else:
109*fb1b10abSAndroid Build Coastguard Worker            ssd += cur_I[r, c]**2
110*fb1b10abSAndroid Build Coastguard Worker    return ssd
111*fb1b10abSAndroid Build Coastguard Worker
112*fb1b10abSAndroid Build Coastguard Worker  """
113*fb1b10abSAndroid Build Coastguard Worker    get region match of level l
114*fb1b10abSAndroid Build Coastguard Worker      l: current level
115*fb1b10abSAndroid Build Coastguard Worker      last_mvs: matchine results of last level
116*fb1b10abSAndroid Build Coastguard Worker      radius: movenment radius
117*fb1b10abSAndroid Build Coastguard Worker    """
118*fb1b10abSAndroid Build Coastguard Worker
119*fb1b10abSAndroid Build Coastguard Worker  def region_match(self, l, last_mvs, radius):
120*fb1b10abSAndroid Build Coastguard Worker    mvs = np.zeros((self.num_row, self.num_col, 2))
121*fb1b10abSAndroid Build Coastguard Worker    min_ssds = np.zeros((self.num_row, self.num_col))
122*fb1b10abSAndroid Build Coastguard Worker    for r in xrange(self.num_row):
123*fb1b10abSAndroid Build Coastguard Worker      for c in xrange(self.num_col):
124*fb1b10abSAndroid Build Coastguard Worker        center = np.array([r * self.blk_sz, c * self.blk_sz])
125*fb1b10abSAndroid Build Coastguard Worker        #use overlap hierarchy policy
126*fb1b10abSAndroid Build Coastguard Worker        init_mvs = []
127*fb1b10abSAndroid Build Coastguard Worker        if last_mvs is None:
128*fb1b10abSAndroid Build Coastguard Worker          init_mvs = [np.array([0, 0])]
129*fb1b10abSAndroid Build Coastguard Worker        else:
130*fb1b10abSAndroid Build Coastguard Worker          for i, j in {(r, c), (r, c + 1), (r + 1, c), (r + 1, c + 1)}:
131*fb1b10abSAndroid Build Coastguard Worker            if 0 <= i < last_mvs.shape[0] and 0 <= j < last_mvs.shape[1]:
132*fb1b10abSAndroid Build Coastguard Worker              init_mvs.append(last_mvs[i, j])
133*fb1b10abSAndroid Build Coastguard Worker        #use last matching results as the start position as current level
134*fb1b10abSAndroid Build Coastguard Worker        min_ssd = None
135*fb1b10abSAndroid Build Coastguard Worker        min_mv = None
136*fb1b10abSAndroid Build Coastguard Worker        for init_mv in init_mvs:
137*fb1b10abSAndroid Build Coastguard Worker          for i in xrange(-2, 3):
138*fb1b10abSAndroid Build Coastguard Worker            for j in xrange(-2, 3):
139*fb1b10abSAndroid Build Coastguard Worker              mv = init_mv + np.array([i, j]) * radius
140*fb1b10abSAndroid Build Coastguard Worker              ssd = self.get_ssd(self.cur_Is[l], self.ref_Is[l], center, mv)
141*fb1b10abSAndroid Build Coastguard Worker              if min_ssd is None or ssd < min_ssd:
142*fb1b10abSAndroid Build Coastguard Worker                min_ssd = ssd
143*fb1b10abSAndroid Build Coastguard Worker                min_mv = mv
144*fb1b10abSAndroid Build Coastguard Worker        min_ssds[r, c] = min_ssd
145*fb1b10abSAndroid Build Coastguard Worker        mvs[r, c] = min_mv
146*fb1b10abSAndroid Build Coastguard Worker    return mvs, min_ssds
147*fb1b10abSAndroid Build Coastguard Worker
148*fb1b10abSAndroid Build Coastguard Worker  """
149*fb1b10abSAndroid Build Coastguard Worker    smooth motion field based on neighbor constraint
150*fb1b10abSAndroid Build Coastguard Worker      uvs: current estimation
151*fb1b10abSAndroid Build Coastguard Worker      mvs: matching results
152*fb1b10abSAndroid Build Coastguard Worker      min_ssds: minimum ssd of matching results
153*fb1b10abSAndroid Build Coastguard Worker      l: current level
154*fb1b10abSAndroid Build Coastguard Worker    """
155*fb1b10abSAndroid Build Coastguard Worker
156*fb1b10abSAndroid Build Coastguard Worker  def smooth(self, uvs, mvs, min_ssds, l):
157*fb1b10abSAndroid Build Coastguard Worker    sm_uvs = np.zeros((self.num_row, self.num_col, 2))
158*fb1b10abSAndroid Build Coastguard Worker    c_max = self.c_maxs[l]
159*fb1b10abSAndroid Build Coastguard Worker    c_min = self.c_mins[l]
160*fb1b10abSAndroid Build Coastguard Worker    e_max = self.e_maxs[l]
161*fb1b10abSAndroid Build Coastguard Worker    e_min = self.e_mins[l]
162*fb1b10abSAndroid Build Coastguard Worker    for r in xrange(self.num_row):
163*fb1b10abSAndroid Build Coastguard Worker      for c in xrange(self.num_col):
164*fb1b10abSAndroid Build Coastguard Worker        w_max = c_max[r, c] / (
165*fb1b10abSAndroid Build Coastguard Worker            self.k1 + self.k2 * min_ssds[r, c] + self.k3 * c_max[r, c])
166*fb1b10abSAndroid Build Coastguard Worker        w_min = c_min[r, c] / (
167*fb1b10abSAndroid Build Coastguard Worker            self.k1 + self.k2 * min_ssds[r, c] + self.k3 * c_min[r, c])
168*fb1b10abSAndroid Build Coastguard Worker        w = w_max * w_min / (w_max + w_min + 1e-6)
169*fb1b10abSAndroid Build Coastguard Worker        if w < 0:
170*fb1b10abSAndroid Build Coastguard Worker          w = 0
171*fb1b10abSAndroid Build Coastguard Worker        avg_uv = np.array([0.0, 0.0])
172*fb1b10abSAndroid Build Coastguard Worker        for i, j in {(r - 1, c), (r + 1, c), (r, c - 1), (r, c + 1)}:
173*fb1b10abSAndroid Build Coastguard Worker          if 0 <= i < self.num_row and 0 <= j < self.num_col:
174*fb1b10abSAndroid Build Coastguard Worker            avg_uv += 0.25 * uvs[i, j]
175*fb1b10abSAndroid Build Coastguard Worker        sm_uvs[r, c] = (w * w * mvs[r, c] + self.beta * avg_uv) / (
176*fb1b10abSAndroid Build Coastguard Worker            self.beta + w * w)
177*fb1b10abSAndroid Build Coastguard Worker    return sm_uvs
178*fb1b10abSAndroid Build Coastguard Worker
179*fb1b10abSAndroid Build Coastguard Worker  """
180*fb1b10abSAndroid Build Coastguard Worker    motion field estimation
181*fb1b10abSAndroid Build Coastguard Worker    """
182*fb1b10abSAndroid Build Coastguard Worker
183*fb1b10abSAndroid Build Coastguard Worker  def motion_field_estimation(self):
184*fb1b10abSAndroid Build Coastguard Worker    last_mvs = None
185*fb1b10abSAndroid Build Coastguard Worker    for l in xrange(self.levels, -1, -1):
186*fb1b10abSAndroid Build Coastguard Worker      mvs, min_ssds = self.region_match(l, last_mvs, 2**l)
187*fb1b10abSAndroid Build Coastguard Worker      uvs = np.zeros(mvs.shape)
188*fb1b10abSAndroid Build Coastguard Worker      for _ in xrange(self.max_iter):
189*fb1b10abSAndroid Build Coastguard Worker        uvs = self.smooth(uvs, mvs, min_ssds, l)
190*fb1b10abSAndroid Build Coastguard Worker      last_mvs = uvs
191*fb1b10abSAndroid Build Coastguard Worker    for r in xrange(self.num_row):
192*fb1b10abSAndroid Build Coastguard Worker      for c in xrange(self.num_col):
193*fb1b10abSAndroid Build Coastguard Worker        self.mf[r, c] = uvs[r, c]
194