1*fb1b10abSAndroid Build Coastguard Worker## Copyright (c) 2020 The WebM project authors. All Rights Reserved. 2*fb1b10abSAndroid Build Coastguard Worker## 3*fb1b10abSAndroid Build Coastguard Worker## Use of this source code is governed by a BSD-style license 4*fb1b10abSAndroid Build Coastguard Worker## that can be found in the LICENSE file in the root of the source 5*fb1b10abSAndroid Build Coastguard Worker## tree. An additional intellectual property rights grant can be found 6*fb1b10abSAndroid Build Coastguard Worker## in the file PATENTS. All contributing project authors may 7*fb1b10abSAndroid Build Coastguard Worker## be found in the AUTHORS file in the root of the source tree. 8*fb1b10abSAndroid Build Coastguard Worker## 9*fb1b10abSAndroid Build Coastguard Worker 10*fb1b10abSAndroid Build Coastguard Worker# coding: utf-8 11*fb1b10abSAndroid Build Coastguard Workerimport numpy as np 12*fb1b10abSAndroid Build Coastguard Workerimport numpy.linalg as LA 13*fb1b10abSAndroid Build Coastguard Workerfrom Util import MSE 14*fb1b10abSAndroid Build Coastguard Workerfrom MotionEST import MotionEST 15*fb1b10abSAndroid Build Coastguard Worker"""Search & Smooth Model with Adapt Weights""" 16*fb1b10abSAndroid Build Coastguard Worker 17*fb1b10abSAndroid Build Coastguard Worker 18*fb1b10abSAndroid Build Coastguard Workerclass SearchSmoothAdapt(MotionEST): 19*fb1b10abSAndroid Build Coastguard Worker """ 20*fb1b10abSAndroid Build Coastguard Worker Constructor: 21*fb1b10abSAndroid Build Coastguard Worker cur_f: current frame 22*fb1b10abSAndroid Build Coastguard Worker ref_f: reference frame 23*fb1b10abSAndroid Build Coastguard Worker blk_sz: block size 24*fb1b10abSAndroid Build Coastguard Worker wnd_size: search window size 25*fb1b10abSAndroid Build Coastguard Worker beta: neigbor loss weight 26*fb1b10abSAndroid Build Coastguard Worker max_iter: maximum number of iterations 27*fb1b10abSAndroid Build Coastguard Worker metric: metric to compare the blocks distrotion 28*fb1b10abSAndroid Build Coastguard Worker """ 29*fb1b10abSAndroid Build Coastguard Worker 30*fb1b10abSAndroid Build Coastguard Worker def __init__(self, cur_f, ref_f, blk_size, search, max_iter=100): 31*fb1b10abSAndroid Build Coastguard Worker self.search = search 32*fb1b10abSAndroid Build Coastguard Worker self.max_iter = max_iter 33*fb1b10abSAndroid Build Coastguard Worker super(SearchSmoothAdapt, self).__init__(cur_f, ref_f, blk_size) 34*fb1b10abSAndroid Build Coastguard Worker 35*fb1b10abSAndroid Build Coastguard Worker """ 36*fb1b10abSAndroid Build Coastguard Worker get local diffiencial of refernce 37*fb1b10abSAndroid Build Coastguard Worker """ 38*fb1b10abSAndroid Build Coastguard Worker 39*fb1b10abSAndroid Build Coastguard Worker def getRefLocalDiff(self, mvs): 40*fb1b10abSAndroid Build Coastguard Worker m, n = self.num_row, self.num_col 41*fb1b10abSAndroid Build Coastguard Worker localDiff = [[] for _ in xrange(m)] 42*fb1b10abSAndroid Build Coastguard Worker blk_sz = self.blk_sz 43*fb1b10abSAndroid Build Coastguard Worker for r in xrange(m): 44*fb1b10abSAndroid Build Coastguard Worker for c in xrange(n): 45*fb1b10abSAndroid Build Coastguard Worker I_row = 0 46*fb1b10abSAndroid Build Coastguard Worker I_col = 0 47*fb1b10abSAndroid Build Coastguard Worker #get ssd surface 48*fb1b10abSAndroid Build Coastguard Worker count = 0 49*fb1b10abSAndroid Build Coastguard Worker center = self.cur_yuv[r * blk_sz:(r + 1) * blk_sz, 50*fb1b10abSAndroid Build Coastguard Worker c * blk_sz:(c + 1) * blk_sz, 0] 51*fb1b10abSAndroid Build Coastguard Worker ty = np.clip(r * blk_sz + int(mvs[r, c, 0]), 0, self.height - blk_sz) 52*fb1b10abSAndroid Build Coastguard Worker tx = np.clip(c * blk_sz + int(mvs[r, c, 1]), 0, self.width - blk_sz) 53*fb1b10abSAndroid Build Coastguard Worker target = self.ref_yuv[ty:ty + blk_sz, tx:tx + blk_sz, 0] 54*fb1b10abSAndroid Build Coastguard Worker for y, x in {(ty - blk_sz, tx), (ty + blk_sz, tx)}: 55*fb1b10abSAndroid Build Coastguard Worker if 0 <= y < self.height - blk_sz and 0 <= x < self.width - blk_sz: 56*fb1b10abSAndroid Build Coastguard Worker nb = self.ref_yuv[y:y + blk_sz, x:x + blk_sz, 0] 57*fb1b10abSAndroid Build Coastguard Worker I_row += np.sum(np.abs(nb - center)) - np.sum( 58*fb1b10abSAndroid Build Coastguard Worker np.abs(target - center)) 59*fb1b10abSAndroid Build Coastguard Worker count += 1 60*fb1b10abSAndroid Build Coastguard Worker I_row //= (count * blk_sz * blk_sz) 61*fb1b10abSAndroid Build Coastguard Worker count = 0 62*fb1b10abSAndroid Build Coastguard Worker for y, x in {(ty, tx - blk_sz), (ty, tx + blk_sz)}: 63*fb1b10abSAndroid Build Coastguard Worker if 0 <= y < self.height - blk_sz and 0 <= x < self.width - blk_sz: 64*fb1b10abSAndroid Build Coastguard Worker nb = self.ref_yuv[y:y + blk_sz, x:x + blk_sz, 0] 65*fb1b10abSAndroid Build Coastguard Worker I_col += np.sum(np.abs(nb - center)) - np.sum( 66*fb1b10abSAndroid Build Coastguard Worker np.abs(target - center)) 67*fb1b10abSAndroid Build Coastguard Worker count += 1 68*fb1b10abSAndroid Build Coastguard Worker I_col //= (count * blk_sz * blk_sz) 69*fb1b10abSAndroid Build Coastguard Worker localDiff[r].append( 70*fb1b10abSAndroid Build Coastguard Worker np.array([[I_row * I_row, I_row * I_col], 71*fb1b10abSAndroid Build Coastguard Worker [I_col * I_row, I_col * I_col]])) 72*fb1b10abSAndroid Build Coastguard Worker return localDiff 73*fb1b10abSAndroid Build Coastguard Worker 74*fb1b10abSAndroid Build Coastguard Worker """ 75*fb1b10abSAndroid Build Coastguard Worker add smooth constraint 76*fb1b10abSAndroid Build Coastguard Worker """ 77*fb1b10abSAndroid Build Coastguard Worker 78*fb1b10abSAndroid Build Coastguard Worker def smooth(self, uvs, mvs): 79*fb1b10abSAndroid Build Coastguard Worker sm_uvs = np.zeros(uvs.shape) 80*fb1b10abSAndroid Build Coastguard Worker blk_sz = self.blk_sz 81*fb1b10abSAndroid Build Coastguard Worker for r in xrange(self.num_row): 82*fb1b10abSAndroid Build Coastguard Worker for c in xrange(self.num_col): 83*fb1b10abSAndroid Build Coastguard Worker nb_uv = np.array([0.0, 0.0]) 84*fb1b10abSAndroid Build Coastguard Worker for i, j in {(r - 1, c), (r + 1, c), (r, c - 1), (r, c + 1)}: 85*fb1b10abSAndroid Build Coastguard Worker if 0 <= i < self.num_row and 0 <= j < self.num_col: 86*fb1b10abSAndroid Build Coastguard Worker nb_uv += uvs[i, j] / 6.0 87*fb1b10abSAndroid Build Coastguard Worker else: 88*fb1b10abSAndroid Build Coastguard Worker nb_uv += uvs[r, c] / 6.0 89*fb1b10abSAndroid Build Coastguard Worker for i, j in {(r - 1, c - 1), (r - 1, c + 1), (r + 1, c - 1), 90*fb1b10abSAndroid Build Coastguard Worker (r + 1, c + 1)}: 91*fb1b10abSAndroid Build Coastguard Worker if 0 <= i < self.num_row and 0 <= j < self.num_col: 92*fb1b10abSAndroid Build Coastguard Worker nb_uv += uvs[i, j] / 12.0 93*fb1b10abSAndroid Build Coastguard Worker else: 94*fb1b10abSAndroid Build Coastguard Worker nb_uv += uvs[r, c] / 12.0 95*fb1b10abSAndroid Build Coastguard Worker ssd_nb = self.block_dist(r, c, self.blk_sz * nb_uv) 96*fb1b10abSAndroid Build Coastguard Worker mv = mvs[r, c] 97*fb1b10abSAndroid Build Coastguard Worker ssd_mv = self.block_dist(r, c, mv) 98*fb1b10abSAndroid Build Coastguard Worker alpha = (ssd_nb - ssd_mv) / (ssd_mv + 1e-6) 99*fb1b10abSAndroid Build Coastguard Worker M = alpha * self.localDiff[r][c] 100*fb1b10abSAndroid Build Coastguard Worker P = M + np.identity(2) 101*fb1b10abSAndroid Build Coastguard Worker inv_P = LA.inv(P) 102*fb1b10abSAndroid Build Coastguard Worker sm_uvs[r, c] = np.dot(inv_P, nb_uv) + np.dot( 103*fb1b10abSAndroid Build Coastguard Worker np.matmul(inv_P, M), mv / blk_sz) 104*fb1b10abSAndroid Build Coastguard Worker return sm_uvs 105*fb1b10abSAndroid Build Coastguard Worker 106*fb1b10abSAndroid Build Coastguard Worker def block_matching(self): 107*fb1b10abSAndroid Build Coastguard Worker self.search.motion_field_estimation() 108*fb1b10abSAndroid Build Coastguard Worker 109*fb1b10abSAndroid Build Coastguard Worker def motion_field_estimation(self): 110*fb1b10abSAndroid Build Coastguard Worker self.localDiff = self.getRefLocalDiff(self.search.mf) 111*fb1b10abSAndroid Build Coastguard Worker #get matching results 112*fb1b10abSAndroid Build Coastguard Worker mvs = self.search.mf 113*fb1b10abSAndroid Build Coastguard Worker #add smoothness constraint 114*fb1b10abSAndroid Build Coastguard Worker uvs = mvs / self.blk_sz 115*fb1b10abSAndroid Build Coastguard Worker for _ in xrange(self.max_iter): 116*fb1b10abSAndroid Build Coastguard Worker uvs = self.smooth(uvs, mvs) 117*fb1b10abSAndroid Build Coastguard Worker self.mf = uvs * self.blk_sz 118*fb1b10abSAndroid Build Coastguard Worker 119*fb1b10abSAndroid Build Coastguard Worker 120*fb1b10abSAndroid Build Coastguard Worker"""Search & Smooth Model with Fixed Weights""" 121*fb1b10abSAndroid Build Coastguard Worker 122*fb1b10abSAndroid Build Coastguard Worker 123*fb1b10abSAndroid Build Coastguard Workerclass SearchSmoothFix(MotionEST): 124*fb1b10abSAndroid Build Coastguard Worker """ 125*fb1b10abSAndroid Build Coastguard Worker Constructor: 126*fb1b10abSAndroid Build Coastguard Worker cur_f: current frame 127*fb1b10abSAndroid Build Coastguard Worker ref_f: reference frame 128*fb1b10abSAndroid Build Coastguard Worker blk_sz: block size 129*fb1b10abSAndroid Build Coastguard Worker wnd_size: search window size 130*fb1b10abSAndroid Build Coastguard Worker beta: neigbor loss weight 131*fb1b10abSAndroid Build Coastguard Worker max_iter: maximum number of iterations 132*fb1b10abSAndroid Build Coastguard Worker metric: metric to compare the blocks distrotion 133*fb1b10abSAndroid Build Coastguard Worker """ 134*fb1b10abSAndroid Build Coastguard Worker 135*fb1b10abSAndroid Build Coastguard Worker def __init__(self, cur_f, ref_f, blk_size, search, beta, max_iter=100): 136*fb1b10abSAndroid Build Coastguard Worker self.search = search 137*fb1b10abSAndroid Build Coastguard Worker self.max_iter = max_iter 138*fb1b10abSAndroid Build Coastguard Worker self.beta = beta 139*fb1b10abSAndroid Build Coastguard Worker super(SearchSmoothFix, self).__init__(cur_f, ref_f, blk_size) 140*fb1b10abSAndroid Build Coastguard Worker 141*fb1b10abSAndroid Build Coastguard Worker """ 142*fb1b10abSAndroid Build Coastguard Worker get local diffiencial of refernce 143*fb1b10abSAndroid Build Coastguard Worker """ 144*fb1b10abSAndroid Build Coastguard Worker 145*fb1b10abSAndroid Build Coastguard Worker def getRefLocalDiff(self, mvs): 146*fb1b10abSAndroid Build Coastguard Worker m, n = self.num_row, self.num_col 147*fb1b10abSAndroid Build Coastguard Worker localDiff = [[] for _ in xrange(m)] 148*fb1b10abSAndroid Build Coastguard Worker blk_sz = self.blk_sz 149*fb1b10abSAndroid Build Coastguard Worker for r in xrange(m): 150*fb1b10abSAndroid Build Coastguard Worker for c in xrange(n): 151*fb1b10abSAndroid Build Coastguard Worker I_row = 0 152*fb1b10abSAndroid Build Coastguard Worker I_col = 0 153*fb1b10abSAndroid Build Coastguard Worker #get ssd surface 154*fb1b10abSAndroid Build Coastguard Worker count = 0 155*fb1b10abSAndroid Build Coastguard Worker center = self.cur_yuv[r * blk_sz:(r + 1) * blk_sz, 156*fb1b10abSAndroid Build Coastguard Worker c * blk_sz:(c + 1) * blk_sz, 0] 157*fb1b10abSAndroid Build Coastguard Worker ty = np.clip(r * blk_sz + int(mvs[r, c, 0]), 0, self.height - blk_sz) 158*fb1b10abSAndroid Build Coastguard Worker tx = np.clip(c * blk_sz + int(mvs[r, c, 1]), 0, self.width - blk_sz) 159*fb1b10abSAndroid Build Coastguard Worker target = self.ref_yuv[ty:ty + blk_sz, tx:tx + blk_sz, 0] 160*fb1b10abSAndroid Build Coastguard Worker for y, x in {(ty - blk_sz, tx), (ty + blk_sz, tx)}: 161*fb1b10abSAndroid Build Coastguard Worker if 0 <= y < self.height - blk_sz and 0 <= x < self.width - blk_sz: 162*fb1b10abSAndroid Build Coastguard Worker nb = self.ref_yuv[y:y + blk_sz, x:x + blk_sz, 0] 163*fb1b10abSAndroid Build Coastguard Worker I_row += np.sum(np.abs(nb - center)) - np.sum( 164*fb1b10abSAndroid Build Coastguard Worker np.abs(target - center)) 165*fb1b10abSAndroid Build Coastguard Worker count += 1 166*fb1b10abSAndroid Build Coastguard Worker I_row //= (count * blk_sz * blk_sz) 167*fb1b10abSAndroid Build Coastguard Worker count = 0 168*fb1b10abSAndroid Build Coastguard Worker for y, x in {(ty, tx - blk_sz), (ty, tx + blk_sz)}: 169*fb1b10abSAndroid Build Coastguard Worker if 0 <= y < self.height - blk_sz and 0 <= x < self.width - blk_sz: 170*fb1b10abSAndroid Build Coastguard Worker nb = self.ref_yuv[y:y + blk_sz, x:x + blk_sz, 0] 171*fb1b10abSAndroid Build Coastguard Worker I_col += np.sum(np.abs(nb - center)) - np.sum( 172*fb1b10abSAndroid Build Coastguard Worker np.abs(target - center)) 173*fb1b10abSAndroid Build Coastguard Worker count += 1 174*fb1b10abSAndroid Build Coastguard Worker I_col //= (count * blk_sz * blk_sz) 175*fb1b10abSAndroid Build Coastguard Worker localDiff[r].append( 176*fb1b10abSAndroid Build Coastguard Worker np.array([[I_row * I_row, I_row * I_col], 177*fb1b10abSAndroid Build Coastguard Worker [I_col * I_row, I_col * I_col]])) 178*fb1b10abSAndroid Build Coastguard Worker return localDiff 179*fb1b10abSAndroid Build Coastguard Worker 180*fb1b10abSAndroid Build Coastguard Worker """ 181*fb1b10abSAndroid Build Coastguard Worker add smooth constraint 182*fb1b10abSAndroid Build Coastguard Worker """ 183*fb1b10abSAndroid Build Coastguard Worker 184*fb1b10abSAndroid Build Coastguard Worker def smooth(self, uvs, mvs): 185*fb1b10abSAndroid Build Coastguard Worker sm_uvs = np.zeros(uvs.shape) 186*fb1b10abSAndroid Build Coastguard Worker blk_sz = self.blk_sz 187*fb1b10abSAndroid Build Coastguard Worker for r in xrange(self.num_row): 188*fb1b10abSAndroid Build Coastguard Worker for c in xrange(self.num_col): 189*fb1b10abSAndroid Build Coastguard Worker nb_uv = np.array([0.0, 0.0]) 190*fb1b10abSAndroid Build Coastguard Worker for i, j in {(r - 1, c), (r + 1, c), (r, c - 1), (r, c + 1)}: 191*fb1b10abSAndroid Build Coastguard Worker if 0 <= i < self.num_row and 0 <= j < self.num_col: 192*fb1b10abSAndroid Build Coastguard Worker nb_uv += uvs[i, j] / 6.0 193*fb1b10abSAndroid Build Coastguard Worker else: 194*fb1b10abSAndroid Build Coastguard Worker nb_uv += uvs[r, c] / 6.0 195*fb1b10abSAndroid Build Coastguard Worker for i, j in {(r - 1, c - 1), (r - 1, c + 1), (r + 1, c - 1), 196*fb1b10abSAndroid Build Coastguard Worker (r + 1, c + 1)}: 197*fb1b10abSAndroid Build Coastguard Worker if 0 <= i < self.num_row and 0 <= j < self.num_col: 198*fb1b10abSAndroid Build Coastguard Worker nb_uv += uvs[i, j] / 12.0 199*fb1b10abSAndroid Build Coastguard Worker else: 200*fb1b10abSAndroid Build Coastguard Worker nb_uv += uvs[r, c] / 12.0 201*fb1b10abSAndroid Build Coastguard Worker mv = mvs[r, c] / blk_sz 202*fb1b10abSAndroid Build Coastguard Worker M = self.localDiff[r][c] 203*fb1b10abSAndroid Build Coastguard Worker P = M + self.beta * np.identity(2) 204*fb1b10abSAndroid Build Coastguard Worker inv_P = LA.inv(P) 205*fb1b10abSAndroid Build Coastguard Worker sm_uvs[r, c] = np.dot(inv_P, self.beta * nb_uv) + np.dot( 206*fb1b10abSAndroid Build Coastguard Worker np.matmul(inv_P, M), mv) 207*fb1b10abSAndroid Build Coastguard Worker return sm_uvs 208*fb1b10abSAndroid Build Coastguard Worker 209*fb1b10abSAndroid Build Coastguard Worker def block_matching(self): 210*fb1b10abSAndroid Build Coastguard Worker self.search.motion_field_estimation() 211*fb1b10abSAndroid Build Coastguard Worker 212*fb1b10abSAndroid Build Coastguard Worker def motion_field_estimation(self): 213*fb1b10abSAndroid Build Coastguard Worker #get local structure 214*fb1b10abSAndroid Build Coastguard Worker self.localDiff = self.getRefLocalDiff(self.search.mf) 215*fb1b10abSAndroid Build Coastguard Worker #get matching results 216*fb1b10abSAndroid Build Coastguard Worker mvs = self.search.mf 217*fb1b10abSAndroid Build Coastguard Worker #add smoothness constraint 218*fb1b10abSAndroid Build Coastguard Worker uvs = mvs / self.blk_sz 219*fb1b10abSAndroid Build Coastguard Worker for _ in xrange(self.max_iter): 220*fb1b10abSAndroid Build Coastguard Worker uvs = self.smooth(uvs, mvs) 221*fb1b10abSAndroid Build Coastguard Worker self.mf = uvs * self.blk_sz 222