1*fb1b10abSAndroid Build Coastguard Worker## Copyright (c) 2020 The WebM project authors. All Rights Reserved. 2*fb1b10abSAndroid Build Coastguard Worker## 3*fb1b10abSAndroid Build Coastguard Worker## Use of this source code is governed by a BSD-style license 4*fb1b10abSAndroid Build Coastguard Worker## that can be found in the LICENSE file in the root of the source 5*fb1b10abSAndroid Build Coastguard Worker## tree. An additional intellectual property rights grant can be found 6*fb1b10abSAndroid Build Coastguard Worker## in the file PATENTS. All contributing project authors may 7*fb1b10abSAndroid Build Coastguard Worker## be found in the AUTHORS file in the root of the source tree. 8*fb1b10abSAndroid Build Coastguard Worker## 9*fb1b10abSAndroid Build Coastguard Worker 10*fb1b10abSAndroid Build Coastguard Worker#coding : utf - 8 11*fb1b10abSAndroid Build Coastguard Workerimport numpy as np 12*fb1b10abSAndroid Build Coastguard Workerimport numpy.linalg as LA 13*fb1b10abSAndroid Build Coastguard Workerimport matplotlib.pyplot as plt 14*fb1b10abSAndroid Build Coastguard Workerfrom Util import drawMF, MSE 15*fb1b10abSAndroid Build Coastguard Worker"""The Base Class of Estimators""" 16*fb1b10abSAndroid Build Coastguard Worker 17*fb1b10abSAndroid Build Coastguard Worker 18*fb1b10abSAndroid Build Coastguard Workerclass MotionEST(object): 19*fb1b10abSAndroid Build Coastguard Worker """ 20*fb1b10abSAndroid Build Coastguard Worker constructor: 21*fb1b10abSAndroid Build Coastguard Worker cur_f: current frame 22*fb1b10abSAndroid Build Coastguard Worker ref_f: reference frame 23*fb1b10abSAndroid Build Coastguard Worker blk_sz: block size 24*fb1b10abSAndroid Build Coastguard Worker """ 25*fb1b10abSAndroid Build Coastguard Worker 26*fb1b10abSAndroid Build Coastguard Worker def __init__(self, cur_f, ref_f, blk_sz): 27*fb1b10abSAndroid Build Coastguard Worker self.cur_f = cur_f 28*fb1b10abSAndroid Build Coastguard Worker self.ref_f = ref_f 29*fb1b10abSAndroid Build Coastguard Worker self.blk_sz = blk_sz 30*fb1b10abSAndroid Build Coastguard Worker #convert RGB to YUV 31*fb1b10abSAndroid Build Coastguard Worker self.cur_yuv = np.array(self.cur_f.convert('YCbCr'), dtype=int) 32*fb1b10abSAndroid Build Coastguard Worker self.ref_yuv = np.array(self.ref_f.convert('YCbCr'), dtype=int) 33*fb1b10abSAndroid Build Coastguard Worker #frame size 34*fb1b10abSAndroid Build Coastguard Worker self.width = self.cur_f.size[0] 35*fb1b10abSAndroid Build Coastguard Worker self.height = self.cur_f.size[1] 36*fb1b10abSAndroid Build Coastguard Worker #motion field size 37*fb1b10abSAndroid Build Coastguard Worker self.num_row = self.height // self.blk_sz 38*fb1b10abSAndroid Build Coastguard Worker self.num_col = self.width // self.blk_sz 39*fb1b10abSAndroid Build Coastguard Worker #initialize motion field 40*fb1b10abSAndroid Build Coastguard Worker self.mf = np.zeros((self.num_row, self.num_col, 2)) 41*fb1b10abSAndroid Build Coastguard Worker 42*fb1b10abSAndroid Build Coastguard Worker """estimation function Override by child classes""" 43*fb1b10abSAndroid Build Coastguard Worker 44*fb1b10abSAndroid Build Coastguard Worker def motion_field_estimation(self): 45*fb1b10abSAndroid Build Coastguard Worker pass 46*fb1b10abSAndroid Build Coastguard Worker 47*fb1b10abSAndroid Build Coastguard Worker """ 48*fb1b10abSAndroid Build Coastguard Worker distortion of a block: 49*fb1b10abSAndroid Build Coastguard Worker cur_r: current row 50*fb1b10abSAndroid Build Coastguard Worker cur_c: current column 51*fb1b10abSAndroid Build Coastguard Worker mv: motion vector 52*fb1b10abSAndroid Build Coastguard Worker metric: distortion metric 53*fb1b10abSAndroid Build Coastguard Worker """ 54*fb1b10abSAndroid Build Coastguard Worker 55*fb1b10abSAndroid Build Coastguard Worker def block_dist(self, cur_r, cur_c, mv, metric=MSE): 56*fb1b10abSAndroid Build Coastguard Worker cur_x = cur_c * self.blk_sz 57*fb1b10abSAndroid Build Coastguard Worker cur_y = cur_r * self.blk_sz 58*fb1b10abSAndroid Build Coastguard Worker h = min(self.blk_sz, self.height - cur_y) 59*fb1b10abSAndroid Build Coastguard Worker w = min(self.blk_sz, self.width - cur_x) 60*fb1b10abSAndroid Build Coastguard Worker cur_blk = self.cur_yuv[cur_y:cur_y + h, cur_x:cur_x + w, :] 61*fb1b10abSAndroid Build Coastguard Worker ref_x = int(cur_x + mv[1]) 62*fb1b10abSAndroid Build Coastguard Worker ref_y = int(cur_y + mv[0]) 63*fb1b10abSAndroid Build Coastguard Worker if 0 <= ref_x < self.width - w and 0 <= ref_y < self.height - h: 64*fb1b10abSAndroid Build Coastguard Worker ref_blk = self.ref_yuv[ref_y:ref_y + h, ref_x:ref_x + w, :] 65*fb1b10abSAndroid Build Coastguard Worker else: 66*fb1b10abSAndroid Build Coastguard Worker ref_blk = np.zeros((h, w, 3)) 67*fb1b10abSAndroid Build Coastguard Worker return metric(cur_blk, ref_blk) 68*fb1b10abSAndroid Build Coastguard Worker 69*fb1b10abSAndroid Build Coastguard Worker """ 70*fb1b10abSAndroid Build Coastguard Worker distortion of motion field 71*fb1b10abSAndroid Build Coastguard Worker """ 72*fb1b10abSAndroid Build Coastguard Worker 73*fb1b10abSAndroid Build Coastguard Worker def distortion(self, mask=None, metric=MSE): 74*fb1b10abSAndroid Build Coastguard Worker loss = 0 75*fb1b10abSAndroid Build Coastguard Worker count = 0 76*fb1b10abSAndroid Build Coastguard Worker for i in xrange(self.num_row): 77*fb1b10abSAndroid Build Coastguard Worker for j in xrange(self.num_col): 78*fb1b10abSAndroid Build Coastguard Worker if mask is not None and mask[i, j]: 79*fb1b10abSAndroid Build Coastguard Worker continue 80*fb1b10abSAndroid Build Coastguard Worker loss += self.block_dist(i, j, self.mf[i, j], metric) 81*fb1b10abSAndroid Build Coastguard Worker count += 1 82*fb1b10abSAndroid Build Coastguard Worker return loss / count 83*fb1b10abSAndroid Build Coastguard Worker 84*fb1b10abSAndroid Build Coastguard Worker """evaluation compare the difference with ground truth""" 85*fb1b10abSAndroid Build Coastguard Worker 86*fb1b10abSAndroid Build Coastguard Worker def motion_field_evaluation(self, ground_truth): 87*fb1b10abSAndroid Build Coastguard Worker loss = 0 88*fb1b10abSAndroid Build Coastguard Worker count = 0 89*fb1b10abSAndroid Build Coastguard Worker gt = ground_truth.mf 90*fb1b10abSAndroid Build Coastguard Worker mask = ground_truth.mask 91*fb1b10abSAndroid Build Coastguard Worker for i in xrange(self.num_row): 92*fb1b10abSAndroid Build Coastguard Worker for j in xrange(self.num_col): 93*fb1b10abSAndroid Build Coastguard Worker if mask is not None and mask[i][j]: 94*fb1b10abSAndroid Build Coastguard Worker continue 95*fb1b10abSAndroid Build Coastguard Worker loss += LA.norm(gt[i, j] - self.mf[i, j]) 96*fb1b10abSAndroid Build Coastguard Worker count += 1 97*fb1b10abSAndroid Build Coastguard Worker return loss / count 98*fb1b10abSAndroid Build Coastguard Worker 99*fb1b10abSAndroid Build Coastguard Worker """render the motion field""" 100*fb1b10abSAndroid Build Coastguard Worker 101*fb1b10abSAndroid Build Coastguard Worker def show(self, ground_truth=None, size=10): 102*fb1b10abSAndroid Build Coastguard Worker cur_mf = drawMF(self.cur_f, self.blk_sz, self.mf) 103*fb1b10abSAndroid Build Coastguard Worker if ground_truth is None: 104*fb1b10abSAndroid Build Coastguard Worker n_row = 1 105*fb1b10abSAndroid Build Coastguard Worker else: 106*fb1b10abSAndroid Build Coastguard Worker gt_mf = drawMF(self.cur_f, self.blk_sz, ground_truth) 107*fb1b10abSAndroid Build Coastguard Worker n_row = 2 108*fb1b10abSAndroid Build Coastguard Worker plt.figure(figsize=(n_row * size, size * self.height / self.width)) 109*fb1b10abSAndroid Build Coastguard Worker plt.subplot(1, n_row, 1) 110*fb1b10abSAndroid Build Coastguard Worker plt.imshow(cur_mf) 111*fb1b10abSAndroid Build Coastguard Worker plt.title('Estimated Motion Field') 112*fb1b10abSAndroid Build Coastguard Worker if ground_truth is not None: 113*fb1b10abSAndroid Build Coastguard Worker plt.subplot(1, n_row, 2) 114*fb1b10abSAndroid Build Coastguard Worker plt.imshow(gt_mf) 115*fb1b10abSAndroid Build Coastguard Worker plt.title('Ground Truth') 116*fb1b10abSAndroid Build Coastguard Worker plt.tight_layout() 117*fb1b10abSAndroid Build Coastguard Worker plt.show() 118