xref: /aosp_15_r20/external/libvpx/tools/3D-Reconstruction/MotionEST/MotionEST.py (revision fb1b10ab9aebc7c7068eedab379b749d7e3900be)
1*fb1b10abSAndroid Build Coastguard Worker##  Copyright (c) 2020 The WebM project authors. All Rights Reserved.
2*fb1b10abSAndroid Build Coastguard Worker##
3*fb1b10abSAndroid Build Coastguard Worker##  Use of this source code is governed by a BSD-style license
4*fb1b10abSAndroid Build Coastguard Worker##  that can be found in the LICENSE file in the root of the source
5*fb1b10abSAndroid Build Coastguard Worker##  tree. An additional intellectual property rights grant can be found
6*fb1b10abSAndroid Build Coastguard Worker##  in the file PATENTS.  All contributing project authors may
7*fb1b10abSAndroid Build Coastguard Worker##  be found in the AUTHORS file in the root of the source tree.
8*fb1b10abSAndroid Build Coastguard Worker##
9*fb1b10abSAndroid Build Coastguard Worker
10*fb1b10abSAndroid Build Coastguard Worker#coding : utf - 8
11*fb1b10abSAndroid Build Coastguard Workerimport numpy as np
12*fb1b10abSAndroid Build Coastguard Workerimport numpy.linalg as LA
13*fb1b10abSAndroid Build Coastguard Workerimport matplotlib.pyplot as plt
14*fb1b10abSAndroid Build Coastguard Workerfrom Util import drawMF, MSE
15*fb1b10abSAndroid Build Coastguard Worker"""The Base Class of Estimators"""
16*fb1b10abSAndroid Build Coastguard Worker
17*fb1b10abSAndroid Build Coastguard Worker
18*fb1b10abSAndroid Build Coastguard Workerclass MotionEST(object):
19*fb1b10abSAndroid Build Coastguard Worker  """
20*fb1b10abSAndroid Build Coastguard Worker    constructor:
21*fb1b10abSAndroid Build Coastguard Worker        cur_f: current frame
22*fb1b10abSAndroid Build Coastguard Worker        ref_f: reference frame
23*fb1b10abSAndroid Build Coastguard Worker        blk_sz: block size
24*fb1b10abSAndroid Build Coastguard Worker    """
25*fb1b10abSAndroid Build Coastguard Worker
26*fb1b10abSAndroid Build Coastguard Worker  def __init__(self, cur_f, ref_f, blk_sz):
27*fb1b10abSAndroid Build Coastguard Worker    self.cur_f = cur_f
28*fb1b10abSAndroid Build Coastguard Worker    self.ref_f = ref_f
29*fb1b10abSAndroid Build Coastguard Worker    self.blk_sz = blk_sz
30*fb1b10abSAndroid Build Coastguard Worker    #convert RGB to YUV
31*fb1b10abSAndroid Build Coastguard Worker    self.cur_yuv = np.array(self.cur_f.convert('YCbCr'), dtype=int)
32*fb1b10abSAndroid Build Coastguard Worker    self.ref_yuv = np.array(self.ref_f.convert('YCbCr'), dtype=int)
33*fb1b10abSAndroid Build Coastguard Worker    #frame size
34*fb1b10abSAndroid Build Coastguard Worker    self.width = self.cur_f.size[0]
35*fb1b10abSAndroid Build Coastguard Worker    self.height = self.cur_f.size[1]
36*fb1b10abSAndroid Build Coastguard Worker    #motion field size
37*fb1b10abSAndroid Build Coastguard Worker    self.num_row = self.height // self.blk_sz
38*fb1b10abSAndroid Build Coastguard Worker    self.num_col = self.width // self.blk_sz
39*fb1b10abSAndroid Build Coastguard Worker    #initialize motion field
40*fb1b10abSAndroid Build Coastguard Worker    self.mf = np.zeros((self.num_row, self.num_col, 2))
41*fb1b10abSAndroid Build Coastguard Worker
42*fb1b10abSAndroid Build Coastguard Worker  """estimation function Override by child classes"""
43*fb1b10abSAndroid Build Coastguard Worker
44*fb1b10abSAndroid Build Coastguard Worker  def motion_field_estimation(self):
45*fb1b10abSAndroid Build Coastguard Worker    pass
46*fb1b10abSAndroid Build Coastguard Worker
47*fb1b10abSAndroid Build Coastguard Worker  """
48*fb1b10abSAndroid Build Coastguard Worker    distortion of a block:
49*fb1b10abSAndroid Build Coastguard Worker      cur_r: current row
50*fb1b10abSAndroid Build Coastguard Worker      cur_c: current column
51*fb1b10abSAndroid Build Coastguard Worker      mv: motion vector
52*fb1b10abSAndroid Build Coastguard Worker      metric: distortion metric
53*fb1b10abSAndroid Build Coastguard Worker  """
54*fb1b10abSAndroid Build Coastguard Worker
55*fb1b10abSAndroid Build Coastguard Worker  def block_dist(self, cur_r, cur_c, mv, metric=MSE):
56*fb1b10abSAndroid Build Coastguard Worker    cur_x = cur_c * self.blk_sz
57*fb1b10abSAndroid Build Coastguard Worker    cur_y = cur_r * self.blk_sz
58*fb1b10abSAndroid Build Coastguard Worker    h = min(self.blk_sz, self.height - cur_y)
59*fb1b10abSAndroid Build Coastguard Worker    w = min(self.blk_sz, self.width - cur_x)
60*fb1b10abSAndroid Build Coastguard Worker    cur_blk = self.cur_yuv[cur_y:cur_y + h, cur_x:cur_x + w, :]
61*fb1b10abSAndroid Build Coastguard Worker    ref_x = int(cur_x + mv[1])
62*fb1b10abSAndroid Build Coastguard Worker    ref_y = int(cur_y + mv[0])
63*fb1b10abSAndroid Build Coastguard Worker    if 0 <= ref_x < self.width - w and 0 <= ref_y < self.height - h:
64*fb1b10abSAndroid Build Coastguard Worker      ref_blk = self.ref_yuv[ref_y:ref_y + h, ref_x:ref_x + w, :]
65*fb1b10abSAndroid Build Coastguard Worker    else:
66*fb1b10abSAndroid Build Coastguard Worker      ref_blk = np.zeros((h, w, 3))
67*fb1b10abSAndroid Build Coastguard Worker    return metric(cur_blk, ref_blk)
68*fb1b10abSAndroid Build Coastguard Worker
69*fb1b10abSAndroid Build Coastguard Worker  """
70*fb1b10abSAndroid Build Coastguard Worker    distortion of motion field
71*fb1b10abSAndroid Build Coastguard Worker  """
72*fb1b10abSAndroid Build Coastguard Worker
73*fb1b10abSAndroid Build Coastguard Worker  def distortion(self, mask=None, metric=MSE):
74*fb1b10abSAndroid Build Coastguard Worker    loss = 0
75*fb1b10abSAndroid Build Coastguard Worker    count = 0
76*fb1b10abSAndroid Build Coastguard Worker    for i in xrange(self.num_row):
77*fb1b10abSAndroid Build Coastguard Worker      for j in xrange(self.num_col):
78*fb1b10abSAndroid Build Coastguard Worker        if mask is not None and mask[i, j]:
79*fb1b10abSAndroid Build Coastguard Worker          continue
80*fb1b10abSAndroid Build Coastguard Worker        loss += self.block_dist(i, j, self.mf[i, j], metric)
81*fb1b10abSAndroid Build Coastguard Worker        count += 1
82*fb1b10abSAndroid Build Coastguard Worker    return loss / count
83*fb1b10abSAndroid Build Coastguard Worker
84*fb1b10abSAndroid Build Coastguard Worker  """evaluation compare the difference with ground truth"""
85*fb1b10abSAndroid Build Coastguard Worker
86*fb1b10abSAndroid Build Coastguard Worker  def motion_field_evaluation(self, ground_truth):
87*fb1b10abSAndroid Build Coastguard Worker    loss = 0
88*fb1b10abSAndroid Build Coastguard Worker    count = 0
89*fb1b10abSAndroid Build Coastguard Worker    gt = ground_truth.mf
90*fb1b10abSAndroid Build Coastguard Worker    mask = ground_truth.mask
91*fb1b10abSAndroid Build Coastguard Worker    for i in xrange(self.num_row):
92*fb1b10abSAndroid Build Coastguard Worker      for j in xrange(self.num_col):
93*fb1b10abSAndroid Build Coastguard Worker        if mask is not None and mask[i][j]:
94*fb1b10abSAndroid Build Coastguard Worker          continue
95*fb1b10abSAndroid Build Coastguard Worker        loss += LA.norm(gt[i, j] - self.mf[i, j])
96*fb1b10abSAndroid Build Coastguard Worker        count += 1
97*fb1b10abSAndroid Build Coastguard Worker    return loss / count
98*fb1b10abSAndroid Build Coastguard Worker
99*fb1b10abSAndroid Build Coastguard Worker  """render the motion field"""
100*fb1b10abSAndroid Build Coastguard Worker
101*fb1b10abSAndroid Build Coastguard Worker  def show(self, ground_truth=None, size=10):
102*fb1b10abSAndroid Build Coastguard Worker    cur_mf = drawMF(self.cur_f, self.blk_sz, self.mf)
103*fb1b10abSAndroid Build Coastguard Worker    if ground_truth is None:
104*fb1b10abSAndroid Build Coastguard Worker      n_row = 1
105*fb1b10abSAndroid Build Coastguard Worker    else:
106*fb1b10abSAndroid Build Coastguard Worker      gt_mf = drawMF(self.cur_f, self.blk_sz, ground_truth)
107*fb1b10abSAndroid Build Coastguard Worker      n_row = 2
108*fb1b10abSAndroid Build Coastguard Worker    plt.figure(figsize=(n_row * size, size * self.height / self.width))
109*fb1b10abSAndroid Build Coastguard Worker    plt.subplot(1, n_row, 1)
110*fb1b10abSAndroid Build Coastguard Worker    plt.imshow(cur_mf)
111*fb1b10abSAndroid Build Coastguard Worker    plt.title('Estimated Motion Field')
112*fb1b10abSAndroid Build Coastguard Worker    if ground_truth is not None:
113*fb1b10abSAndroid Build Coastguard Worker      plt.subplot(1, n_row, 2)
114*fb1b10abSAndroid Build Coastguard Worker      plt.imshow(gt_mf)
115*fb1b10abSAndroid Build Coastguard Worker      plt.title('Ground Truth')
116*fb1b10abSAndroid Build Coastguard Worker    plt.tight_layout()
117*fb1b10abSAndroid Build Coastguard Worker    plt.show()
118