xref: /aosp_15_r20/external/libaom/av1/encoder/x86/highbd_temporal_filter_sse2.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2020, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <assert.h>
13 #include <emmintrin.h>
14 
15 #include "config/av1_rtcd.h"
16 #include "aom_dsp/mathutils.h"
17 #include "av1/encoder/encoder.h"
18 #include "av1/encoder/temporal_filter.h"
19 
20 // For the squared error buffer, keep a padding for 4 samples
21 #define SSE_STRIDE (BW + 4)
22 
23 DECLARE_ALIGNED(32, static const uint32_t, sse_bytemask_2x4[4][2][4]) = {
24   { { 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF },
25     { 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000 } },
26   { { 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF },
27     { 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000 } },
28   { { 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF },
29     { 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000 } },
30   { { 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF },
31     { 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF } }
32 };
33 
get_squared_error(const uint16_t * frame1,const unsigned int stride,const uint16_t * frame2,const unsigned int stride2,const int block_width,const int block_height,uint32_t * frame_sse,const unsigned int dst_stride)34 static void get_squared_error(const uint16_t *frame1, const unsigned int stride,
35                               const uint16_t *frame2,
36                               const unsigned int stride2, const int block_width,
37                               const int block_height, uint32_t *frame_sse,
38                               const unsigned int dst_stride) {
39   const uint16_t *src1 = frame1;
40   const uint16_t *src2 = frame2;
41   uint32_t *dst = frame_sse;
42 
43   for (int i = 0; i < block_height; i++) {
44     for (int j = 0; j < block_width; j += 8) {
45       __m128i vsrc1 = _mm_loadu_si128((__m128i *)(src1 + j));
46       __m128i vsrc2 = _mm_loadu_si128((__m128i *)(src2 + j));
47 
48       __m128i vdiff = _mm_sub_epi16(vsrc1, vsrc2);
49       __m128i vmullo = _mm_mullo_epi16(vdiff, vdiff);
50       __m128i vmullh = _mm_mulhi_epi16(vdiff, vdiff);
51 
52       __m128i vres1 = _mm_unpacklo_epi16(vmullo, vmullh);
53       __m128i vres2 = _mm_unpackhi_epi16(vmullo, vmullh);
54 
55       _mm_storeu_si128((__m128i *)(dst + j + 2), vres1);
56       _mm_storeu_si128((__m128i *)(dst + j + 6), vres2);
57     }
58 
59     src1 += stride;
60     src2 += stride2;
61     dst += dst_stride;
62   }
63 }
64 
xx_load_and_pad(uint32_t * src,__m128i * dstvec,int col,int block_width)65 static void xx_load_and_pad(uint32_t *src, __m128i *dstvec, int col,
66                             int block_width) {
67   __m128i vtmp1 = _mm_loadu_si128((__m128i *)src);
68   __m128i vtmp2 = _mm_loadu_si128((__m128i *)(src + 4));
69   // For the first column, replicate the first element twice to the left
70   dstvec[0] = (col) ? vtmp1 : _mm_shuffle_epi32(vtmp1, 0xEA);
71   // For the last column, replicate the last element twice to the right
72   dstvec[1] = (col < block_width - 4) ? vtmp2 : _mm_shuffle_epi32(vtmp2, 0x54);
73 }
74 
xx_mask_and_hadd(__m128i vsum1,__m128i vsum2,int i)75 static int32_t xx_mask_and_hadd(__m128i vsum1, __m128i vsum2, int i) {
76   __m128i veca, vecb;
77   // Mask and obtain the required 5 values inside the vector
78   veca = _mm_and_si128(vsum1, *(__m128i *)sse_bytemask_2x4[i][0]);
79   vecb = _mm_and_si128(vsum2, *(__m128i *)sse_bytemask_2x4[i][1]);
80   // A = [A0+B0, A1+B1, A2+B2, A3+B3]
81   veca = _mm_add_epi32(veca, vecb);
82   // B = [A2+B2, A3+B3, 0, 0]
83   vecb = _mm_srli_si128(veca, 8);
84   // A = [A0+B0+A2+B2, A1+B1+A3+B3, X, X]
85   veca = _mm_add_epi32(veca, vecb);
86   // B = [A1+B1+A3+B3, 0, 0, 0]
87   vecb = _mm_srli_si128(veca, 4);
88   // A = [A0+B0+A2+B2+A1+B1+A3+B3, X, X, X]
89   veca = _mm_add_epi32(veca, vecb);
90   return _mm_cvtsi128_si32(veca);
91 }
92 
highbd_apply_temporal_filter(const uint16_t * frame1,const unsigned int stride,const uint16_t * frame2,const unsigned int stride2,const int block_width,const int block_height,const int * subblock_mses,unsigned int * accumulator,uint16_t * count,uint32_t * frame_sse,uint32_t * luma_sse_sum,int bd,const double inv_num_ref_pixels,const double decay_factor,const double inv_factor,const double weight_factor,double * d_factor,int tf_wgt_calc_lvl)93 static void highbd_apply_temporal_filter(
94     const uint16_t *frame1, const unsigned int stride, const uint16_t *frame2,
95     const unsigned int stride2, const int block_width, const int block_height,
96     const int *subblock_mses, unsigned int *accumulator, uint16_t *count,
97     uint32_t *frame_sse, uint32_t *luma_sse_sum, int bd,
98     const double inv_num_ref_pixels, const double decay_factor,
99     const double inv_factor, const double weight_factor, double *d_factor,
100     int tf_wgt_calc_lvl) {
101   assert(((block_width == 16) || (block_width == 32)) &&
102          ((block_height == 16) || (block_height == 32)));
103 
104   uint32_t acc_5x5_sse[BH][BW];
105 
106   get_squared_error(frame1, stride, frame2, stride2, block_width, block_height,
107                     frame_sse, SSE_STRIDE);
108 
109   __m128i vsrc[5][2];
110 
111   // Traverse 4 columns at a time
112   // First and last columns will require padding
113   for (int col = 0; col < block_width; col += 4) {
114     uint32_t *src = frame_sse + col;
115 
116     // Load and pad(for first and last col) 3 rows from the top
117     for (int i = 2; i < 5; i++) {
118       xx_load_and_pad(src, vsrc[i], col, block_width);
119       src += SSE_STRIDE;
120     }
121 
122     // Padding for top 2 rows
123     vsrc[0][0] = vsrc[2][0];
124     vsrc[0][1] = vsrc[2][1];
125     vsrc[1][0] = vsrc[2][0];
126     vsrc[1][1] = vsrc[2][1];
127 
128     for (int row = 0; row < block_height - 3; row++) {
129       __m128i vsum11 = _mm_add_epi32(vsrc[0][0], vsrc[1][0]);
130       __m128i vsum12 = _mm_add_epi32(vsrc[2][0], vsrc[3][0]);
131       __m128i vsum13 = _mm_add_epi32(vsum11, vsum12);
132       __m128i vsum1 = _mm_add_epi32(vsum13, vsrc[4][0]);
133 
134       __m128i vsum21 = _mm_add_epi32(vsrc[0][1], vsrc[1][1]);
135       __m128i vsum22 = _mm_add_epi32(vsrc[2][1], vsrc[3][1]);
136       __m128i vsum23 = _mm_add_epi32(vsum21, vsum22);
137       __m128i vsum2 = _mm_add_epi32(vsum23, vsrc[4][1]);
138 
139       vsrc[0][0] = vsrc[1][0];
140       vsrc[0][1] = vsrc[1][1];
141       vsrc[1][0] = vsrc[2][0];
142       vsrc[1][1] = vsrc[2][1];
143       vsrc[2][0] = vsrc[3][0];
144       vsrc[2][1] = vsrc[3][1];
145       vsrc[3][0] = vsrc[4][0];
146       vsrc[3][1] = vsrc[4][1];
147 
148       // Load next row
149       xx_load_and_pad(src, vsrc[4], col, block_width);
150       src += SSE_STRIDE;
151 
152       acc_5x5_sse[row][col] = xx_mask_and_hadd(vsum1, vsum2, 0);
153       acc_5x5_sse[row][col + 1] = xx_mask_and_hadd(vsum1, vsum2, 1);
154       acc_5x5_sse[row][col + 2] = xx_mask_and_hadd(vsum1, vsum2, 2);
155       acc_5x5_sse[row][col + 3] = xx_mask_and_hadd(vsum1, vsum2, 3);
156     }
157     for (int row = block_height - 3; row < block_height; row++) {
158       __m128i vsum11 = _mm_add_epi32(vsrc[0][0], vsrc[1][0]);
159       __m128i vsum12 = _mm_add_epi32(vsrc[2][0], vsrc[3][0]);
160       __m128i vsum13 = _mm_add_epi32(vsum11, vsum12);
161       __m128i vsum1 = _mm_add_epi32(vsum13, vsrc[4][0]);
162 
163       __m128i vsum21 = _mm_add_epi32(vsrc[0][1], vsrc[1][1]);
164       __m128i vsum22 = _mm_add_epi32(vsrc[2][1], vsrc[3][1]);
165       __m128i vsum23 = _mm_add_epi32(vsum21, vsum22);
166       __m128i vsum2 = _mm_add_epi32(vsum23, vsrc[4][1]);
167 
168       vsrc[0][0] = vsrc[1][0];
169       vsrc[0][1] = vsrc[1][1];
170       vsrc[1][0] = vsrc[2][0];
171       vsrc[1][1] = vsrc[2][1];
172       vsrc[2][0] = vsrc[3][0];
173       vsrc[2][1] = vsrc[3][1];
174       vsrc[3][0] = vsrc[4][0];
175       vsrc[3][1] = vsrc[4][1];
176 
177       acc_5x5_sse[row][col] = xx_mask_and_hadd(vsum1, vsum2, 0);
178       acc_5x5_sse[row][col + 1] = xx_mask_and_hadd(vsum1, vsum2, 1);
179       acc_5x5_sse[row][col + 2] = xx_mask_and_hadd(vsum1, vsum2, 2);
180       acc_5x5_sse[row][col + 3] = xx_mask_and_hadd(vsum1, vsum2, 3);
181     }
182   }
183 
184   double subblock_mses_scaled[4];
185   double d_factor_decayed[4];
186   for (int idx = 0; idx < 4; idx++) {
187     subblock_mses_scaled[idx] = subblock_mses[idx] * inv_factor;
188     d_factor_decayed[idx] = d_factor[idx] * decay_factor;
189   }
190   if (tf_wgt_calc_lvl == 0) {
191     for (int i = 0, k = 0; i < block_height; i++) {
192       const int y_blk_raster_offset = (i >= block_height / 2) * 2;
193       for (int j = 0; j < block_width; j++, k++) {
194         const int pixel_value = frame2[i * stride2 + j];
195         uint32_t diff_sse = acc_5x5_sse[i][j] + luma_sse_sum[i * BW + j];
196 
197         // Scale down the difference for high bit depth input.
198         diff_sse >>= ((bd - 8) * 2);
199 
200         const double window_error = diff_sse * inv_num_ref_pixels;
201         const int subblock_idx = y_blk_raster_offset + (j >= block_width / 2);
202 
203         const double combined_error =
204             weight_factor * window_error + subblock_mses_scaled[subblock_idx];
205 
206         double scaled_error = combined_error * d_factor_decayed[subblock_idx];
207         scaled_error = AOMMIN(scaled_error, 7);
208         const int weight = (int)(exp(-scaled_error) * TF_WEIGHT_SCALE);
209 
210         count[k] += weight;
211         accumulator[k] += weight * pixel_value;
212       }
213     }
214   } else {
215     for (int i = 0, k = 0; i < block_height; i++) {
216       const int y_blk_raster_offset = (i >= block_height / 2) * 2;
217       for (int j = 0; j < block_width; j++, k++) {
218         const int pixel_value = frame2[i * stride2 + j];
219         uint32_t diff_sse = acc_5x5_sse[i][j] + luma_sse_sum[i * BW + j];
220 
221         // Scale down the difference for high bit depth input.
222         diff_sse >>= ((bd - 8) * 2);
223 
224         const double window_error = diff_sse * inv_num_ref_pixels;
225         const int subblock_idx = y_blk_raster_offset + (j >= block_width / 2);
226 
227         const double combined_error =
228             weight_factor * window_error + subblock_mses_scaled[subblock_idx];
229 
230         double scaled_error = combined_error * d_factor_decayed[subblock_idx];
231         scaled_error = AOMMIN(scaled_error, 7);
232         const float fweight =
233             approx_exp((float)-scaled_error) * TF_WEIGHT_SCALE;
234         const int weight = iroundpf(fweight);
235 
236         count[k] += weight;
237         accumulator[k] += weight * pixel_value;
238       }
239     }
240   }
241 }
242 
av1_highbd_apply_temporal_filter_sse2(const YV12_BUFFER_CONFIG * frame_to_filter,const MACROBLOCKD * mbd,const BLOCK_SIZE block_size,const int mb_row,const int mb_col,const int num_planes,const double * noise_levels,const MV * subblock_mvs,const int * subblock_mses,const int q_factor,const int filter_strength,int tf_wgt_calc_lvl,const uint8_t * pred,uint32_t * accum,uint16_t * count)243 void av1_highbd_apply_temporal_filter_sse2(
244     const YV12_BUFFER_CONFIG *frame_to_filter, const MACROBLOCKD *mbd,
245     const BLOCK_SIZE block_size, const int mb_row, const int mb_col,
246     const int num_planes, const double *noise_levels, const MV *subblock_mvs,
247     const int *subblock_mses, const int q_factor, const int filter_strength,
248     int tf_wgt_calc_lvl, const uint8_t *pred, uint32_t *accum,
249     uint16_t *count) {
250   const int is_high_bitdepth = frame_to_filter->flags & YV12_FLAG_HIGHBITDEPTH;
251   assert(block_size == BLOCK_32X32 && "Only support 32x32 block with sse2!");
252   assert(TF_WINDOW_LENGTH == 5 && "Only support window length 5 with sse2!");
253   assert(num_planes >= 1 && num_planes <= MAX_MB_PLANE);
254   (void)is_high_bitdepth;
255 
256   const int mb_height = block_size_high[block_size];
257   const int mb_width = block_size_wide[block_size];
258   const int frame_height = frame_to_filter->y_crop_height;
259   const int frame_width = frame_to_filter->y_crop_width;
260   const int min_frame_size = AOMMIN(frame_height, frame_width);
261   // Variables to simplify combined error calculation.
262   const double inv_factor = 1.0 / ((TF_WINDOW_BLOCK_BALANCE_WEIGHT + 1) *
263                                    TF_SEARCH_ERROR_NORM_WEIGHT);
264   const double weight_factor =
265       (double)TF_WINDOW_BLOCK_BALANCE_WEIGHT * inv_factor;
266   // Adjust filtering based on q.
267   // Larger q -> stronger filtering -> larger weight.
268   // Smaller q -> weaker filtering -> smaller weight.
269   double q_decay = pow((double)q_factor / TF_Q_DECAY_THRESHOLD, 2);
270   q_decay = CLIP(q_decay, 1e-5, 1);
271   if (q_factor >= TF_QINDEX_CUTOFF) {
272     // Max q_factor is 255, therefore the upper bound of q_decay is 8.
273     // We do not need a clip here.
274     q_decay = 0.5 * pow((double)q_factor / 64, 2);
275   }
276   // Smaller strength -> smaller filtering weight.
277   double s_decay = pow((double)filter_strength / TF_STRENGTH_THRESHOLD, 2);
278   s_decay = CLIP(s_decay, 1e-5, 1);
279   double d_factor[4] = { 0 };
280   uint32_t frame_sse[SSE_STRIDE * BH] = { 0 };
281   uint32_t luma_sse_sum[BW * BH] = { 0 };
282   uint16_t *pred1 = CONVERT_TO_SHORTPTR(pred);
283 
284   for (int subblock_idx = 0; subblock_idx < 4; subblock_idx++) {
285     // Larger motion vector -> smaller filtering weight.
286     const MV mv = subblock_mvs[subblock_idx];
287     const double distance = sqrt(pow(mv.row, 2) + pow(mv.col, 2));
288     double distance_threshold = min_frame_size * TF_SEARCH_DISTANCE_THRESHOLD;
289     distance_threshold = AOMMAX(distance_threshold, 1);
290     d_factor[subblock_idx] = distance / distance_threshold;
291     d_factor[subblock_idx] = AOMMAX(d_factor[subblock_idx], 1);
292   }
293 
294   // Handle planes in sequence.
295   int plane_offset = 0;
296   for (int plane = 0; plane < num_planes; ++plane) {
297     const uint32_t plane_h = mb_height >> mbd->plane[plane].subsampling_y;
298     const uint32_t plane_w = mb_width >> mbd->plane[plane].subsampling_x;
299     const uint32_t frame_stride = frame_to_filter->strides[plane == 0 ? 0 : 1];
300     const int frame_offset = mb_row * plane_h * frame_stride + mb_col * plane_w;
301 
302     const uint16_t *ref =
303         CONVERT_TO_SHORTPTR(frame_to_filter->buffers[plane]) + frame_offset;
304     const int ss_x_shift =
305         mbd->plane[plane].subsampling_x - mbd->plane[0].subsampling_x;
306     const int ss_y_shift =
307         mbd->plane[plane].subsampling_y - mbd->plane[0].subsampling_y;
308     const int num_ref_pixels = TF_WINDOW_LENGTH * TF_WINDOW_LENGTH +
309                                ((plane) ? (1 << (ss_x_shift + ss_y_shift)) : 0);
310     const double inv_num_ref_pixels = 1.0 / num_ref_pixels;
311     // Larger noise -> larger filtering weight.
312     const double n_decay = 0.5 + log(2 * noise_levels[plane] + 5.0);
313     // Decay factors for non-local mean approach.
314     const double decay_factor = 1 / (n_decay * q_decay * s_decay);
315 
316     // Filter U-plane and V-plane using Y-plane. This is because motion
317     // search is only done on Y-plane, so the information from Y-plane
318     // will be more accurate. The luma sse sum is reused in both chroma
319     // planes.
320     if (plane == AOM_PLANE_U) {
321       for (unsigned int i = 0, k = 0; i < plane_h; i++) {
322         for (unsigned int j = 0; j < plane_w; j++, k++) {
323           for (int ii = 0; ii < (1 << ss_y_shift); ++ii) {
324             for (int jj = 0; jj < (1 << ss_x_shift); ++jj) {
325               const int yy = (i << ss_y_shift) + ii;  // Y-coord on Y-plane.
326               const int xx = (j << ss_x_shift) + jj;  // X-coord on Y-plane.
327               luma_sse_sum[i * BW + j] += frame_sse[yy * SSE_STRIDE + xx + 2];
328             }
329           }
330         }
331       }
332     }
333 
334     highbd_apply_temporal_filter(
335         ref, frame_stride, pred1 + plane_offset, plane_w, plane_w, plane_h,
336         subblock_mses, accum + plane_offset, count + plane_offset, frame_sse,
337         luma_sse_sum, mbd->bd, inv_num_ref_pixels, decay_factor, inv_factor,
338         weight_factor, d_factor, tf_wgt_calc_lvl);
339     plane_offset += plane_h * plane_w;
340   }
341 }
342