1*fb1b10abSAndroid Build Coastguard Worker## Copyright (c) 2020 The WebM project authors. All Rights Reserved. 2*fb1b10abSAndroid Build Coastguard Worker## 3*fb1b10abSAndroid Build Coastguard Worker## Use of this source code is governed by a BSD-style license 4*fb1b10abSAndroid Build Coastguard Worker## that can be found in the LICENSE file in the root of the source 5*fb1b10abSAndroid Build Coastguard Worker## tree. An additional intellectual property rights grant can be found 6*fb1b10abSAndroid Build Coastguard Worker## in the file PATENTS. All contributing project authors may 7*fb1b10abSAndroid Build Coastguard Worker## be found in the AUTHORS file in the root of the source tree. 8*fb1b10abSAndroid Build Coastguard Worker## 9*fb1b10abSAndroid Build Coastguard Worker 10*fb1b10abSAndroid Build Coastguard Worker# coding: utf-8 11*fb1b10abSAndroid Build Coastguard Workerimport numpy as np 12*fb1b10abSAndroid Build Coastguard Workerimport numpy.linalg as LA 13*fb1b10abSAndroid Build Coastguard Workerfrom Util import MSE 14*fb1b10abSAndroid Build Coastguard Workerfrom MotionEST import MotionEST 15*fb1b10abSAndroid Build Coastguard Worker"""Exhaust Search:""" 16*fb1b10abSAndroid Build Coastguard Worker 17*fb1b10abSAndroid Build Coastguard Worker 18*fb1b10abSAndroid Build Coastguard Workerclass Exhaust(MotionEST): 19*fb1b10abSAndroid Build Coastguard Worker """ 20*fb1b10abSAndroid Build Coastguard Worker Constructor: 21*fb1b10abSAndroid Build Coastguard Worker cur_f: current frame 22*fb1b10abSAndroid Build Coastguard Worker ref_f: reference frame 23*fb1b10abSAndroid Build Coastguard Worker blk_sz: block size 24*fb1b10abSAndroid Build Coastguard Worker wnd_size: search window size 25*fb1b10abSAndroid Build Coastguard Worker metric: metric to compare the blocks distrotion 26*fb1b10abSAndroid Build Coastguard Worker """ 27*fb1b10abSAndroid Build Coastguard Worker 28*fb1b10abSAndroid Build Coastguard Worker def __init__(self, cur_f, ref_f, blk_size, wnd_size, metric=MSE): 29*fb1b10abSAndroid Build Coastguard Worker self.name = 'exhaust' 30*fb1b10abSAndroid Build Coastguard Worker self.wnd_sz = wnd_size 31*fb1b10abSAndroid Build Coastguard Worker self.metric = metric 32*fb1b10abSAndroid Build Coastguard Worker super(Exhaust, self).__init__(cur_f, ref_f, blk_size) 33*fb1b10abSAndroid Build Coastguard Worker 34*fb1b10abSAndroid Build Coastguard Worker """ 35*fb1b10abSAndroid Build Coastguard Worker search method: 36*fb1b10abSAndroid Build Coastguard Worker cur_r: start row 37*fb1b10abSAndroid Build Coastguard Worker cur_c: start column 38*fb1b10abSAndroid Build Coastguard Worker """ 39*fb1b10abSAndroid Build Coastguard Worker 40*fb1b10abSAndroid Build Coastguard Worker def search(self, cur_r, cur_c): 41*fb1b10abSAndroid Build Coastguard Worker min_loss = self.block_dist(cur_r, cur_c, [0, 0], self.metric) 42*fb1b10abSAndroid Build Coastguard Worker cur_x = cur_c * self.blk_sz 43*fb1b10abSAndroid Build Coastguard Worker cur_y = cur_r * self.blk_sz 44*fb1b10abSAndroid Build Coastguard Worker ref_x = cur_x 45*fb1b10abSAndroid Build Coastguard Worker ref_y = cur_y 46*fb1b10abSAndroid Build Coastguard Worker #search all validate positions and select the one with minimum distortion 47*fb1b10abSAndroid Build Coastguard Worker for y in xrange(cur_y - self.wnd_sz, cur_y + self.wnd_sz): 48*fb1b10abSAndroid Build Coastguard Worker for x in xrange(cur_x - self.wnd_sz, cur_x + self.wnd_sz): 49*fb1b10abSAndroid Build Coastguard Worker if 0 <= x < self.width - self.blk_sz and 0 <= y < self.height - self.blk_sz: 50*fb1b10abSAndroid Build Coastguard Worker loss = self.block_dist(cur_r, cur_c, [y - cur_y, x - cur_x], 51*fb1b10abSAndroid Build Coastguard Worker self.metric) 52*fb1b10abSAndroid Build Coastguard Worker if loss < min_loss: 53*fb1b10abSAndroid Build Coastguard Worker min_loss = loss 54*fb1b10abSAndroid Build Coastguard Worker ref_x = x 55*fb1b10abSAndroid Build Coastguard Worker ref_y = y 56*fb1b10abSAndroid Build Coastguard Worker return ref_x, ref_y 57*fb1b10abSAndroid Build Coastguard Worker 58*fb1b10abSAndroid Build Coastguard Worker def motion_field_estimation(self): 59*fb1b10abSAndroid Build Coastguard Worker for i in xrange(self.num_row): 60*fb1b10abSAndroid Build Coastguard Worker for j in xrange(self.num_col): 61*fb1b10abSAndroid Build Coastguard Worker ref_x, ref_y = self.search(i, j) 62*fb1b10abSAndroid Build Coastguard Worker self.mf[i, j] = np.array( 63*fb1b10abSAndroid Build Coastguard Worker [ref_y - i * self.blk_sz, ref_x - j * self.blk_sz]) 64*fb1b10abSAndroid Build Coastguard Worker 65*fb1b10abSAndroid Build Coastguard Worker 66*fb1b10abSAndroid Build Coastguard Worker"""Exhaust with Neighbor Constraint""" 67*fb1b10abSAndroid Build Coastguard Worker 68*fb1b10abSAndroid Build Coastguard Worker 69*fb1b10abSAndroid Build Coastguard Workerclass ExhaustNeighbor(MotionEST): 70*fb1b10abSAndroid Build Coastguard Worker """ 71*fb1b10abSAndroid Build Coastguard Worker Constructor: 72*fb1b10abSAndroid Build Coastguard Worker cur_f: current frame 73*fb1b10abSAndroid Build Coastguard Worker ref_f: reference frame 74*fb1b10abSAndroid Build Coastguard Worker blk_sz: block size 75*fb1b10abSAndroid Build Coastguard Worker wnd_size: search window size 76*fb1b10abSAndroid Build Coastguard Worker beta: neigbor loss weight 77*fb1b10abSAndroid Build Coastguard Worker metric: metric to compare the blocks distrotion 78*fb1b10abSAndroid Build Coastguard Worker """ 79*fb1b10abSAndroid Build Coastguard Worker 80*fb1b10abSAndroid Build Coastguard Worker def __init__(self, cur_f, ref_f, blk_size, wnd_size, beta, metric=MSE): 81*fb1b10abSAndroid Build Coastguard Worker self.name = 'exhaust + neighbor' 82*fb1b10abSAndroid Build Coastguard Worker self.wnd_sz = wnd_size 83*fb1b10abSAndroid Build Coastguard Worker self.beta = beta 84*fb1b10abSAndroid Build Coastguard Worker self.metric = metric 85*fb1b10abSAndroid Build Coastguard Worker super(ExhaustNeighbor, self).__init__(cur_f, ref_f, blk_size) 86*fb1b10abSAndroid Build Coastguard Worker self.assign = np.zeros((self.num_row, self.num_col), dtype=bool) 87*fb1b10abSAndroid Build Coastguard Worker 88*fb1b10abSAndroid Build Coastguard Worker """ 89*fb1b10abSAndroid Build Coastguard Worker estimate neighbor loss: 90*fb1b10abSAndroid Build Coastguard Worker cur_r: current row 91*fb1b10abSAndroid Build Coastguard Worker cur_c: current column 92*fb1b10abSAndroid Build Coastguard Worker mv: current motion vector 93*fb1b10abSAndroid Build Coastguard Worker """ 94*fb1b10abSAndroid Build Coastguard Worker 95*fb1b10abSAndroid Build Coastguard Worker def neighborLoss(self, cur_r, cur_c, mv): 96*fb1b10abSAndroid Build Coastguard Worker loss = 0 97*fb1b10abSAndroid Build Coastguard Worker #accumulate difference between current block's motion vector with neighbors' 98*fb1b10abSAndroid Build Coastguard Worker for i, j in {(-1, 0), (1, 0), (0, 1), (0, -1)}: 99*fb1b10abSAndroid Build Coastguard Worker nb_r = cur_r + i 100*fb1b10abSAndroid Build Coastguard Worker nb_c = cur_c + j 101*fb1b10abSAndroid Build Coastguard Worker if 0 <= nb_r < self.num_row and 0 <= nb_c < self.num_col and self.assign[ 102*fb1b10abSAndroid Build Coastguard Worker nb_r, nb_c]: 103*fb1b10abSAndroid Build Coastguard Worker loss += LA.norm(mv - self.mf[nb_r, nb_c]) 104*fb1b10abSAndroid Build Coastguard Worker return loss 105*fb1b10abSAndroid Build Coastguard Worker 106*fb1b10abSAndroid Build Coastguard Worker """ 107*fb1b10abSAndroid Build Coastguard Worker search method: 108*fb1b10abSAndroid Build Coastguard Worker cur_r: start row 109*fb1b10abSAndroid Build Coastguard Worker cur_c: start column 110*fb1b10abSAndroid Build Coastguard Worker """ 111*fb1b10abSAndroid Build Coastguard Worker 112*fb1b10abSAndroid Build Coastguard Worker def search(self, cur_r, cur_c): 113*fb1b10abSAndroid Build Coastguard Worker dist_loss = self.block_dist(cur_r, cur_c, [0, 0], self.metric) 114*fb1b10abSAndroid Build Coastguard Worker nb_loss = self.neighborLoss(cur_r, cur_c, np.array([0, 0])) 115*fb1b10abSAndroid Build Coastguard Worker min_loss = dist_loss + self.beta * nb_loss 116*fb1b10abSAndroid Build Coastguard Worker cur_x = cur_c * self.blk_sz 117*fb1b10abSAndroid Build Coastguard Worker cur_y = cur_r * self.blk_sz 118*fb1b10abSAndroid Build Coastguard Worker ref_x = cur_x 119*fb1b10abSAndroid Build Coastguard Worker ref_y = cur_y 120*fb1b10abSAndroid Build Coastguard Worker #search all validate positions and select the one with minimum distortion 121*fb1b10abSAndroid Build Coastguard Worker # as well as weighted neighbor loss 122*fb1b10abSAndroid Build Coastguard Worker for y in xrange(cur_y - self.wnd_sz, cur_y + self.wnd_sz): 123*fb1b10abSAndroid Build Coastguard Worker for x in xrange(cur_x - self.wnd_sz, cur_x + self.wnd_sz): 124*fb1b10abSAndroid Build Coastguard Worker if 0 <= x < self.width - self.blk_sz and 0 <= y < self.height - self.blk_sz: 125*fb1b10abSAndroid Build Coastguard Worker dist_loss = self.block_dist(cur_r, cur_c, [y - cur_y, x - cur_x], 126*fb1b10abSAndroid Build Coastguard Worker self.metric) 127*fb1b10abSAndroid Build Coastguard Worker nb_loss = self.neighborLoss(cur_r, cur_c, [y - cur_y, x - cur_x]) 128*fb1b10abSAndroid Build Coastguard Worker loss = dist_loss + self.beta * nb_loss 129*fb1b10abSAndroid Build Coastguard Worker if loss < min_loss: 130*fb1b10abSAndroid Build Coastguard Worker min_loss = loss 131*fb1b10abSAndroid Build Coastguard Worker ref_x = x 132*fb1b10abSAndroid Build Coastguard Worker ref_y = y 133*fb1b10abSAndroid Build Coastguard Worker return ref_x, ref_y 134*fb1b10abSAndroid Build Coastguard Worker 135*fb1b10abSAndroid Build Coastguard Worker def motion_field_estimation(self): 136*fb1b10abSAndroid Build Coastguard Worker for i in xrange(self.num_row): 137*fb1b10abSAndroid Build Coastguard Worker for j in xrange(self.num_col): 138*fb1b10abSAndroid Build Coastguard Worker ref_x, ref_y = self.search(i, j) 139*fb1b10abSAndroid Build Coastguard Worker self.mf[i, j] = np.array( 140*fb1b10abSAndroid Build Coastguard Worker [ref_y - i * self.blk_sz, ref_x - j * self.blk_sz]) 141*fb1b10abSAndroid Build Coastguard Worker self.assign[i, j] = True 142*fb1b10abSAndroid Build Coastguard Worker 143*fb1b10abSAndroid Build Coastguard Worker 144*fb1b10abSAndroid Build Coastguard Worker"""Exhaust with Neighbor Constraint and Feature Score""" 145*fb1b10abSAndroid Build Coastguard Worker 146*fb1b10abSAndroid Build Coastguard Worker 147*fb1b10abSAndroid Build Coastguard Workerclass ExhaustNeighborFeatureScore(MotionEST): 148*fb1b10abSAndroid Build Coastguard Worker """ 149*fb1b10abSAndroid Build Coastguard Worker Constructor: 150*fb1b10abSAndroid Build Coastguard Worker cur_f: current frame 151*fb1b10abSAndroid Build Coastguard Worker ref_f: reference frame 152*fb1b10abSAndroid Build Coastguard Worker blk_sz: block size 153*fb1b10abSAndroid Build Coastguard Worker wnd_size: search window size 154*fb1b10abSAndroid Build Coastguard Worker beta: neigbor loss weight 155*fb1b10abSAndroid Build Coastguard Worker max_iter: maximum number of iterations 156*fb1b10abSAndroid Build Coastguard Worker metric: metric to compare the blocks distrotion 157*fb1b10abSAndroid Build Coastguard Worker """ 158*fb1b10abSAndroid Build Coastguard Worker 159*fb1b10abSAndroid Build Coastguard Worker def __init__(self, 160*fb1b10abSAndroid Build Coastguard Worker cur_f, 161*fb1b10abSAndroid Build Coastguard Worker ref_f, 162*fb1b10abSAndroid Build Coastguard Worker blk_size, 163*fb1b10abSAndroid Build Coastguard Worker wnd_size, 164*fb1b10abSAndroid Build Coastguard Worker beta=1, 165*fb1b10abSAndroid Build Coastguard Worker max_iter=100, 166*fb1b10abSAndroid Build Coastguard Worker metric=MSE): 167*fb1b10abSAndroid Build Coastguard Worker self.name = 'exhaust + neighbor+feature score' 168*fb1b10abSAndroid Build Coastguard Worker self.wnd_sz = wnd_size 169*fb1b10abSAndroid Build Coastguard Worker self.beta = beta 170*fb1b10abSAndroid Build Coastguard Worker self.metric = metric 171*fb1b10abSAndroid Build Coastguard Worker self.max_iter = max_iter 172*fb1b10abSAndroid Build Coastguard Worker super(ExhaustNeighborFeatureScore, self).__init__(cur_f, ref_f, blk_size) 173*fb1b10abSAndroid Build Coastguard Worker self.fs = self.getFeatureScore() 174*fb1b10abSAndroid Build Coastguard Worker 175*fb1b10abSAndroid Build Coastguard Worker """ 176*fb1b10abSAndroid Build Coastguard Worker get feature score of each block 177*fb1b10abSAndroid Build Coastguard Worker """ 178*fb1b10abSAndroid Build Coastguard Worker 179*fb1b10abSAndroid Build Coastguard Worker def getFeatureScore(self): 180*fb1b10abSAndroid Build Coastguard Worker fs = np.zeros((self.num_row, self.num_col)) 181*fb1b10abSAndroid Build Coastguard Worker for r in xrange(self.num_row): 182*fb1b10abSAndroid Build Coastguard Worker for c in xrange(self.num_col): 183*fb1b10abSAndroid Build Coastguard Worker IxIx = 0 184*fb1b10abSAndroid Build Coastguard Worker IyIy = 0 185*fb1b10abSAndroid Build Coastguard Worker IxIy = 0 186*fb1b10abSAndroid Build Coastguard Worker #get ssd surface 187*fb1b10abSAndroid Build Coastguard Worker for x in xrange(self.blk_sz - 1): 188*fb1b10abSAndroid Build Coastguard Worker for y in xrange(self.blk_sz - 1): 189*fb1b10abSAndroid Build Coastguard Worker ox = c * self.blk_sz + x 190*fb1b10abSAndroid Build Coastguard Worker oy = r * self.blk_sz + y 191*fb1b10abSAndroid Build Coastguard Worker Ix = self.cur_yuv[oy, ox + 1, 0] - self.cur_yuv[oy, ox, 0] 192*fb1b10abSAndroid Build Coastguard Worker Iy = self.cur_yuv[oy + 1, ox, 0] - self.cur_yuv[oy, ox, 0] 193*fb1b10abSAndroid Build Coastguard Worker IxIx += Ix * Ix 194*fb1b10abSAndroid Build Coastguard Worker IyIy += Iy * Iy 195*fb1b10abSAndroid Build Coastguard Worker IxIy += Ix * Iy 196*fb1b10abSAndroid Build Coastguard Worker #get maximum and minimum eigenvalues 197*fb1b10abSAndroid Build Coastguard Worker lambda_max = 0.5 * ((IxIx + IyIy) + np.sqrt(4 * IxIy * IxIy + 198*fb1b10abSAndroid Build Coastguard Worker (IxIx - IyIy)**2)) 199*fb1b10abSAndroid Build Coastguard Worker lambda_min = 0.5 * ((IxIx + IyIy) - np.sqrt(4 * IxIy * IxIy + 200*fb1b10abSAndroid Build Coastguard Worker (IxIx - IyIy)**2)) 201*fb1b10abSAndroid Build Coastguard Worker fs[r, c] = lambda_max * lambda_min / (1e-6 + lambda_max + lambda_min) 202*fb1b10abSAndroid Build Coastguard Worker if fs[r, c] < 0: 203*fb1b10abSAndroid Build Coastguard Worker fs[r, c] = 0 204*fb1b10abSAndroid Build Coastguard Worker return fs 205*fb1b10abSAndroid Build Coastguard Worker 206*fb1b10abSAndroid Build Coastguard Worker """ 207*fb1b10abSAndroid Build Coastguard Worker do exhaust search 208*fb1b10abSAndroid Build Coastguard Worker """ 209*fb1b10abSAndroid Build Coastguard Worker 210*fb1b10abSAndroid Build Coastguard Worker def search(self, cur_r, cur_c): 211*fb1b10abSAndroid Build Coastguard Worker min_loss = self.block_dist(cur_r, cur_c, [0, 0], self.metric) 212*fb1b10abSAndroid Build Coastguard Worker cur_x = cur_c * self.blk_sz 213*fb1b10abSAndroid Build Coastguard Worker cur_y = cur_r * self.blk_sz 214*fb1b10abSAndroid Build Coastguard Worker ref_x = cur_x 215*fb1b10abSAndroid Build Coastguard Worker ref_y = cur_y 216*fb1b10abSAndroid Build Coastguard Worker #search all validate positions and select the one with minimum distortion 217*fb1b10abSAndroid Build Coastguard Worker for y in xrange(cur_y - self.wnd_sz, cur_y + self.wnd_sz): 218*fb1b10abSAndroid Build Coastguard Worker for x in xrange(cur_x - self.wnd_sz, cur_x + self.wnd_sz): 219*fb1b10abSAndroid Build Coastguard Worker if 0 <= x < self.width - self.blk_sz and 0 <= y < self.height - self.blk_sz: 220*fb1b10abSAndroid Build Coastguard Worker loss = self.block_dist(cur_r, cur_c, [y - cur_y, x - cur_x], 221*fb1b10abSAndroid Build Coastguard Worker self.metric) 222*fb1b10abSAndroid Build Coastguard Worker if loss < min_loss: 223*fb1b10abSAndroid Build Coastguard Worker min_loss = loss 224*fb1b10abSAndroid Build Coastguard Worker ref_x = x 225*fb1b10abSAndroid Build Coastguard Worker ref_y = y 226*fb1b10abSAndroid Build Coastguard Worker return ref_x, ref_y 227*fb1b10abSAndroid Build Coastguard Worker 228*fb1b10abSAndroid Build Coastguard Worker """ 229*fb1b10abSAndroid Build Coastguard Worker add smooth constraint 230*fb1b10abSAndroid Build Coastguard Worker """ 231*fb1b10abSAndroid Build Coastguard Worker 232*fb1b10abSAndroid Build Coastguard Worker def smooth(self, uvs, mvs): 233*fb1b10abSAndroid Build Coastguard Worker sm_uvs = np.zeros(uvs.shape) 234*fb1b10abSAndroid Build Coastguard Worker for r in xrange(self.num_row): 235*fb1b10abSAndroid Build Coastguard Worker for c in xrange(self.num_col): 236*fb1b10abSAndroid Build Coastguard Worker avg_uv = np.array([0.0, 0.0]) 237*fb1b10abSAndroid Build Coastguard Worker for i, j in {(r - 1, c), (r + 1, c), (r, c - 1), (r, c + 1)}: 238*fb1b10abSAndroid Build Coastguard Worker if 0 <= i < self.num_row and 0 <= j < self.num_col: 239*fb1b10abSAndroid Build Coastguard Worker avg_uv += uvs[i, j] / 6.0 240*fb1b10abSAndroid Build Coastguard Worker for i, j in {(r - 1, c - 1), (r - 1, c + 1), (r + 1, c - 1), 241*fb1b10abSAndroid Build Coastguard Worker (r + 1, c + 1)}: 242*fb1b10abSAndroid Build Coastguard Worker if 0 <= i < self.num_row and 0 <= j < self.num_col: 243*fb1b10abSAndroid Build Coastguard Worker avg_uv += uvs[i, j] / 12.0 244*fb1b10abSAndroid Build Coastguard Worker sm_uvs[r, c] = (self.fs[r, c] * mvs[r, c] + self.beta * avg_uv) / ( 245*fb1b10abSAndroid Build Coastguard Worker self.beta + self.fs[r, c]) 246*fb1b10abSAndroid Build Coastguard Worker return sm_uvs 247*fb1b10abSAndroid Build Coastguard Worker 248*fb1b10abSAndroid Build Coastguard Worker def motion_field_estimation(self): 249*fb1b10abSAndroid Build Coastguard Worker #get matching results 250*fb1b10abSAndroid Build Coastguard Worker mvs = np.zeros(self.mf.shape) 251*fb1b10abSAndroid Build Coastguard Worker for r in xrange(self.num_row): 252*fb1b10abSAndroid Build Coastguard Worker for c in xrange(self.num_col): 253*fb1b10abSAndroid Build Coastguard Worker ref_x, ref_y = self.search(r, c) 254*fb1b10abSAndroid Build Coastguard Worker mvs[r, c] = np.array([ref_y - r * self.blk_sz, ref_x - c * self.blk_sz]) 255*fb1b10abSAndroid Build Coastguard Worker #add smoothness constraint 256*fb1b10abSAndroid Build Coastguard Worker uvs = np.zeros(self.mf.shape) 257*fb1b10abSAndroid Build Coastguard Worker for _ in xrange(self.max_iter): 258*fb1b10abSAndroid Build Coastguard Worker uvs = self.smooth(uvs, mvs) 259*fb1b10abSAndroid Build Coastguard Worker self.mf = uvs 260