xref: /aosp_15_r20/external/libaom/av1/encoder/tx_search.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1*77c1e3ccSAndroid Build Coastguard Worker /*
2*77c1e3ccSAndroid Build Coastguard Worker  * Copyright (c) 2020, Alliance for Open Media. All rights reserved.
3*77c1e3ccSAndroid Build Coastguard Worker  *
4*77c1e3ccSAndroid Build Coastguard Worker  * This source code is subject to the terms of the BSD 2 Clause License and
5*77c1e3ccSAndroid Build Coastguard Worker  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6*77c1e3ccSAndroid Build Coastguard Worker  * was not distributed with this source code in the LICENSE file, you can
7*77c1e3ccSAndroid Build Coastguard Worker  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8*77c1e3ccSAndroid Build Coastguard Worker  * Media Patent License 1.0 was not distributed with this source code in the
9*77c1e3ccSAndroid Build Coastguard Worker  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10*77c1e3ccSAndroid Build Coastguard Worker  */
11*77c1e3ccSAndroid Build Coastguard Worker 
12*77c1e3ccSAndroid Build Coastguard Worker #include "av1/common/cfl.h"
13*77c1e3ccSAndroid Build Coastguard Worker #include "av1/common/reconintra.h"
14*77c1e3ccSAndroid Build Coastguard Worker #include "av1/encoder/block.h"
15*77c1e3ccSAndroid Build Coastguard Worker #include "av1/encoder/hybrid_fwd_txfm.h"
16*77c1e3ccSAndroid Build Coastguard Worker #include "av1/common/idct.h"
17*77c1e3ccSAndroid Build Coastguard Worker #include "av1/encoder/model_rd.h"
18*77c1e3ccSAndroid Build Coastguard Worker #include "av1/encoder/random.h"
19*77c1e3ccSAndroid Build Coastguard Worker #include "av1/encoder/rdopt_utils.h"
20*77c1e3ccSAndroid Build Coastguard Worker #include "av1/encoder/sorting_network.h"
21*77c1e3ccSAndroid Build Coastguard Worker #include "av1/encoder/tx_prune_model_weights.h"
22*77c1e3ccSAndroid Build Coastguard Worker #include "av1/encoder/tx_search.h"
23*77c1e3ccSAndroid Build Coastguard Worker #include "av1/encoder/txb_rdopt.h"
24*77c1e3ccSAndroid Build Coastguard Worker 
25*77c1e3ccSAndroid Build Coastguard Worker #define PROB_THRESH_OFFSET_TX_TYPE 100
26*77c1e3ccSAndroid Build Coastguard Worker 
27*77c1e3ccSAndroid Build Coastguard Worker struct rdcost_block_args {
28*77c1e3ccSAndroid Build Coastguard Worker   const AV1_COMP *cpi;
29*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCK *x;
30*77c1e3ccSAndroid Build Coastguard Worker   ENTROPY_CONTEXT t_above[MAX_MIB_SIZE];
31*77c1e3ccSAndroid Build Coastguard Worker   ENTROPY_CONTEXT t_left[MAX_MIB_SIZE];
32*77c1e3ccSAndroid Build Coastguard Worker   RD_STATS rd_stats;
33*77c1e3ccSAndroid Build Coastguard Worker   int64_t current_rd;
34*77c1e3ccSAndroid Build Coastguard Worker   int64_t best_rd;
35*77c1e3ccSAndroid Build Coastguard Worker   int exit_early;
36*77c1e3ccSAndroid Build Coastguard Worker   int incomplete_exit;
37*77c1e3ccSAndroid Build Coastguard Worker   FAST_TX_SEARCH_MODE ftxs_mode;
38*77c1e3ccSAndroid Build Coastguard Worker   int skip_trellis;
39*77c1e3ccSAndroid Build Coastguard Worker };
40*77c1e3ccSAndroid Build Coastguard Worker 
41*77c1e3ccSAndroid Build Coastguard Worker typedef struct {
42*77c1e3ccSAndroid Build Coastguard Worker   int64_t rd;
43*77c1e3ccSAndroid Build Coastguard Worker   int txb_entropy_ctx;
44*77c1e3ccSAndroid Build Coastguard Worker   TX_TYPE tx_type;
45*77c1e3ccSAndroid Build Coastguard Worker } TxCandidateInfo;
46*77c1e3ccSAndroid Build Coastguard Worker 
47*77c1e3ccSAndroid Build Coastguard Worker // origin_threshold * 128 / 100
48*77c1e3ccSAndroid Build Coastguard Worker static const uint32_t skip_pred_threshold[3][BLOCK_SIZES_ALL] = {
49*77c1e3ccSAndroid Build Coastguard Worker   {
50*77c1e3ccSAndroid Build Coastguard Worker       64, 64, 64, 70, 60, 60, 68, 68, 68, 68, 68,
51*77c1e3ccSAndroid Build Coastguard Worker       68, 68, 68, 68, 68, 64, 64, 70, 70, 68, 68,
52*77c1e3ccSAndroid Build Coastguard Worker   },
53*77c1e3ccSAndroid Build Coastguard Worker   {
54*77c1e3ccSAndroid Build Coastguard Worker       88, 88, 88, 86, 87, 87, 68, 68, 68, 68, 68,
55*77c1e3ccSAndroid Build Coastguard Worker       68, 68, 68, 68, 68, 88, 88, 86, 86, 68, 68,
56*77c1e3ccSAndroid Build Coastguard Worker   },
57*77c1e3ccSAndroid Build Coastguard Worker   {
58*77c1e3ccSAndroid Build Coastguard Worker       90, 93, 93, 90, 93, 93, 74, 74, 74, 74, 74,
59*77c1e3ccSAndroid Build Coastguard Worker       74, 74, 74, 74, 74, 90, 90, 90, 90, 74, 74,
60*77c1e3ccSAndroid Build Coastguard Worker   },
61*77c1e3ccSAndroid Build Coastguard Worker };
62*77c1e3ccSAndroid Build Coastguard Worker 
63*77c1e3ccSAndroid Build Coastguard Worker // lookup table for predict_skip_txfm
64*77c1e3ccSAndroid Build Coastguard Worker // int max_tx_size = max_txsize_rect_lookup[bsize];
65*77c1e3ccSAndroid Build Coastguard Worker // if (tx_size_high[max_tx_size] > 16 || tx_size_wide[max_tx_size] > 16)
66*77c1e3ccSAndroid Build Coastguard Worker //   max_tx_size = AOMMIN(max_txsize_lookup[bsize], TX_16X16);
67*77c1e3ccSAndroid Build Coastguard Worker static const TX_SIZE max_predict_sf_tx_size[BLOCK_SIZES_ALL] = {
68*77c1e3ccSAndroid Build Coastguard Worker   TX_4X4,   TX_4X8,   TX_8X4,   TX_8X8,   TX_8X16,  TX_16X8,
69*77c1e3ccSAndroid Build Coastguard Worker   TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16,
70*77c1e3ccSAndroid Build Coastguard Worker   TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_4X16,  TX_16X4,
71*77c1e3ccSAndroid Build Coastguard Worker   TX_8X8,   TX_8X8,   TX_16X16, TX_16X16,
72*77c1e3ccSAndroid Build Coastguard Worker };
73*77c1e3ccSAndroid Build Coastguard Worker 
74*77c1e3ccSAndroid Build Coastguard Worker // look-up table for sqrt of number of pixels in a transform block
75*77c1e3ccSAndroid Build Coastguard Worker // rounded up to the nearest integer.
76*77c1e3ccSAndroid Build Coastguard Worker static const int sqrt_tx_pixels_2d[TX_SIZES_ALL] = { 4,  8,  16, 32, 32, 6,  6,
77*77c1e3ccSAndroid Build Coastguard Worker                                                      12, 12, 23, 23, 32, 32, 8,
78*77c1e3ccSAndroid Build Coastguard Worker                                                      8,  16, 16, 23, 23 };
79*77c1e3ccSAndroid Build Coastguard Worker 
get_block_residue_hash(MACROBLOCK * x,BLOCK_SIZE bsize)80*77c1e3ccSAndroid Build Coastguard Worker static inline uint32_t get_block_residue_hash(MACROBLOCK *x, BLOCK_SIZE bsize) {
81*77c1e3ccSAndroid Build Coastguard Worker   const int rows = block_size_high[bsize];
82*77c1e3ccSAndroid Build Coastguard Worker   const int cols = block_size_wide[bsize];
83*77c1e3ccSAndroid Build Coastguard Worker   const int16_t *diff = x->plane[0].src_diff;
84*77c1e3ccSAndroid Build Coastguard Worker   const uint32_t hash =
85*77c1e3ccSAndroid Build Coastguard Worker       av1_get_crc32c_value(&x->txfm_search_info.mb_rd_record->crc_calculator,
86*77c1e3ccSAndroid Build Coastguard Worker                            (uint8_t *)diff, 2 * rows * cols);
87*77c1e3ccSAndroid Build Coastguard Worker   return (hash << 5) + bsize;
88*77c1e3ccSAndroid Build Coastguard Worker }
89*77c1e3ccSAndroid Build Coastguard Worker 
find_mb_rd_info(const MB_RD_RECORD * const mb_rd_record,const int64_t ref_best_rd,const uint32_t hash)90*77c1e3ccSAndroid Build Coastguard Worker static inline int32_t find_mb_rd_info(const MB_RD_RECORD *const mb_rd_record,
91*77c1e3ccSAndroid Build Coastguard Worker                                       const int64_t ref_best_rd,
92*77c1e3ccSAndroid Build Coastguard Worker                                       const uint32_t hash) {
93*77c1e3ccSAndroid Build Coastguard Worker   int32_t match_index = -1;
94*77c1e3ccSAndroid Build Coastguard Worker   if (ref_best_rd != INT64_MAX) {
95*77c1e3ccSAndroid Build Coastguard Worker     for (int i = 0; i < mb_rd_record->num; ++i) {
96*77c1e3ccSAndroid Build Coastguard Worker       const int index = (mb_rd_record->index_start + i) % RD_RECORD_BUFFER_LEN;
97*77c1e3ccSAndroid Build Coastguard Worker       // If there is a match in the mb_rd_record, fetch the RD decision and
98*77c1e3ccSAndroid Build Coastguard Worker       // terminate early.
99*77c1e3ccSAndroid Build Coastguard Worker       if (mb_rd_record->mb_rd_info[index].hash_value == hash) {
100*77c1e3ccSAndroid Build Coastguard Worker         match_index = index;
101*77c1e3ccSAndroid Build Coastguard Worker         break;
102*77c1e3ccSAndroid Build Coastguard Worker       }
103*77c1e3ccSAndroid Build Coastguard Worker     }
104*77c1e3ccSAndroid Build Coastguard Worker   }
105*77c1e3ccSAndroid Build Coastguard Worker   return match_index;
106*77c1e3ccSAndroid Build Coastguard Worker }
107*77c1e3ccSAndroid Build Coastguard Worker 
fetch_mb_rd_info(int n4,const MB_RD_INFO * const mb_rd_info,RD_STATS * const rd_stats,MACROBLOCK * const x)108*77c1e3ccSAndroid Build Coastguard Worker static inline void fetch_mb_rd_info(int n4, const MB_RD_INFO *const mb_rd_info,
109*77c1e3ccSAndroid Build Coastguard Worker                                     RD_STATS *const rd_stats,
110*77c1e3ccSAndroid Build Coastguard Worker                                     MACROBLOCK *const x) {
111*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCKD *const xd = &x->e_mbd;
112*77c1e3ccSAndroid Build Coastguard Worker   MB_MODE_INFO *const mbmi = xd->mi[0];
113*77c1e3ccSAndroid Build Coastguard Worker   mbmi->tx_size = mb_rd_info->tx_size;
114*77c1e3ccSAndroid Build Coastguard Worker   memcpy(x->txfm_search_info.blk_skip, mb_rd_info->blk_skip,
115*77c1e3ccSAndroid Build Coastguard Worker          sizeof(mb_rd_info->blk_skip[0]) * n4);
116*77c1e3ccSAndroid Build Coastguard Worker   av1_copy(mbmi->inter_tx_size, mb_rd_info->inter_tx_size);
117*77c1e3ccSAndroid Build Coastguard Worker   av1_copy_array(xd->tx_type_map, mb_rd_info->tx_type_map, n4);
118*77c1e3ccSAndroid Build Coastguard Worker   *rd_stats = mb_rd_info->rd_stats;
119*77c1e3ccSAndroid Build Coastguard Worker }
120*77c1e3ccSAndroid Build Coastguard Worker 
av1_pixel_diff_dist(const MACROBLOCK * x,int plane,int blk_row,int blk_col,const BLOCK_SIZE plane_bsize,const BLOCK_SIZE tx_bsize,unsigned int * block_mse_q8)121*77c1e3ccSAndroid Build Coastguard Worker int64_t av1_pixel_diff_dist(const MACROBLOCK *x, int plane, int blk_row,
122*77c1e3ccSAndroid Build Coastguard Worker                             int blk_col, const BLOCK_SIZE plane_bsize,
123*77c1e3ccSAndroid Build Coastguard Worker                             const BLOCK_SIZE tx_bsize,
124*77c1e3ccSAndroid Build Coastguard Worker                             unsigned int *block_mse_q8) {
125*77c1e3ccSAndroid Build Coastguard Worker   int visible_rows, visible_cols;
126*77c1e3ccSAndroid Build Coastguard Worker   const MACROBLOCKD *xd = &x->e_mbd;
127*77c1e3ccSAndroid Build Coastguard Worker   get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize, NULL,
128*77c1e3ccSAndroid Build Coastguard Worker                      NULL, &visible_cols, &visible_rows);
129*77c1e3ccSAndroid Build Coastguard Worker   const int diff_stride = block_size_wide[plane_bsize];
130*77c1e3ccSAndroid Build Coastguard Worker   const int16_t *diff = x->plane[plane].src_diff;
131*77c1e3ccSAndroid Build Coastguard Worker 
132*77c1e3ccSAndroid Build Coastguard Worker   diff += ((blk_row * diff_stride + blk_col) << MI_SIZE_LOG2);
133*77c1e3ccSAndroid Build Coastguard Worker   uint64_t sse =
134*77c1e3ccSAndroid Build Coastguard Worker       aom_sum_squares_2d_i16(diff, diff_stride, visible_cols, visible_rows);
135*77c1e3ccSAndroid Build Coastguard Worker   if (block_mse_q8 != NULL) {
136*77c1e3ccSAndroid Build Coastguard Worker     if (visible_cols > 0 && visible_rows > 0)
137*77c1e3ccSAndroid Build Coastguard Worker       *block_mse_q8 =
138*77c1e3ccSAndroid Build Coastguard Worker           (unsigned int)((256 * sse) / (visible_cols * visible_rows));
139*77c1e3ccSAndroid Build Coastguard Worker     else
140*77c1e3ccSAndroid Build Coastguard Worker       *block_mse_q8 = UINT_MAX;
141*77c1e3ccSAndroid Build Coastguard Worker   }
142*77c1e3ccSAndroid Build Coastguard Worker   return sse;
143*77c1e3ccSAndroid Build Coastguard Worker }
144*77c1e3ccSAndroid Build Coastguard Worker 
145*77c1e3ccSAndroid Build Coastguard Worker // Computes the residual block's SSE and mean on all visible 4x4s in the
146*77c1e3ccSAndroid Build Coastguard Worker // transform block
pixel_diff_stats(MACROBLOCK * x,int plane,int blk_row,int blk_col,const BLOCK_SIZE plane_bsize,const BLOCK_SIZE tx_bsize,unsigned int * block_mse_q8,int64_t * per_px_mean,uint64_t * block_var)147*77c1e3ccSAndroid Build Coastguard Worker static inline int64_t pixel_diff_stats(
148*77c1e3ccSAndroid Build Coastguard Worker     MACROBLOCK *x, int plane, int blk_row, int blk_col,
149*77c1e3ccSAndroid Build Coastguard Worker     const BLOCK_SIZE plane_bsize, const BLOCK_SIZE tx_bsize,
150*77c1e3ccSAndroid Build Coastguard Worker     unsigned int *block_mse_q8, int64_t *per_px_mean, uint64_t *block_var) {
151*77c1e3ccSAndroid Build Coastguard Worker   int visible_rows, visible_cols;
152*77c1e3ccSAndroid Build Coastguard Worker   const MACROBLOCKD *xd = &x->e_mbd;
153*77c1e3ccSAndroid Build Coastguard Worker   get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize, NULL,
154*77c1e3ccSAndroid Build Coastguard Worker                      NULL, &visible_cols, &visible_rows);
155*77c1e3ccSAndroid Build Coastguard Worker   const int diff_stride = block_size_wide[plane_bsize];
156*77c1e3ccSAndroid Build Coastguard Worker   const int16_t *diff = x->plane[plane].src_diff;
157*77c1e3ccSAndroid Build Coastguard Worker 
158*77c1e3ccSAndroid Build Coastguard Worker   diff += ((blk_row * diff_stride + blk_col) << MI_SIZE_LOG2);
159*77c1e3ccSAndroid Build Coastguard Worker   uint64_t sse = 0;
160*77c1e3ccSAndroid Build Coastguard Worker   int sum = 0;
161*77c1e3ccSAndroid Build Coastguard Worker   sse = aom_sum_sse_2d_i16(diff, diff_stride, visible_cols, visible_rows, &sum);
162*77c1e3ccSAndroid Build Coastguard Worker   if (visible_cols > 0 && visible_rows > 0) {
163*77c1e3ccSAndroid Build Coastguard Worker     double norm_factor = 1.0 / (visible_cols * visible_rows);
164*77c1e3ccSAndroid Build Coastguard Worker     int sign_sum = sum > 0 ? 1 : -1;
165*77c1e3ccSAndroid Build Coastguard Worker     // Conversion to transform domain
166*77c1e3ccSAndroid Build Coastguard Worker     *per_px_mean = (int64_t)(norm_factor * abs(sum)) << 7;
167*77c1e3ccSAndroid Build Coastguard Worker     *per_px_mean = sign_sum * (*per_px_mean);
168*77c1e3ccSAndroid Build Coastguard Worker     *block_mse_q8 = (unsigned int)(norm_factor * (256 * sse));
169*77c1e3ccSAndroid Build Coastguard Worker     *block_var = (uint64_t)(sse - (uint64_t)(norm_factor * sum * sum));
170*77c1e3ccSAndroid Build Coastguard Worker   } else {
171*77c1e3ccSAndroid Build Coastguard Worker     *block_mse_q8 = UINT_MAX;
172*77c1e3ccSAndroid Build Coastguard Worker   }
173*77c1e3ccSAndroid Build Coastguard Worker   return sse;
174*77c1e3ccSAndroid Build Coastguard Worker }
175*77c1e3ccSAndroid Build Coastguard Worker 
176*77c1e3ccSAndroid Build Coastguard Worker // Uses simple features on top of DCT coefficients to quickly predict
177*77c1e3ccSAndroid Build Coastguard Worker // whether optimal RD decision is to skip encoding the residual.
178*77c1e3ccSAndroid Build Coastguard Worker // The sse value is stored in dist.
predict_skip_txfm(MACROBLOCK * x,BLOCK_SIZE bsize,int64_t * dist,int reduced_tx_set)179*77c1e3ccSAndroid Build Coastguard Worker static int predict_skip_txfm(MACROBLOCK *x, BLOCK_SIZE bsize, int64_t *dist,
180*77c1e3ccSAndroid Build Coastguard Worker                              int reduced_tx_set) {
181*77c1e3ccSAndroid Build Coastguard Worker   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
182*77c1e3ccSAndroid Build Coastguard Worker   const int bw = block_size_wide[bsize];
183*77c1e3ccSAndroid Build Coastguard Worker   const int bh = block_size_high[bsize];
184*77c1e3ccSAndroid Build Coastguard Worker   const MACROBLOCKD *xd = &x->e_mbd;
185*77c1e3ccSAndroid Build Coastguard Worker   const int16_t dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd);
186*77c1e3ccSAndroid Build Coastguard Worker 
187*77c1e3ccSAndroid Build Coastguard Worker   *dist = av1_pixel_diff_dist(x, 0, 0, 0, bsize, bsize, NULL);
188*77c1e3ccSAndroid Build Coastguard Worker 
189*77c1e3ccSAndroid Build Coastguard Worker   const int64_t mse = *dist / bw / bh;
190*77c1e3ccSAndroid Build Coastguard Worker   // Normalized quantizer takes the transform upscaling factor (8 for tx size
191*77c1e3ccSAndroid Build Coastguard Worker   // smaller than 32) into account.
192*77c1e3ccSAndroid Build Coastguard Worker   const int16_t normalized_dc_q = dc_q >> 3;
193*77c1e3ccSAndroid Build Coastguard Worker   const int64_t mse_thresh = (int64_t)normalized_dc_q * normalized_dc_q / 8;
194*77c1e3ccSAndroid Build Coastguard Worker   // For faster early skip decision, use dist to compare against threshold so
195*77c1e3ccSAndroid Build Coastguard Worker   // that quality risk is less for the skip=1 decision. Otherwise, use mse
196*77c1e3ccSAndroid Build Coastguard Worker   // since the fwd_txfm coeff checks will take care of quality
197*77c1e3ccSAndroid Build Coastguard Worker   // TODO(any): Use dist to return 0 when skip_txfm_level is 1
198*77c1e3ccSAndroid Build Coastguard Worker   int64_t pred_err = (txfm_params->skip_txfm_level >= 2) ? *dist : mse;
199*77c1e3ccSAndroid Build Coastguard Worker   // Predict not to skip when error is larger than threshold.
200*77c1e3ccSAndroid Build Coastguard Worker   if (pred_err > mse_thresh) return 0;
201*77c1e3ccSAndroid Build Coastguard Worker   // Return as skip otherwise for aggressive early skip
202*77c1e3ccSAndroid Build Coastguard Worker   else if (txfm_params->skip_txfm_level >= 2)
203*77c1e3ccSAndroid Build Coastguard Worker     return 1;
204*77c1e3ccSAndroid Build Coastguard Worker 
205*77c1e3ccSAndroid Build Coastguard Worker   const int max_tx_size = max_predict_sf_tx_size[bsize];
206*77c1e3ccSAndroid Build Coastguard Worker   const int tx_h = tx_size_high[max_tx_size];
207*77c1e3ccSAndroid Build Coastguard Worker   const int tx_w = tx_size_wide[max_tx_size];
208*77c1e3ccSAndroid Build Coastguard Worker   DECLARE_ALIGNED(32, tran_low_t, coefs[32 * 32]);
209*77c1e3ccSAndroid Build Coastguard Worker   TxfmParam param;
210*77c1e3ccSAndroid Build Coastguard Worker   param.tx_type = DCT_DCT;
211*77c1e3ccSAndroid Build Coastguard Worker   param.tx_size = max_tx_size;
212*77c1e3ccSAndroid Build Coastguard Worker   param.bd = xd->bd;
213*77c1e3ccSAndroid Build Coastguard Worker   param.is_hbd = is_cur_buf_hbd(xd);
214*77c1e3ccSAndroid Build Coastguard Worker   param.lossless = 0;
215*77c1e3ccSAndroid Build Coastguard Worker   param.tx_set_type = av1_get_ext_tx_set_type(
216*77c1e3ccSAndroid Build Coastguard Worker       param.tx_size, is_inter_block(xd->mi[0]), reduced_tx_set);
217*77c1e3ccSAndroid Build Coastguard Worker   const int bd_idx = (xd->bd == 8) ? 0 : ((xd->bd == 10) ? 1 : 2);
218*77c1e3ccSAndroid Build Coastguard Worker   const uint32_t max_qcoef_thresh = skip_pred_threshold[bd_idx][bsize];
219*77c1e3ccSAndroid Build Coastguard Worker   const int16_t *src_diff = x->plane[0].src_diff;
220*77c1e3ccSAndroid Build Coastguard Worker   const int n_coeff = tx_w * tx_h;
221*77c1e3ccSAndroid Build Coastguard Worker   const int16_t ac_q = av1_ac_quant_QTX(x->qindex, 0, xd->bd);
222*77c1e3ccSAndroid Build Coastguard Worker   const uint32_t dc_thresh = max_qcoef_thresh * dc_q;
223*77c1e3ccSAndroid Build Coastguard Worker   const uint32_t ac_thresh = max_qcoef_thresh * ac_q;
224*77c1e3ccSAndroid Build Coastguard Worker   for (int row = 0; row < bh; row += tx_h) {
225*77c1e3ccSAndroid Build Coastguard Worker     for (int col = 0; col < bw; col += tx_w) {
226*77c1e3ccSAndroid Build Coastguard Worker       av1_fwd_txfm(src_diff + col, coefs, bw, &param);
227*77c1e3ccSAndroid Build Coastguard Worker       // Operating on TX domain, not pixels; we want the QTX quantizers
228*77c1e3ccSAndroid Build Coastguard Worker       const uint32_t dc_coef = (((uint32_t)abs(coefs[0])) << 7);
229*77c1e3ccSAndroid Build Coastguard Worker       if (dc_coef >= dc_thresh) return 0;
230*77c1e3ccSAndroid Build Coastguard Worker       for (int i = 1; i < n_coeff; ++i) {
231*77c1e3ccSAndroid Build Coastguard Worker         const uint32_t ac_coef = (((uint32_t)abs(coefs[i])) << 7);
232*77c1e3ccSAndroid Build Coastguard Worker         if (ac_coef >= ac_thresh) return 0;
233*77c1e3ccSAndroid Build Coastguard Worker       }
234*77c1e3ccSAndroid Build Coastguard Worker     }
235*77c1e3ccSAndroid Build Coastguard Worker     src_diff += tx_h * bw;
236*77c1e3ccSAndroid Build Coastguard Worker   }
237*77c1e3ccSAndroid Build Coastguard Worker   return 1;
238*77c1e3ccSAndroid Build Coastguard Worker }
239*77c1e3ccSAndroid Build Coastguard Worker 
240*77c1e3ccSAndroid Build Coastguard Worker // Used to set proper context for early termination with skip = 1.
set_skip_txfm(MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t dist)241*77c1e3ccSAndroid Build Coastguard Worker static inline void set_skip_txfm(MACROBLOCK *x, RD_STATS *rd_stats,
242*77c1e3ccSAndroid Build Coastguard Worker                                  BLOCK_SIZE bsize, int64_t dist) {
243*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCKD *const xd = &x->e_mbd;
244*77c1e3ccSAndroid Build Coastguard Worker   MB_MODE_INFO *const mbmi = xd->mi[0];
245*77c1e3ccSAndroid Build Coastguard Worker   const int n4 = bsize_to_num_blk(bsize);
246*77c1e3ccSAndroid Build Coastguard Worker   const TX_SIZE tx_size = max_txsize_rect_lookup[bsize];
247*77c1e3ccSAndroid Build Coastguard Worker   memset(xd->tx_type_map, DCT_DCT, sizeof(xd->tx_type_map[0]) * n4);
248*77c1e3ccSAndroid Build Coastguard Worker   memset(mbmi->inter_tx_size, tx_size, sizeof(mbmi->inter_tx_size));
249*77c1e3ccSAndroid Build Coastguard Worker   mbmi->tx_size = tx_size;
250*77c1e3ccSAndroid Build Coastguard Worker   for (int i = 0; i < n4; ++i)
251*77c1e3ccSAndroid Build Coastguard Worker     set_blk_skip(x->txfm_search_info.blk_skip, 0, i, 1);
252*77c1e3ccSAndroid Build Coastguard Worker   rd_stats->skip_txfm = 1;
253*77c1e3ccSAndroid Build Coastguard Worker   if (is_cur_buf_hbd(xd)) dist = ROUND_POWER_OF_TWO(dist, (xd->bd - 8) * 2);
254*77c1e3ccSAndroid Build Coastguard Worker   rd_stats->dist = rd_stats->sse = (dist << 4);
255*77c1e3ccSAndroid Build Coastguard Worker   // Though decision is to make the block as skip based on luma stats,
256*77c1e3ccSAndroid Build Coastguard Worker   // it is possible that block becomes non skip after chroma rd. In addition
257*77c1e3ccSAndroid Build Coastguard Worker   // intermediate non skip costs calculated by caller function will be
258*77c1e3ccSAndroid Build Coastguard Worker   // incorrect, if rate is set as  zero (i.e., if zero_blk_rate is not
259*77c1e3ccSAndroid Build Coastguard Worker   // accounted). Hence intermediate rate is populated to code the luma tx blks
260*77c1e3ccSAndroid Build Coastguard Worker   // as skip, the caller function based on final rd decision (i.e., skip vs
261*77c1e3ccSAndroid Build Coastguard Worker   // non-skip) sets the final rate accordingly. Here the rate populated
262*77c1e3ccSAndroid Build Coastguard Worker   // corresponds to coding all the tx blocks with zero_blk_rate (based on max tx
263*77c1e3ccSAndroid Build Coastguard Worker   // size possible) in the current block. Eg: For 128*128 block, rate would be
264*77c1e3ccSAndroid Build Coastguard Worker   // 4 * zero_blk_rate where zero_blk_rate corresponds to coding of one 64x64 tx
265*77c1e3ccSAndroid Build Coastguard Worker   // block as 'all zeros'
266*77c1e3ccSAndroid Build Coastguard Worker   ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
267*77c1e3ccSAndroid Build Coastguard Worker   ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
268*77c1e3ccSAndroid Build Coastguard Worker   av1_get_entropy_contexts(bsize, &xd->plane[0], ctxa, ctxl);
269*77c1e3ccSAndroid Build Coastguard Worker   ENTROPY_CONTEXT *ta = ctxa;
270*77c1e3ccSAndroid Build Coastguard Worker   ENTROPY_CONTEXT *tl = ctxl;
271*77c1e3ccSAndroid Build Coastguard Worker   const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
272*77c1e3ccSAndroid Build Coastguard Worker   TXB_CTX txb_ctx;
273*77c1e3ccSAndroid Build Coastguard Worker   get_txb_ctx(bsize, tx_size, 0, ta, tl, &txb_ctx);
274*77c1e3ccSAndroid Build Coastguard Worker   const int zero_blk_rate = x->coeff_costs.coeff_costs[txs_ctx][PLANE_TYPE_Y]
275*77c1e3ccSAndroid Build Coastguard Worker                                 .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
276*77c1e3ccSAndroid Build Coastguard Worker   rd_stats->rate = zero_blk_rate *
277*77c1e3ccSAndroid Build Coastguard Worker                    (block_size_wide[bsize] >> tx_size_wide_log2[tx_size]) *
278*77c1e3ccSAndroid Build Coastguard Worker                    (block_size_high[bsize] >> tx_size_high_log2[tx_size]);
279*77c1e3ccSAndroid Build Coastguard Worker }
280*77c1e3ccSAndroid Build Coastguard Worker 
save_mb_rd_info(int n4,uint32_t hash,const MACROBLOCK * const x,const RD_STATS * const rd_stats,MB_RD_RECORD * mb_rd_record)281*77c1e3ccSAndroid Build Coastguard Worker static inline void save_mb_rd_info(int n4, uint32_t hash,
282*77c1e3ccSAndroid Build Coastguard Worker                                    const MACROBLOCK *const x,
283*77c1e3ccSAndroid Build Coastguard Worker                                    const RD_STATS *const rd_stats,
284*77c1e3ccSAndroid Build Coastguard Worker                                    MB_RD_RECORD *mb_rd_record) {
285*77c1e3ccSAndroid Build Coastguard Worker   int index;
286*77c1e3ccSAndroid Build Coastguard Worker   if (mb_rd_record->num < RD_RECORD_BUFFER_LEN) {
287*77c1e3ccSAndroid Build Coastguard Worker     index =
288*77c1e3ccSAndroid Build Coastguard Worker         (mb_rd_record->index_start + mb_rd_record->num) % RD_RECORD_BUFFER_LEN;
289*77c1e3ccSAndroid Build Coastguard Worker     ++mb_rd_record->num;
290*77c1e3ccSAndroid Build Coastguard Worker   } else {
291*77c1e3ccSAndroid Build Coastguard Worker     index = mb_rd_record->index_start;
292*77c1e3ccSAndroid Build Coastguard Worker     mb_rd_record->index_start =
293*77c1e3ccSAndroid Build Coastguard Worker         (mb_rd_record->index_start + 1) % RD_RECORD_BUFFER_LEN;
294*77c1e3ccSAndroid Build Coastguard Worker   }
295*77c1e3ccSAndroid Build Coastguard Worker   MB_RD_INFO *const mb_rd_info = &mb_rd_record->mb_rd_info[index];
296*77c1e3ccSAndroid Build Coastguard Worker   const MACROBLOCKD *const xd = &x->e_mbd;
297*77c1e3ccSAndroid Build Coastguard Worker   const MB_MODE_INFO *const mbmi = xd->mi[0];
298*77c1e3ccSAndroid Build Coastguard Worker   mb_rd_info->hash_value = hash;
299*77c1e3ccSAndroid Build Coastguard Worker   mb_rd_info->tx_size = mbmi->tx_size;
300*77c1e3ccSAndroid Build Coastguard Worker   memcpy(mb_rd_info->blk_skip, x->txfm_search_info.blk_skip,
301*77c1e3ccSAndroid Build Coastguard Worker          sizeof(mb_rd_info->blk_skip[0]) * n4);
302*77c1e3ccSAndroid Build Coastguard Worker   av1_copy(mb_rd_info->inter_tx_size, mbmi->inter_tx_size);
303*77c1e3ccSAndroid Build Coastguard Worker   av1_copy_array(mb_rd_info->tx_type_map, xd->tx_type_map, n4);
304*77c1e3ccSAndroid Build Coastguard Worker   mb_rd_info->rd_stats = *rd_stats;
305*77c1e3ccSAndroid Build Coastguard Worker }
306*77c1e3ccSAndroid Build Coastguard Worker 
get_search_init_depth(int mi_width,int mi_height,int is_inter,const SPEED_FEATURES * sf,int tx_size_search_method)307*77c1e3ccSAndroid Build Coastguard Worker static int get_search_init_depth(int mi_width, int mi_height, int is_inter,
308*77c1e3ccSAndroid Build Coastguard Worker                                  const SPEED_FEATURES *sf,
309*77c1e3ccSAndroid Build Coastguard Worker                                  int tx_size_search_method) {
310*77c1e3ccSAndroid Build Coastguard Worker   if (tx_size_search_method == USE_LARGESTALL) return MAX_VARTX_DEPTH;
311*77c1e3ccSAndroid Build Coastguard Worker 
312*77c1e3ccSAndroid Build Coastguard Worker   if (sf->tx_sf.tx_size_search_lgr_block) {
313*77c1e3ccSAndroid Build Coastguard Worker     if (mi_width > mi_size_wide[BLOCK_64X64] ||
314*77c1e3ccSAndroid Build Coastguard Worker         mi_height > mi_size_high[BLOCK_64X64])
315*77c1e3ccSAndroid Build Coastguard Worker       return MAX_VARTX_DEPTH;
316*77c1e3ccSAndroid Build Coastguard Worker   }
317*77c1e3ccSAndroid Build Coastguard Worker 
318*77c1e3ccSAndroid Build Coastguard Worker   if (is_inter) {
319*77c1e3ccSAndroid Build Coastguard Worker     return (mi_height != mi_width)
320*77c1e3ccSAndroid Build Coastguard Worker                ? sf->tx_sf.inter_tx_size_search_init_depth_rect
321*77c1e3ccSAndroid Build Coastguard Worker                : sf->tx_sf.inter_tx_size_search_init_depth_sqr;
322*77c1e3ccSAndroid Build Coastguard Worker   } else {
323*77c1e3ccSAndroid Build Coastguard Worker     return (mi_height != mi_width)
324*77c1e3ccSAndroid Build Coastguard Worker                ? sf->tx_sf.intra_tx_size_search_init_depth_rect
325*77c1e3ccSAndroid Build Coastguard Worker                : sf->tx_sf.intra_tx_size_search_init_depth_sqr;
326*77c1e3ccSAndroid Build Coastguard Worker   }
327*77c1e3ccSAndroid Build Coastguard Worker }
328*77c1e3ccSAndroid Build Coastguard Worker 
329*77c1e3ccSAndroid Build Coastguard Worker static inline void select_tx_block(
330*77c1e3ccSAndroid Build Coastguard Worker     const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
331*77c1e3ccSAndroid Build Coastguard Worker     TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
332*77c1e3ccSAndroid Build Coastguard Worker     ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
333*77c1e3ccSAndroid Build Coastguard Worker     RD_STATS *rd_stats, int64_t prev_level_rd, int64_t ref_best_rd,
334*77c1e3ccSAndroid Build Coastguard Worker     int *is_cost_valid, FAST_TX_SEARCH_MODE ftxs_mode);
335*77c1e3ccSAndroid Build Coastguard Worker 
336*77c1e3ccSAndroid Build Coastguard Worker // NOTE: CONFIG_COLLECT_RD_STATS has 3 possible values
337*77c1e3ccSAndroid Build Coastguard Worker // 0: Do not collect any RD stats
338*77c1e3ccSAndroid Build Coastguard Worker // 1: Collect RD stats for transform units
339*77c1e3ccSAndroid Build Coastguard Worker // 2: Collect RD stats for partition units
340*77c1e3ccSAndroid Build Coastguard Worker #if CONFIG_COLLECT_RD_STATS
341*77c1e3ccSAndroid Build Coastguard Worker 
get_energy_distribution_fine(const AV1_COMP * cpi,BLOCK_SIZE bsize,const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride,int need_4th,double * hordist,double * verdist)342*77c1e3ccSAndroid Build Coastguard Worker static inline void get_energy_distribution_fine(
343*77c1e3ccSAndroid Build Coastguard Worker     const AV1_COMP *cpi, BLOCK_SIZE bsize, const uint8_t *src, int src_stride,
344*77c1e3ccSAndroid Build Coastguard Worker     const uint8_t *dst, int dst_stride, int need_4th, double *hordist,
345*77c1e3ccSAndroid Build Coastguard Worker     double *verdist) {
346*77c1e3ccSAndroid Build Coastguard Worker   const int bw = block_size_wide[bsize];
347*77c1e3ccSAndroid Build Coastguard Worker   const int bh = block_size_high[bsize];
348*77c1e3ccSAndroid Build Coastguard Worker   unsigned int esq[16] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
349*77c1e3ccSAndroid Build Coastguard Worker 
350*77c1e3ccSAndroid Build Coastguard Worker   if (bsize < BLOCK_16X16 || (bsize >= BLOCK_4X16 && bsize <= BLOCK_32X8)) {
351*77c1e3ccSAndroid Build Coastguard Worker     // Special cases: calculate 'esq' values manually, as we don't have 'vf'
352*77c1e3ccSAndroid Build Coastguard Worker     // functions for the 16 (very small) sub-blocks of this block.
353*77c1e3ccSAndroid Build Coastguard Worker     const int w_shift = (bw == 4) ? 0 : (bw == 8) ? 1 : (bw == 16) ? 2 : 3;
354*77c1e3ccSAndroid Build Coastguard Worker     const int h_shift = (bh == 4) ? 0 : (bh == 8) ? 1 : (bh == 16) ? 2 : 3;
355*77c1e3ccSAndroid Build Coastguard Worker     assert(bw <= 32);
356*77c1e3ccSAndroid Build Coastguard Worker     assert(bh <= 32);
357*77c1e3ccSAndroid Build Coastguard Worker     assert(((bw - 1) >> w_shift) + (((bh - 1) >> h_shift) << 2) == 15);
358*77c1e3ccSAndroid Build Coastguard Worker     if (cpi->common.seq_params->use_highbitdepth) {
359*77c1e3ccSAndroid Build Coastguard Worker       const uint16_t *src16 = CONVERT_TO_SHORTPTR(src);
360*77c1e3ccSAndroid Build Coastguard Worker       const uint16_t *dst16 = CONVERT_TO_SHORTPTR(dst);
361*77c1e3ccSAndroid Build Coastguard Worker       for (int i = 0; i < bh; ++i)
362*77c1e3ccSAndroid Build Coastguard Worker         for (int j = 0; j < bw; ++j) {
363*77c1e3ccSAndroid Build Coastguard Worker           const int index = (j >> w_shift) + ((i >> h_shift) << 2);
364*77c1e3ccSAndroid Build Coastguard Worker           esq[index] +=
365*77c1e3ccSAndroid Build Coastguard Worker               (src16[j + i * src_stride] - dst16[j + i * dst_stride]) *
366*77c1e3ccSAndroid Build Coastguard Worker               (src16[j + i * src_stride] - dst16[j + i * dst_stride]);
367*77c1e3ccSAndroid Build Coastguard Worker         }
368*77c1e3ccSAndroid Build Coastguard Worker     } else {
369*77c1e3ccSAndroid Build Coastguard Worker       for (int i = 0; i < bh; ++i)
370*77c1e3ccSAndroid Build Coastguard Worker         for (int j = 0; j < bw; ++j) {
371*77c1e3ccSAndroid Build Coastguard Worker           const int index = (j >> w_shift) + ((i >> h_shift) << 2);
372*77c1e3ccSAndroid Build Coastguard Worker           esq[index] += (src[j + i * src_stride] - dst[j + i * dst_stride]) *
373*77c1e3ccSAndroid Build Coastguard Worker                         (src[j + i * src_stride] - dst[j + i * dst_stride]);
374*77c1e3ccSAndroid Build Coastguard Worker         }
375*77c1e3ccSAndroid Build Coastguard Worker     }
376*77c1e3ccSAndroid Build Coastguard Worker   } else {  // Calculate 'esq' values using 'vf' functions on the 16 sub-blocks.
377*77c1e3ccSAndroid Build Coastguard Worker     const int f_index =
378*77c1e3ccSAndroid Build Coastguard Worker         (bsize < BLOCK_SIZES) ? bsize - BLOCK_16X16 : bsize - BLOCK_8X16;
379*77c1e3ccSAndroid Build Coastguard Worker     assert(f_index >= 0 && f_index < BLOCK_SIZES_ALL);
380*77c1e3ccSAndroid Build Coastguard Worker     const BLOCK_SIZE subsize = (BLOCK_SIZE)f_index;
381*77c1e3ccSAndroid Build Coastguard Worker     assert(block_size_wide[bsize] == 4 * block_size_wide[subsize]);
382*77c1e3ccSAndroid Build Coastguard Worker     assert(block_size_high[bsize] == 4 * block_size_high[subsize]);
383*77c1e3ccSAndroid Build Coastguard Worker     cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[0]);
384*77c1e3ccSAndroid Build Coastguard Worker     cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
385*77c1e3ccSAndroid Build Coastguard Worker                                  dst_stride, &esq[1]);
386*77c1e3ccSAndroid Build Coastguard Worker     cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
387*77c1e3ccSAndroid Build Coastguard Worker                                  dst_stride, &esq[2]);
388*77c1e3ccSAndroid Build Coastguard Worker     cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
389*77c1e3ccSAndroid Build Coastguard Worker                                  dst_stride, &esq[3]);
390*77c1e3ccSAndroid Build Coastguard Worker     src += bh / 4 * src_stride;
391*77c1e3ccSAndroid Build Coastguard Worker     dst += bh / 4 * dst_stride;
392*77c1e3ccSAndroid Build Coastguard Worker 
393*77c1e3ccSAndroid Build Coastguard Worker     cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[4]);
394*77c1e3ccSAndroid Build Coastguard Worker     cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
395*77c1e3ccSAndroid Build Coastguard Worker                                  dst_stride, &esq[5]);
396*77c1e3ccSAndroid Build Coastguard Worker     cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
397*77c1e3ccSAndroid Build Coastguard Worker                                  dst_stride, &esq[6]);
398*77c1e3ccSAndroid Build Coastguard Worker     cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
399*77c1e3ccSAndroid Build Coastguard Worker                                  dst_stride, &esq[7]);
400*77c1e3ccSAndroid Build Coastguard Worker     src += bh / 4 * src_stride;
401*77c1e3ccSAndroid Build Coastguard Worker     dst += bh / 4 * dst_stride;
402*77c1e3ccSAndroid Build Coastguard Worker 
403*77c1e3ccSAndroid Build Coastguard Worker     cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[8]);
404*77c1e3ccSAndroid Build Coastguard Worker     cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
405*77c1e3ccSAndroid Build Coastguard Worker                                  dst_stride, &esq[9]);
406*77c1e3ccSAndroid Build Coastguard Worker     cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
407*77c1e3ccSAndroid Build Coastguard Worker                                  dst_stride, &esq[10]);
408*77c1e3ccSAndroid Build Coastguard Worker     cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
409*77c1e3ccSAndroid Build Coastguard Worker                                  dst_stride, &esq[11]);
410*77c1e3ccSAndroid Build Coastguard Worker     src += bh / 4 * src_stride;
411*77c1e3ccSAndroid Build Coastguard Worker     dst += bh / 4 * dst_stride;
412*77c1e3ccSAndroid Build Coastguard Worker 
413*77c1e3ccSAndroid Build Coastguard Worker     cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[12]);
414*77c1e3ccSAndroid Build Coastguard Worker     cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
415*77c1e3ccSAndroid Build Coastguard Worker                                  dst_stride, &esq[13]);
416*77c1e3ccSAndroid Build Coastguard Worker     cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
417*77c1e3ccSAndroid Build Coastguard Worker                                  dst_stride, &esq[14]);
418*77c1e3ccSAndroid Build Coastguard Worker     cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
419*77c1e3ccSAndroid Build Coastguard Worker                                  dst_stride, &esq[15]);
420*77c1e3ccSAndroid Build Coastguard Worker   }
421*77c1e3ccSAndroid Build Coastguard Worker 
422*77c1e3ccSAndroid Build Coastguard Worker   double total = (double)esq[0] + esq[1] + esq[2] + esq[3] + esq[4] + esq[5] +
423*77c1e3ccSAndroid Build Coastguard Worker                  esq[6] + esq[7] + esq[8] + esq[9] + esq[10] + esq[11] +
424*77c1e3ccSAndroid Build Coastguard Worker                  esq[12] + esq[13] + esq[14] + esq[15];
425*77c1e3ccSAndroid Build Coastguard Worker   if (total > 0) {
426*77c1e3ccSAndroid Build Coastguard Worker     const double e_recip = 1.0 / total;
427*77c1e3ccSAndroid Build Coastguard Worker     hordist[0] = ((double)esq[0] + esq[4] + esq[8] + esq[12]) * e_recip;
428*77c1e3ccSAndroid Build Coastguard Worker     hordist[1] = ((double)esq[1] + esq[5] + esq[9] + esq[13]) * e_recip;
429*77c1e3ccSAndroid Build Coastguard Worker     hordist[2] = ((double)esq[2] + esq[6] + esq[10] + esq[14]) * e_recip;
430*77c1e3ccSAndroid Build Coastguard Worker     if (need_4th) {
431*77c1e3ccSAndroid Build Coastguard Worker       hordist[3] = ((double)esq[3] + esq[7] + esq[11] + esq[15]) * e_recip;
432*77c1e3ccSAndroid Build Coastguard Worker     }
433*77c1e3ccSAndroid Build Coastguard Worker     verdist[0] = ((double)esq[0] + esq[1] + esq[2] + esq[3]) * e_recip;
434*77c1e3ccSAndroid Build Coastguard Worker     verdist[1] = ((double)esq[4] + esq[5] + esq[6] + esq[7]) * e_recip;
435*77c1e3ccSAndroid Build Coastguard Worker     verdist[2] = ((double)esq[8] + esq[9] + esq[10] + esq[11]) * e_recip;
436*77c1e3ccSAndroid Build Coastguard Worker     if (need_4th) {
437*77c1e3ccSAndroid Build Coastguard Worker       verdist[3] = ((double)esq[12] + esq[13] + esq[14] + esq[15]) * e_recip;
438*77c1e3ccSAndroid Build Coastguard Worker     }
439*77c1e3ccSAndroid Build Coastguard Worker   } else {
440*77c1e3ccSAndroid Build Coastguard Worker     hordist[0] = verdist[0] = 0.25;
441*77c1e3ccSAndroid Build Coastguard Worker     hordist[1] = verdist[1] = 0.25;
442*77c1e3ccSAndroid Build Coastguard Worker     hordist[2] = verdist[2] = 0.25;
443*77c1e3ccSAndroid Build Coastguard Worker     if (need_4th) {
444*77c1e3ccSAndroid Build Coastguard Worker       hordist[3] = verdist[3] = 0.25;
445*77c1e3ccSAndroid Build Coastguard Worker     }
446*77c1e3ccSAndroid Build Coastguard Worker   }
447*77c1e3ccSAndroid Build Coastguard Worker }
448*77c1e3ccSAndroid Build Coastguard Worker 
get_sse_norm(const int16_t * diff,int stride,int w,int h)449*77c1e3ccSAndroid Build Coastguard Worker static double get_sse_norm(const int16_t *diff, int stride, int w, int h) {
450*77c1e3ccSAndroid Build Coastguard Worker   double sum = 0.0;
451*77c1e3ccSAndroid Build Coastguard Worker   for (int j = 0; j < h; ++j) {
452*77c1e3ccSAndroid Build Coastguard Worker     for (int i = 0; i < w; ++i) {
453*77c1e3ccSAndroid Build Coastguard Worker       const int err = diff[j * stride + i];
454*77c1e3ccSAndroid Build Coastguard Worker       sum += err * err;
455*77c1e3ccSAndroid Build Coastguard Worker     }
456*77c1e3ccSAndroid Build Coastguard Worker   }
457*77c1e3ccSAndroid Build Coastguard Worker   assert(w > 0 && h > 0);
458*77c1e3ccSAndroid Build Coastguard Worker   return sum / (w * h);
459*77c1e3ccSAndroid Build Coastguard Worker }
460*77c1e3ccSAndroid Build Coastguard Worker 
get_sad_norm(const int16_t * diff,int stride,int w,int h)461*77c1e3ccSAndroid Build Coastguard Worker static double get_sad_norm(const int16_t *diff, int stride, int w, int h) {
462*77c1e3ccSAndroid Build Coastguard Worker   double sum = 0.0;
463*77c1e3ccSAndroid Build Coastguard Worker   for (int j = 0; j < h; ++j) {
464*77c1e3ccSAndroid Build Coastguard Worker     for (int i = 0; i < w; ++i) {
465*77c1e3ccSAndroid Build Coastguard Worker       sum += abs(diff[j * stride + i]);
466*77c1e3ccSAndroid Build Coastguard Worker     }
467*77c1e3ccSAndroid Build Coastguard Worker   }
468*77c1e3ccSAndroid Build Coastguard Worker   assert(w > 0 && h > 0);
469*77c1e3ccSAndroid Build Coastguard Worker   return sum / (w * h);
470*77c1e3ccSAndroid Build Coastguard Worker }
471*77c1e3ccSAndroid Build Coastguard Worker 
get_2x2_normalized_sses_and_sads(const AV1_COMP * const cpi,BLOCK_SIZE tx_bsize,const uint8_t * const src,int src_stride,const uint8_t * const dst,int dst_stride,const int16_t * const src_diff,int diff_stride,double * const sse_norm_arr,double * const sad_norm_arr)472*77c1e3ccSAndroid Build Coastguard Worker static inline void get_2x2_normalized_sses_and_sads(
473*77c1e3ccSAndroid Build Coastguard Worker     const AV1_COMP *const cpi, BLOCK_SIZE tx_bsize, const uint8_t *const src,
474*77c1e3ccSAndroid Build Coastguard Worker     int src_stride, const uint8_t *const dst, int dst_stride,
475*77c1e3ccSAndroid Build Coastguard Worker     const int16_t *const src_diff, int diff_stride, double *const sse_norm_arr,
476*77c1e3ccSAndroid Build Coastguard Worker     double *const sad_norm_arr) {
477*77c1e3ccSAndroid Build Coastguard Worker   const BLOCK_SIZE tx_bsize_half =
478*77c1e3ccSAndroid Build Coastguard Worker       get_partition_subsize(tx_bsize, PARTITION_SPLIT);
479*77c1e3ccSAndroid Build Coastguard Worker   if (tx_bsize_half == BLOCK_INVALID) {  // manually calculate stats
480*77c1e3ccSAndroid Build Coastguard Worker     const int half_width = block_size_wide[tx_bsize] / 2;
481*77c1e3ccSAndroid Build Coastguard Worker     const int half_height = block_size_high[tx_bsize] / 2;
482*77c1e3ccSAndroid Build Coastguard Worker     for (int row = 0; row < 2; ++row) {
483*77c1e3ccSAndroid Build Coastguard Worker       for (int col = 0; col < 2; ++col) {
484*77c1e3ccSAndroid Build Coastguard Worker         const int16_t *const this_src_diff =
485*77c1e3ccSAndroid Build Coastguard Worker             src_diff + row * half_height * diff_stride + col * half_width;
486*77c1e3ccSAndroid Build Coastguard Worker         if (sse_norm_arr) {
487*77c1e3ccSAndroid Build Coastguard Worker           sse_norm_arr[row * 2 + col] =
488*77c1e3ccSAndroid Build Coastguard Worker               get_sse_norm(this_src_diff, diff_stride, half_width, half_height);
489*77c1e3ccSAndroid Build Coastguard Worker         }
490*77c1e3ccSAndroid Build Coastguard Worker         if (sad_norm_arr) {
491*77c1e3ccSAndroid Build Coastguard Worker           sad_norm_arr[row * 2 + col] =
492*77c1e3ccSAndroid Build Coastguard Worker               get_sad_norm(this_src_diff, diff_stride, half_width, half_height);
493*77c1e3ccSAndroid Build Coastguard Worker         }
494*77c1e3ccSAndroid Build Coastguard Worker       }
495*77c1e3ccSAndroid Build Coastguard Worker     }
496*77c1e3ccSAndroid Build Coastguard Worker   } else {  // use function pointers to calculate stats
497*77c1e3ccSAndroid Build Coastguard Worker     const int half_width = block_size_wide[tx_bsize_half];
498*77c1e3ccSAndroid Build Coastguard Worker     const int half_height = block_size_high[tx_bsize_half];
499*77c1e3ccSAndroid Build Coastguard Worker     const int num_samples_half = half_width * half_height;
500*77c1e3ccSAndroid Build Coastguard Worker     for (int row = 0; row < 2; ++row) {
501*77c1e3ccSAndroid Build Coastguard Worker       for (int col = 0; col < 2; ++col) {
502*77c1e3ccSAndroid Build Coastguard Worker         const uint8_t *const this_src =
503*77c1e3ccSAndroid Build Coastguard Worker             src + row * half_height * src_stride + col * half_width;
504*77c1e3ccSAndroid Build Coastguard Worker         const uint8_t *const this_dst =
505*77c1e3ccSAndroid Build Coastguard Worker             dst + row * half_height * dst_stride + col * half_width;
506*77c1e3ccSAndroid Build Coastguard Worker 
507*77c1e3ccSAndroid Build Coastguard Worker         if (sse_norm_arr) {
508*77c1e3ccSAndroid Build Coastguard Worker           unsigned int this_sse;
509*77c1e3ccSAndroid Build Coastguard Worker           cpi->ppi->fn_ptr[tx_bsize_half].vf(this_src, src_stride, this_dst,
510*77c1e3ccSAndroid Build Coastguard Worker                                              dst_stride, &this_sse);
511*77c1e3ccSAndroid Build Coastguard Worker           sse_norm_arr[row * 2 + col] = (double)this_sse / num_samples_half;
512*77c1e3ccSAndroid Build Coastguard Worker         }
513*77c1e3ccSAndroid Build Coastguard Worker 
514*77c1e3ccSAndroid Build Coastguard Worker         if (sad_norm_arr) {
515*77c1e3ccSAndroid Build Coastguard Worker           const unsigned int this_sad = cpi->ppi->fn_ptr[tx_bsize_half].sdf(
516*77c1e3ccSAndroid Build Coastguard Worker               this_src, src_stride, this_dst, dst_stride);
517*77c1e3ccSAndroid Build Coastguard Worker           sad_norm_arr[row * 2 + col] = (double)this_sad / num_samples_half;
518*77c1e3ccSAndroid Build Coastguard Worker         }
519*77c1e3ccSAndroid Build Coastguard Worker       }
520*77c1e3ccSAndroid Build Coastguard Worker     }
521*77c1e3ccSAndroid Build Coastguard Worker   }
522*77c1e3ccSAndroid Build Coastguard Worker }
523*77c1e3ccSAndroid Build Coastguard Worker 
524*77c1e3ccSAndroid Build Coastguard Worker #if CONFIG_COLLECT_RD_STATS == 1
get_mean(const int16_t * diff,int stride,int w,int h)525*77c1e3ccSAndroid Build Coastguard Worker static double get_mean(const int16_t *diff, int stride, int w, int h) {
526*77c1e3ccSAndroid Build Coastguard Worker   double sum = 0.0;
527*77c1e3ccSAndroid Build Coastguard Worker   for (int j = 0; j < h; ++j) {
528*77c1e3ccSAndroid Build Coastguard Worker     for (int i = 0; i < w; ++i) {
529*77c1e3ccSAndroid Build Coastguard Worker       sum += diff[j * stride + i];
530*77c1e3ccSAndroid Build Coastguard Worker     }
531*77c1e3ccSAndroid Build Coastguard Worker   }
532*77c1e3ccSAndroid Build Coastguard Worker   assert(w > 0 && h > 0);
533*77c1e3ccSAndroid Build Coastguard Worker   return sum / (w * h);
534*77c1e3ccSAndroid Build Coastguard Worker }
PrintTransformUnitStats(const AV1_COMP * const cpi,MACROBLOCK * x,const RD_STATS * const rd_stats,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,TX_TYPE tx_type,int64_t rd)535*77c1e3ccSAndroid Build Coastguard Worker static inline void PrintTransformUnitStats(
536*77c1e3ccSAndroid Build Coastguard Worker     const AV1_COMP *const cpi, MACROBLOCK *x, const RD_STATS *const rd_stats,
537*77c1e3ccSAndroid Build Coastguard Worker     int blk_row, int blk_col, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
538*77c1e3ccSAndroid Build Coastguard Worker     TX_TYPE tx_type, int64_t rd) {
539*77c1e3ccSAndroid Build Coastguard Worker   if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return;
540*77c1e3ccSAndroid Build Coastguard Worker 
541*77c1e3ccSAndroid Build Coastguard Worker   // Generate small sample to restrict output size.
542*77c1e3ccSAndroid Build Coastguard Worker   static unsigned int seed = 21743;
543*77c1e3ccSAndroid Build Coastguard Worker   if (lcg_rand16(&seed) % 256 > 0) return;
544*77c1e3ccSAndroid Build Coastguard Worker 
545*77c1e3ccSAndroid Build Coastguard Worker   const char output_file[] = "tu_stats.txt";
546*77c1e3ccSAndroid Build Coastguard Worker   FILE *fout = fopen(output_file, "a");
547*77c1e3ccSAndroid Build Coastguard Worker   if (!fout) return;
548*77c1e3ccSAndroid Build Coastguard Worker 
549*77c1e3ccSAndroid Build Coastguard Worker   const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
550*77c1e3ccSAndroid Build Coastguard Worker   const MACROBLOCKD *const xd = &x->e_mbd;
551*77c1e3ccSAndroid Build Coastguard Worker   const int plane = 0;
552*77c1e3ccSAndroid Build Coastguard Worker   struct macroblock_plane *const p = &x->plane[plane];
553*77c1e3ccSAndroid Build Coastguard Worker   const struct macroblockd_plane *const pd = &xd->plane[plane];
554*77c1e3ccSAndroid Build Coastguard Worker   const int txw = tx_size_wide[tx_size];
555*77c1e3ccSAndroid Build Coastguard Worker   const int txh = tx_size_high[tx_size];
556*77c1e3ccSAndroid Build Coastguard Worker   const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
557*77c1e3ccSAndroid Build Coastguard Worker   const int q_step = p->dequant_QTX[1] >> dequant_shift;
558*77c1e3ccSAndroid Build Coastguard Worker   const int num_samples = txw * txh;
559*77c1e3ccSAndroid Build Coastguard Worker 
560*77c1e3ccSAndroid Build Coastguard Worker   const double rate_norm = (double)rd_stats->rate / num_samples;
561*77c1e3ccSAndroid Build Coastguard Worker   const double dist_norm = (double)rd_stats->dist / num_samples;
562*77c1e3ccSAndroid Build Coastguard Worker 
563*77c1e3ccSAndroid Build Coastguard Worker   fprintf(fout, "%g %g", rate_norm, dist_norm);
564*77c1e3ccSAndroid Build Coastguard Worker 
565*77c1e3ccSAndroid Build Coastguard Worker   const int src_stride = p->src.stride;
566*77c1e3ccSAndroid Build Coastguard Worker   const uint8_t *const src =
567*77c1e3ccSAndroid Build Coastguard Worker       &p->src.buf[(blk_row * src_stride + blk_col) << MI_SIZE_LOG2];
568*77c1e3ccSAndroid Build Coastguard Worker   const int dst_stride = pd->dst.stride;
569*77c1e3ccSAndroid Build Coastguard Worker   const uint8_t *const dst =
570*77c1e3ccSAndroid Build Coastguard Worker       &pd->dst.buf[(blk_row * dst_stride + blk_col) << MI_SIZE_LOG2];
571*77c1e3ccSAndroid Build Coastguard Worker   unsigned int sse;
572*77c1e3ccSAndroid Build Coastguard Worker   cpi->ppi->fn_ptr[tx_bsize].vf(src, src_stride, dst, dst_stride, &sse);
573*77c1e3ccSAndroid Build Coastguard Worker   const double sse_norm = (double)sse / num_samples;
574*77c1e3ccSAndroid Build Coastguard Worker 
575*77c1e3ccSAndroid Build Coastguard Worker   const unsigned int sad =
576*77c1e3ccSAndroid Build Coastguard Worker       cpi->ppi->fn_ptr[tx_bsize].sdf(src, src_stride, dst, dst_stride);
577*77c1e3ccSAndroid Build Coastguard Worker   const double sad_norm = (double)sad / num_samples;
578*77c1e3ccSAndroid Build Coastguard Worker 
579*77c1e3ccSAndroid Build Coastguard Worker   fprintf(fout, " %g %g", sse_norm, sad_norm);
580*77c1e3ccSAndroid Build Coastguard Worker 
581*77c1e3ccSAndroid Build Coastguard Worker   const int diff_stride = block_size_wide[plane_bsize];
582*77c1e3ccSAndroid Build Coastguard Worker   const int16_t *const src_diff =
583*77c1e3ccSAndroid Build Coastguard Worker       &p->src_diff[(blk_row * diff_stride + blk_col) << MI_SIZE_LOG2];
584*77c1e3ccSAndroid Build Coastguard Worker 
585*77c1e3ccSAndroid Build Coastguard Worker   double sse_norm_arr[4], sad_norm_arr[4];
586*77c1e3ccSAndroid Build Coastguard Worker   get_2x2_normalized_sses_and_sads(cpi, tx_bsize, src, src_stride, dst,
587*77c1e3ccSAndroid Build Coastguard Worker                                    dst_stride, src_diff, diff_stride,
588*77c1e3ccSAndroid Build Coastguard Worker                                    sse_norm_arr, sad_norm_arr);
589*77c1e3ccSAndroid Build Coastguard Worker   for (int i = 0; i < 4; ++i) {
590*77c1e3ccSAndroid Build Coastguard Worker     fprintf(fout, " %g", sse_norm_arr[i]);
591*77c1e3ccSAndroid Build Coastguard Worker   }
592*77c1e3ccSAndroid Build Coastguard Worker   for (int i = 0; i < 4; ++i) {
593*77c1e3ccSAndroid Build Coastguard Worker     fprintf(fout, " %g", sad_norm_arr[i]);
594*77c1e3ccSAndroid Build Coastguard Worker   }
595*77c1e3ccSAndroid Build Coastguard Worker 
596*77c1e3ccSAndroid Build Coastguard Worker   const TX_TYPE_1D tx_type_1d_row = htx_tab[tx_type];
597*77c1e3ccSAndroid Build Coastguard Worker   const TX_TYPE_1D tx_type_1d_col = vtx_tab[tx_type];
598*77c1e3ccSAndroid Build Coastguard Worker 
599*77c1e3ccSAndroid Build Coastguard Worker   fprintf(fout, " %d %d %d %d %d", q_step, tx_size_wide[tx_size],
600*77c1e3ccSAndroid Build Coastguard Worker           tx_size_high[tx_size], tx_type_1d_row, tx_type_1d_col);
601*77c1e3ccSAndroid Build Coastguard Worker 
602*77c1e3ccSAndroid Build Coastguard Worker   int model_rate;
603*77c1e3ccSAndroid Build Coastguard Worker   int64_t model_dist;
604*77c1e3ccSAndroid Build Coastguard Worker   model_rd_sse_fn[MODELRD_CURVFIT](cpi, x, tx_bsize, plane, sse, num_samples,
605*77c1e3ccSAndroid Build Coastguard Worker                                    &model_rate, &model_dist);
606*77c1e3ccSAndroid Build Coastguard Worker   const double model_rate_norm = (double)model_rate / num_samples;
607*77c1e3ccSAndroid Build Coastguard Worker   const double model_dist_norm = (double)model_dist / num_samples;
608*77c1e3ccSAndroid Build Coastguard Worker   fprintf(fout, " %g %g", model_rate_norm, model_dist_norm);
609*77c1e3ccSAndroid Build Coastguard Worker 
610*77c1e3ccSAndroid Build Coastguard Worker   const double mean = get_mean(src_diff, diff_stride, txw, txh);
611*77c1e3ccSAndroid Build Coastguard Worker   float hor_corr, vert_corr;
612*77c1e3ccSAndroid Build Coastguard Worker   av1_get_horver_correlation_full(src_diff, diff_stride, txw, txh, &hor_corr,
613*77c1e3ccSAndroid Build Coastguard Worker                                   &vert_corr);
614*77c1e3ccSAndroid Build Coastguard Worker   fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);
615*77c1e3ccSAndroid Build Coastguard Worker 
616*77c1e3ccSAndroid Build Coastguard Worker   double hdist[4] = { 0 }, vdist[4] = { 0 };
617*77c1e3ccSAndroid Build Coastguard Worker   get_energy_distribution_fine(cpi, tx_bsize, src, src_stride, dst, dst_stride,
618*77c1e3ccSAndroid Build Coastguard Worker                                1, hdist, vdist);
619*77c1e3ccSAndroid Build Coastguard Worker   fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
620*77c1e3ccSAndroid Build Coastguard Worker           hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);
621*77c1e3ccSAndroid Build Coastguard Worker 
622*77c1e3ccSAndroid Build Coastguard Worker   fprintf(fout, " %d %" PRId64, x->rdmult, rd);
623*77c1e3ccSAndroid Build Coastguard Worker 
624*77c1e3ccSAndroid Build Coastguard Worker   fprintf(fout, "\n");
625*77c1e3ccSAndroid Build Coastguard Worker   fclose(fout);
626*77c1e3ccSAndroid Build Coastguard Worker }
627*77c1e3ccSAndroid Build Coastguard Worker #endif  // CONFIG_COLLECT_RD_STATS == 1
628*77c1e3ccSAndroid Build Coastguard Worker 
629*77c1e3ccSAndroid Build Coastguard Worker #if CONFIG_COLLECT_RD_STATS >= 2
get_sse(const AV1_COMP * cpi,const MACROBLOCK * x)630*77c1e3ccSAndroid Build Coastguard Worker static int64_t get_sse(const AV1_COMP *cpi, const MACROBLOCK *x) {
631*77c1e3ccSAndroid Build Coastguard Worker   const AV1_COMMON *cm = &cpi->common;
632*77c1e3ccSAndroid Build Coastguard Worker   const int num_planes = av1_num_planes(cm);
633*77c1e3ccSAndroid Build Coastguard Worker   const MACROBLOCKD *xd = &x->e_mbd;
634*77c1e3ccSAndroid Build Coastguard Worker   const MB_MODE_INFO *mbmi = xd->mi[0];
635*77c1e3ccSAndroid Build Coastguard Worker   int64_t total_sse = 0;
636*77c1e3ccSAndroid Build Coastguard Worker   for (int plane = 0; plane < num_planes; ++plane) {
637*77c1e3ccSAndroid Build Coastguard Worker     const struct macroblock_plane *const p = &x->plane[plane];
638*77c1e3ccSAndroid Build Coastguard Worker     const struct macroblockd_plane *const pd = &xd->plane[plane];
639*77c1e3ccSAndroid Build Coastguard Worker     const BLOCK_SIZE bs =
640*77c1e3ccSAndroid Build Coastguard Worker         get_plane_block_size(mbmi->bsize, pd->subsampling_x, pd->subsampling_y);
641*77c1e3ccSAndroid Build Coastguard Worker     unsigned int sse;
642*77c1e3ccSAndroid Build Coastguard Worker 
643*77c1e3ccSAndroid Build Coastguard Worker     if (plane) continue;
644*77c1e3ccSAndroid Build Coastguard Worker 
645*77c1e3ccSAndroid Build Coastguard Worker     cpi->ppi->fn_ptr[bs].vf(p->src.buf, p->src.stride, pd->dst.buf,
646*77c1e3ccSAndroid Build Coastguard Worker                             pd->dst.stride, &sse);
647*77c1e3ccSAndroid Build Coastguard Worker     total_sse += sse;
648*77c1e3ccSAndroid Build Coastguard Worker   }
649*77c1e3ccSAndroid Build Coastguard Worker   total_sse <<= 4;
650*77c1e3ccSAndroid Build Coastguard Worker   return total_sse;
651*77c1e3ccSAndroid Build Coastguard Worker }
652*77c1e3ccSAndroid Build Coastguard Worker 
get_est_rate_dist(const TileDataEnc * tile_data,BLOCK_SIZE bsize,int64_t sse,int * est_residue_cost,int64_t * est_dist)653*77c1e3ccSAndroid Build Coastguard Worker static int get_est_rate_dist(const TileDataEnc *tile_data, BLOCK_SIZE bsize,
654*77c1e3ccSAndroid Build Coastguard Worker                              int64_t sse, int *est_residue_cost,
655*77c1e3ccSAndroid Build Coastguard Worker                              int64_t *est_dist) {
656*77c1e3ccSAndroid Build Coastguard Worker   const InterModeRdModel *md = &tile_data->inter_mode_rd_models[bsize];
657*77c1e3ccSAndroid Build Coastguard Worker   if (md->ready) {
658*77c1e3ccSAndroid Build Coastguard Worker     if (sse < md->dist_mean) {
659*77c1e3ccSAndroid Build Coastguard Worker       *est_residue_cost = 0;
660*77c1e3ccSAndroid Build Coastguard Worker       *est_dist = sse;
661*77c1e3ccSAndroid Build Coastguard Worker     } else {
662*77c1e3ccSAndroid Build Coastguard Worker       *est_dist = (int64_t)round(md->dist_mean);
663*77c1e3ccSAndroid Build Coastguard Worker       const double est_ld = md->a * sse + md->b;
664*77c1e3ccSAndroid Build Coastguard Worker       // Clamp estimated rate cost by INT_MAX / 2.
665*77c1e3ccSAndroid Build Coastguard Worker       // TODO([email protected]): find better solution than clamping.
666*77c1e3ccSAndroid Build Coastguard Worker       if (fabs(est_ld) < 1e-2) {
667*77c1e3ccSAndroid Build Coastguard Worker         *est_residue_cost = INT_MAX / 2;
668*77c1e3ccSAndroid Build Coastguard Worker       } else {
669*77c1e3ccSAndroid Build Coastguard Worker         double est_residue_cost_dbl = ((sse - md->dist_mean) / est_ld);
670*77c1e3ccSAndroid Build Coastguard Worker         if (est_residue_cost_dbl < 0) {
671*77c1e3ccSAndroid Build Coastguard Worker           *est_residue_cost = 0;
672*77c1e3ccSAndroid Build Coastguard Worker         } else {
673*77c1e3ccSAndroid Build Coastguard Worker           *est_residue_cost =
674*77c1e3ccSAndroid Build Coastguard Worker               (int)AOMMIN((int64_t)round(est_residue_cost_dbl), INT_MAX / 2);
675*77c1e3ccSAndroid Build Coastguard Worker         }
676*77c1e3ccSAndroid Build Coastguard Worker       }
677*77c1e3ccSAndroid Build Coastguard Worker       if (*est_residue_cost <= 0) {
678*77c1e3ccSAndroid Build Coastguard Worker         *est_residue_cost = 0;
679*77c1e3ccSAndroid Build Coastguard Worker         *est_dist = sse;
680*77c1e3ccSAndroid Build Coastguard Worker       }
681*77c1e3ccSAndroid Build Coastguard Worker     }
682*77c1e3ccSAndroid Build Coastguard Worker     return 1;
683*77c1e3ccSAndroid Build Coastguard Worker   }
684*77c1e3ccSAndroid Build Coastguard Worker   return 0;
685*77c1e3ccSAndroid Build Coastguard Worker }
686*77c1e3ccSAndroid Build Coastguard Worker 
get_highbd_diff_mean(const uint8_t * src8,int src_stride,const uint8_t * dst8,int dst_stride,int w,int h)687*77c1e3ccSAndroid Build Coastguard Worker static double get_highbd_diff_mean(const uint8_t *src8, int src_stride,
688*77c1e3ccSAndroid Build Coastguard Worker                                    const uint8_t *dst8, int dst_stride, int w,
689*77c1e3ccSAndroid Build Coastguard Worker                                    int h) {
690*77c1e3ccSAndroid Build Coastguard Worker   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
691*77c1e3ccSAndroid Build Coastguard Worker   const uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
692*77c1e3ccSAndroid Build Coastguard Worker   double sum = 0.0;
693*77c1e3ccSAndroid Build Coastguard Worker   for (int j = 0; j < h; ++j) {
694*77c1e3ccSAndroid Build Coastguard Worker     for (int i = 0; i < w; ++i) {
695*77c1e3ccSAndroid Build Coastguard Worker       const int diff = src[j * src_stride + i] - dst[j * dst_stride + i];
696*77c1e3ccSAndroid Build Coastguard Worker       sum += diff;
697*77c1e3ccSAndroid Build Coastguard Worker     }
698*77c1e3ccSAndroid Build Coastguard Worker   }
699*77c1e3ccSAndroid Build Coastguard Worker   assert(w > 0 && h > 0);
700*77c1e3ccSAndroid Build Coastguard Worker   return sum / (w * h);
701*77c1e3ccSAndroid Build Coastguard Worker }
702*77c1e3ccSAndroid Build Coastguard Worker 
get_diff_mean(const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride,int w,int h)703*77c1e3ccSAndroid Build Coastguard Worker static double get_diff_mean(const uint8_t *src, int src_stride,
704*77c1e3ccSAndroid Build Coastguard Worker                             const uint8_t *dst, int dst_stride, int w, int h) {
705*77c1e3ccSAndroid Build Coastguard Worker   double sum = 0.0;
706*77c1e3ccSAndroid Build Coastguard Worker   for (int j = 0; j < h; ++j) {
707*77c1e3ccSAndroid Build Coastguard Worker     for (int i = 0; i < w; ++i) {
708*77c1e3ccSAndroid Build Coastguard Worker       const int diff = src[j * src_stride + i] - dst[j * dst_stride + i];
709*77c1e3ccSAndroid Build Coastguard Worker       sum += diff;
710*77c1e3ccSAndroid Build Coastguard Worker     }
711*77c1e3ccSAndroid Build Coastguard Worker   }
712*77c1e3ccSAndroid Build Coastguard Worker   assert(w > 0 && h > 0);
713*77c1e3ccSAndroid Build Coastguard Worker   return sum / (w * h);
714*77c1e3ccSAndroid Build Coastguard Worker }
715*77c1e3ccSAndroid Build Coastguard Worker 
PrintPredictionUnitStats(const AV1_COMP * const cpi,const TileDataEnc * tile_data,MACROBLOCK * x,const RD_STATS * const rd_stats,BLOCK_SIZE plane_bsize)716*77c1e3ccSAndroid Build Coastguard Worker static inline void PrintPredictionUnitStats(const AV1_COMP *const cpi,
717*77c1e3ccSAndroid Build Coastguard Worker                                             const TileDataEnc *tile_data,
718*77c1e3ccSAndroid Build Coastguard Worker                                             MACROBLOCK *x,
719*77c1e3ccSAndroid Build Coastguard Worker                                             const RD_STATS *const rd_stats,
720*77c1e3ccSAndroid Build Coastguard Worker                                             BLOCK_SIZE plane_bsize) {
721*77c1e3ccSAndroid Build Coastguard Worker   if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return;
722*77c1e3ccSAndroid Build Coastguard Worker 
723*77c1e3ccSAndroid Build Coastguard Worker   if (cpi->sf.inter_sf.inter_mode_rd_model_estimation == 1 &&
724*77c1e3ccSAndroid Build Coastguard Worker       (tile_data == NULL ||
725*77c1e3ccSAndroid Build Coastguard Worker        !tile_data->inter_mode_rd_models[plane_bsize].ready))
726*77c1e3ccSAndroid Build Coastguard Worker     return;
727*77c1e3ccSAndroid Build Coastguard Worker   (void)tile_data;
728*77c1e3ccSAndroid Build Coastguard Worker   // Generate small sample to restrict output size.
729*77c1e3ccSAndroid Build Coastguard Worker   static unsigned int seed = 95014;
730*77c1e3ccSAndroid Build Coastguard Worker 
731*77c1e3ccSAndroid Build Coastguard Worker   if ((lcg_rand16(&seed) % (1 << (14 - num_pels_log2_lookup[plane_bsize]))) !=
732*77c1e3ccSAndroid Build Coastguard Worker       1)
733*77c1e3ccSAndroid Build Coastguard Worker     return;
734*77c1e3ccSAndroid Build Coastguard Worker 
735*77c1e3ccSAndroid Build Coastguard Worker   const char output_file[] = "pu_stats.txt";
736*77c1e3ccSAndroid Build Coastguard Worker   FILE *fout = fopen(output_file, "a");
737*77c1e3ccSAndroid Build Coastguard Worker   if (!fout) return;
738*77c1e3ccSAndroid Build Coastguard Worker 
739*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCKD *const xd = &x->e_mbd;
740*77c1e3ccSAndroid Build Coastguard Worker   const int plane = 0;
741*77c1e3ccSAndroid Build Coastguard Worker   struct macroblock_plane *const p = &x->plane[plane];
742*77c1e3ccSAndroid Build Coastguard Worker   struct macroblockd_plane *pd = &xd->plane[plane];
743*77c1e3ccSAndroid Build Coastguard Worker   const int diff_stride = block_size_wide[plane_bsize];
744*77c1e3ccSAndroid Build Coastguard Worker   int bw, bh;
745*77c1e3ccSAndroid Build Coastguard Worker   get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL, &bw,
746*77c1e3ccSAndroid Build Coastguard Worker                      &bh);
747*77c1e3ccSAndroid Build Coastguard Worker   const int num_samples = bw * bh;
748*77c1e3ccSAndroid Build Coastguard Worker   const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
749*77c1e3ccSAndroid Build Coastguard Worker   const int q_step = p->dequant_QTX[1] >> dequant_shift;
750*77c1e3ccSAndroid Build Coastguard Worker   const int shift = (xd->bd - 8);
751*77c1e3ccSAndroid Build Coastguard Worker 
752*77c1e3ccSAndroid Build Coastguard Worker   const double rate_norm = (double)rd_stats->rate / num_samples;
753*77c1e3ccSAndroid Build Coastguard Worker   const double dist_norm = (double)rd_stats->dist / num_samples;
754*77c1e3ccSAndroid Build Coastguard Worker   const double rdcost_norm =
755*77c1e3ccSAndroid Build Coastguard Worker       (double)RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) / num_samples;
756*77c1e3ccSAndroid Build Coastguard Worker 
757*77c1e3ccSAndroid Build Coastguard Worker   fprintf(fout, "%g %g %g", rate_norm, dist_norm, rdcost_norm);
758*77c1e3ccSAndroid Build Coastguard Worker 
759*77c1e3ccSAndroid Build Coastguard Worker   const int src_stride = p->src.stride;
760*77c1e3ccSAndroid Build Coastguard Worker   const uint8_t *const src = p->src.buf;
761*77c1e3ccSAndroid Build Coastguard Worker   const int dst_stride = pd->dst.stride;
762*77c1e3ccSAndroid Build Coastguard Worker   const uint8_t *const dst = pd->dst.buf;
763*77c1e3ccSAndroid Build Coastguard Worker   const int16_t *const src_diff = p->src_diff;
764*77c1e3ccSAndroid Build Coastguard Worker 
765*77c1e3ccSAndroid Build Coastguard Worker   int64_t sse = calculate_sse(xd, p, pd, bw, bh);
766*77c1e3ccSAndroid Build Coastguard Worker   const double sse_norm = (double)sse / num_samples;
767*77c1e3ccSAndroid Build Coastguard Worker 
768*77c1e3ccSAndroid Build Coastguard Worker   const unsigned int sad =
769*77c1e3ccSAndroid Build Coastguard Worker       cpi->ppi->fn_ptr[plane_bsize].sdf(src, src_stride, dst, dst_stride);
770*77c1e3ccSAndroid Build Coastguard Worker   const double sad_norm =
771*77c1e3ccSAndroid Build Coastguard Worker       (double)sad / (1 << num_pels_log2_lookup[plane_bsize]);
772*77c1e3ccSAndroid Build Coastguard Worker 
773*77c1e3ccSAndroid Build Coastguard Worker   fprintf(fout, " %g %g", sse_norm, sad_norm);
774*77c1e3ccSAndroid Build Coastguard Worker 
775*77c1e3ccSAndroid Build Coastguard Worker   double sse_norm_arr[4], sad_norm_arr[4];
776*77c1e3ccSAndroid Build Coastguard Worker   get_2x2_normalized_sses_and_sads(cpi, plane_bsize, src, src_stride, dst,
777*77c1e3ccSAndroid Build Coastguard Worker                                    dst_stride, src_diff, diff_stride,
778*77c1e3ccSAndroid Build Coastguard Worker                                    sse_norm_arr, sad_norm_arr);
779*77c1e3ccSAndroid Build Coastguard Worker   if (shift) {
780*77c1e3ccSAndroid Build Coastguard Worker     for (int k = 0; k < 4; ++k) sse_norm_arr[k] /= (1 << (2 * shift));
781*77c1e3ccSAndroid Build Coastguard Worker     for (int k = 0; k < 4; ++k) sad_norm_arr[k] /= (1 << shift);
782*77c1e3ccSAndroid Build Coastguard Worker   }
783*77c1e3ccSAndroid Build Coastguard Worker   for (int i = 0; i < 4; ++i) {
784*77c1e3ccSAndroid Build Coastguard Worker     fprintf(fout, " %g", sse_norm_arr[i]);
785*77c1e3ccSAndroid Build Coastguard Worker   }
786*77c1e3ccSAndroid Build Coastguard Worker   for (int i = 0; i < 4; ++i) {
787*77c1e3ccSAndroid Build Coastguard Worker     fprintf(fout, " %g", sad_norm_arr[i]);
788*77c1e3ccSAndroid Build Coastguard Worker   }
789*77c1e3ccSAndroid Build Coastguard Worker 
790*77c1e3ccSAndroid Build Coastguard Worker   fprintf(fout, " %d %d %d %d", q_step, x->rdmult, bw, bh);
791*77c1e3ccSAndroid Build Coastguard Worker 
792*77c1e3ccSAndroid Build Coastguard Worker   int model_rate;
793*77c1e3ccSAndroid Build Coastguard Worker   int64_t model_dist;
794*77c1e3ccSAndroid Build Coastguard Worker   model_rd_sse_fn[MODELRD_CURVFIT](cpi, x, plane_bsize, plane, sse, num_samples,
795*77c1e3ccSAndroid Build Coastguard Worker                                    &model_rate, &model_dist);
796*77c1e3ccSAndroid Build Coastguard Worker   const double model_rdcost_norm =
797*77c1e3ccSAndroid Build Coastguard Worker       (double)RDCOST(x->rdmult, model_rate, model_dist) / num_samples;
798*77c1e3ccSAndroid Build Coastguard Worker   const double model_rate_norm = (double)model_rate / num_samples;
799*77c1e3ccSAndroid Build Coastguard Worker   const double model_dist_norm = (double)model_dist / num_samples;
800*77c1e3ccSAndroid Build Coastguard Worker   fprintf(fout, " %g %g %g", model_rate_norm, model_dist_norm,
801*77c1e3ccSAndroid Build Coastguard Worker           model_rdcost_norm);
802*77c1e3ccSAndroid Build Coastguard Worker 
803*77c1e3ccSAndroid Build Coastguard Worker   double mean;
804*77c1e3ccSAndroid Build Coastguard Worker   if (is_cur_buf_hbd(xd)) {
805*77c1e3ccSAndroid Build Coastguard Worker     mean = get_highbd_diff_mean(p->src.buf, p->src.stride, pd->dst.buf,
806*77c1e3ccSAndroid Build Coastguard Worker                                 pd->dst.stride, bw, bh);
807*77c1e3ccSAndroid Build Coastguard Worker   } else {
808*77c1e3ccSAndroid Build Coastguard Worker     mean = get_diff_mean(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride,
809*77c1e3ccSAndroid Build Coastguard Worker                          bw, bh);
810*77c1e3ccSAndroid Build Coastguard Worker   }
811*77c1e3ccSAndroid Build Coastguard Worker   mean /= (1 << shift);
812*77c1e3ccSAndroid Build Coastguard Worker   float hor_corr, vert_corr;
813*77c1e3ccSAndroid Build Coastguard Worker   av1_get_horver_correlation_full(src_diff, diff_stride, bw, bh, &hor_corr,
814*77c1e3ccSAndroid Build Coastguard Worker                                   &vert_corr);
815*77c1e3ccSAndroid Build Coastguard Worker   fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);
816*77c1e3ccSAndroid Build Coastguard Worker 
817*77c1e3ccSAndroid Build Coastguard Worker   double hdist[4] = { 0 }, vdist[4] = { 0 };
818*77c1e3ccSAndroid Build Coastguard Worker   get_energy_distribution_fine(cpi, plane_bsize, src, src_stride, dst,
819*77c1e3ccSAndroid Build Coastguard Worker                                dst_stride, 1, hdist, vdist);
820*77c1e3ccSAndroid Build Coastguard Worker   fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
821*77c1e3ccSAndroid Build Coastguard Worker           hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);
822*77c1e3ccSAndroid Build Coastguard Worker 
823*77c1e3ccSAndroid Build Coastguard Worker   if (cpi->sf.inter_sf.inter_mode_rd_model_estimation == 1) {
824*77c1e3ccSAndroid Build Coastguard Worker     assert(tile_data->inter_mode_rd_models[plane_bsize].ready);
825*77c1e3ccSAndroid Build Coastguard Worker     const int64_t overall_sse = get_sse(cpi, x);
826*77c1e3ccSAndroid Build Coastguard Worker     int est_residue_cost = 0;
827*77c1e3ccSAndroid Build Coastguard Worker     int64_t est_dist = 0;
828*77c1e3ccSAndroid Build Coastguard Worker     get_est_rate_dist(tile_data, plane_bsize, overall_sse, &est_residue_cost,
829*77c1e3ccSAndroid Build Coastguard Worker                       &est_dist);
830*77c1e3ccSAndroid Build Coastguard Worker     const double est_residue_cost_norm = (double)est_residue_cost / num_samples;
831*77c1e3ccSAndroid Build Coastguard Worker     const double est_dist_norm = (double)est_dist / num_samples;
832*77c1e3ccSAndroid Build Coastguard Worker     const double est_rdcost_norm =
833*77c1e3ccSAndroid Build Coastguard Worker         (double)RDCOST(x->rdmult, est_residue_cost, est_dist) / num_samples;
834*77c1e3ccSAndroid Build Coastguard Worker     fprintf(fout, " %g %g %g", est_residue_cost_norm, est_dist_norm,
835*77c1e3ccSAndroid Build Coastguard Worker             est_rdcost_norm);
836*77c1e3ccSAndroid Build Coastguard Worker   }
837*77c1e3ccSAndroid Build Coastguard Worker 
838*77c1e3ccSAndroid Build Coastguard Worker   fprintf(fout, "\n");
839*77c1e3ccSAndroid Build Coastguard Worker   fclose(fout);
840*77c1e3ccSAndroid Build Coastguard Worker }
841*77c1e3ccSAndroid Build Coastguard Worker #endif  // CONFIG_COLLECT_RD_STATS >= 2
842*77c1e3ccSAndroid Build Coastguard Worker #endif  // CONFIG_COLLECT_RD_STATS
843*77c1e3ccSAndroid Build Coastguard Worker 
inverse_transform_block_facade(MACROBLOCK * const x,int plane,int block,int blk_row,int blk_col,int eob,int reduced_tx_set)844*77c1e3ccSAndroid Build Coastguard Worker static inline void inverse_transform_block_facade(MACROBLOCK *const x,
845*77c1e3ccSAndroid Build Coastguard Worker                                                   int plane, int block,
846*77c1e3ccSAndroid Build Coastguard Worker                                                   int blk_row, int blk_col,
847*77c1e3ccSAndroid Build Coastguard Worker                                                   int eob, int reduced_tx_set) {
848*77c1e3ccSAndroid Build Coastguard Worker   if (!eob) return;
849*77c1e3ccSAndroid Build Coastguard Worker   struct macroblock_plane *const p = &x->plane[plane];
850*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCKD *const xd = &x->e_mbd;
851*77c1e3ccSAndroid Build Coastguard Worker   tran_low_t *dqcoeff = p->dqcoeff + BLOCK_OFFSET(block);
852*77c1e3ccSAndroid Build Coastguard Worker   const PLANE_TYPE plane_type = get_plane_type(plane);
853*77c1e3ccSAndroid Build Coastguard Worker   const TX_SIZE tx_size = av1_get_tx_size(plane, xd);
854*77c1e3ccSAndroid Build Coastguard Worker   const TX_TYPE tx_type = av1_get_tx_type(xd, plane_type, blk_row, blk_col,
855*77c1e3ccSAndroid Build Coastguard Worker                                           tx_size, reduced_tx_set);
856*77c1e3ccSAndroid Build Coastguard Worker 
857*77c1e3ccSAndroid Build Coastguard Worker   struct macroblockd_plane *const pd = &xd->plane[plane];
858*77c1e3ccSAndroid Build Coastguard Worker   const int dst_stride = pd->dst.stride;
859*77c1e3ccSAndroid Build Coastguard Worker   uint8_t *dst = &pd->dst.buf[(blk_row * dst_stride + blk_col) << MI_SIZE_LOG2];
860*77c1e3ccSAndroid Build Coastguard Worker   av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, dst,
861*77c1e3ccSAndroid Build Coastguard Worker                               dst_stride, eob, reduced_tx_set);
862*77c1e3ccSAndroid Build Coastguard Worker }
863*77c1e3ccSAndroid Build Coastguard Worker 
recon_intra(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,const TXB_CTX * const txb_ctx,int skip_trellis,TX_TYPE best_tx_type,int do_quant,int * rate_cost,uint16_t best_eob)864*77c1e3ccSAndroid Build Coastguard Worker static inline void recon_intra(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
865*77c1e3ccSAndroid Build Coastguard Worker                                int block, int blk_row, int blk_col,
866*77c1e3ccSAndroid Build Coastguard Worker                                BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
867*77c1e3ccSAndroid Build Coastguard Worker                                const TXB_CTX *const txb_ctx, int skip_trellis,
868*77c1e3ccSAndroid Build Coastguard Worker                                TX_TYPE best_tx_type, int do_quant,
869*77c1e3ccSAndroid Build Coastguard Worker                                int *rate_cost, uint16_t best_eob) {
870*77c1e3ccSAndroid Build Coastguard Worker   const AV1_COMMON *cm = &cpi->common;
871*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCKD *xd = &x->e_mbd;
872*77c1e3ccSAndroid Build Coastguard Worker   MB_MODE_INFO *mbmi = xd->mi[0];
873*77c1e3ccSAndroid Build Coastguard Worker   const int is_inter = is_inter_block(mbmi);
874*77c1e3ccSAndroid Build Coastguard Worker   if (!is_inter && best_eob &&
875*77c1e3ccSAndroid Build Coastguard Worker       (blk_row + tx_size_high_unit[tx_size] < mi_size_high[plane_bsize] ||
876*77c1e3ccSAndroid Build Coastguard Worker        blk_col + tx_size_wide_unit[tx_size] < mi_size_wide[plane_bsize])) {
877*77c1e3ccSAndroid Build Coastguard Worker     // if the quantized coefficients are stored in the dqcoeff buffer, we don't
878*77c1e3ccSAndroid Build Coastguard Worker     // need to do transform and quantization again.
879*77c1e3ccSAndroid Build Coastguard Worker     if (do_quant) {
880*77c1e3ccSAndroid Build Coastguard Worker       TxfmParam txfm_param_intra;
881*77c1e3ccSAndroid Build Coastguard Worker       QUANT_PARAM quant_param_intra;
882*77c1e3ccSAndroid Build Coastguard Worker       av1_setup_xform(cm, x, tx_size, best_tx_type, &txfm_param_intra);
883*77c1e3ccSAndroid Build Coastguard Worker       av1_setup_quant(tx_size, !skip_trellis,
884*77c1e3ccSAndroid Build Coastguard Worker                       skip_trellis
885*77c1e3ccSAndroid Build Coastguard Worker                           ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B
886*77c1e3ccSAndroid Build Coastguard Worker                                                     : AV1_XFORM_QUANT_FP)
887*77c1e3ccSAndroid Build Coastguard Worker                           : AV1_XFORM_QUANT_FP,
888*77c1e3ccSAndroid Build Coastguard Worker                       cpi->oxcf.q_cfg.quant_b_adapt, &quant_param_intra);
889*77c1e3ccSAndroid Build Coastguard Worker       av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, best_tx_type,
890*77c1e3ccSAndroid Build Coastguard Worker                         &quant_param_intra);
891*77c1e3ccSAndroid Build Coastguard Worker       av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize,
892*77c1e3ccSAndroid Build Coastguard Worker                       &txfm_param_intra, &quant_param_intra);
893*77c1e3ccSAndroid Build Coastguard Worker       if (quant_param_intra.use_optimize_b) {
894*77c1e3ccSAndroid Build Coastguard Worker         av1_optimize_b(cpi, x, plane, block, tx_size, best_tx_type, txb_ctx,
895*77c1e3ccSAndroid Build Coastguard Worker                        rate_cost);
896*77c1e3ccSAndroid Build Coastguard Worker       }
897*77c1e3ccSAndroid Build Coastguard Worker     }
898*77c1e3ccSAndroid Build Coastguard Worker 
899*77c1e3ccSAndroid Build Coastguard Worker     inverse_transform_block_facade(x, plane, block, blk_row, blk_col,
900*77c1e3ccSAndroid Build Coastguard Worker                                    x->plane[plane].eobs[block],
901*77c1e3ccSAndroid Build Coastguard Worker                                    cm->features.reduced_tx_set_used);
902*77c1e3ccSAndroid Build Coastguard Worker 
903*77c1e3ccSAndroid Build Coastguard Worker     // This may happen because of hash collision. The eob stored in the hash
904*77c1e3ccSAndroid Build Coastguard Worker     // table is non-zero, but the real eob is zero. We need to make sure tx_type
905*77c1e3ccSAndroid Build Coastguard Worker     // is DCT_DCT in this case.
906*77c1e3ccSAndroid Build Coastguard Worker     if (plane == 0 && x->plane[plane].eobs[block] == 0 &&
907*77c1e3ccSAndroid Build Coastguard Worker         best_tx_type != DCT_DCT) {
908*77c1e3ccSAndroid Build Coastguard Worker       update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
909*77c1e3ccSAndroid Build Coastguard Worker     }
910*77c1e3ccSAndroid Build Coastguard Worker   }
911*77c1e3ccSAndroid Build Coastguard Worker }
912*77c1e3ccSAndroid Build Coastguard Worker 
pixel_dist_visible_only(const AV1_COMP * const cpi,const MACROBLOCK * x,const uint8_t * src,const int src_stride,const uint8_t * dst,const int dst_stride,const BLOCK_SIZE tx_bsize,int txb_rows,int txb_cols,int visible_rows,int visible_cols)913*77c1e3ccSAndroid Build Coastguard Worker static unsigned pixel_dist_visible_only(
914*77c1e3ccSAndroid Build Coastguard Worker     const AV1_COMP *const cpi, const MACROBLOCK *x, const uint8_t *src,
915*77c1e3ccSAndroid Build Coastguard Worker     const int src_stride, const uint8_t *dst, const int dst_stride,
916*77c1e3ccSAndroid Build Coastguard Worker     const BLOCK_SIZE tx_bsize, int txb_rows, int txb_cols, int visible_rows,
917*77c1e3ccSAndroid Build Coastguard Worker     int visible_cols) {
918*77c1e3ccSAndroid Build Coastguard Worker   unsigned sse;
919*77c1e3ccSAndroid Build Coastguard Worker 
920*77c1e3ccSAndroid Build Coastguard Worker   if (txb_rows == visible_rows && txb_cols == visible_cols) {
921*77c1e3ccSAndroid Build Coastguard Worker     cpi->ppi->fn_ptr[tx_bsize].vf(src, src_stride, dst, dst_stride, &sse);
922*77c1e3ccSAndroid Build Coastguard Worker     return sse;
923*77c1e3ccSAndroid Build Coastguard Worker   }
924*77c1e3ccSAndroid Build Coastguard Worker 
925*77c1e3ccSAndroid Build Coastguard Worker #if CONFIG_AV1_HIGHBITDEPTH
926*77c1e3ccSAndroid Build Coastguard Worker   const MACROBLOCKD *xd = &x->e_mbd;
927*77c1e3ccSAndroid Build Coastguard Worker   if (is_cur_buf_hbd(xd)) {
928*77c1e3ccSAndroid Build Coastguard Worker     uint64_t sse64 = aom_highbd_sse_odd_size(src, src_stride, dst, dst_stride,
929*77c1e3ccSAndroid Build Coastguard Worker                                              visible_cols, visible_rows);
930*77c1e3ccSAndroid Build Coastguard Worker     return (unsigned int)ROUND_POWER_OF_TWO(sse64, (xd->bd - 8) * 2);
931*77c1e3ccSAndroid Build Coastguard Worker   }
932*77c1e3ccSAndroid Build Coastguard Worker #else
933*77c1e3ccSAndroid Build Coastguard Worker   (void)x;
934*77c1e3ccSAndroid Build Coastguard Worker #endif
935*77c1e3ccSAndroid Build Coastguard Worker   sse = aom_sse_odd_size(src, src_stride, dst, dst_stride, visible_cols,
936*77c1e3ccSAndroid Build Coastguard Worker                          visible_rows);
937*77c1e3ccSAndroid Build Coastguard Worker   return sse;
938*77c1e3ccSAndroid Build Coastguard Worker }
939*77c1e3ccSAndroid Build Coastguard Worker 
940*77c1e3ccSAndroid Build Coastguard Worker // Compute the pixel domain distortion from src and dst on all visible 4x4s in
941*77c1e3ccSAndroid Build Coastguard Worker // the
942*77c1e3ccSAndroid Build Coastguard Worker // transform block.
pixel_dist(const AV1_COMP * const cpi,const MACROBLOCK * x,int plane,const uint8_t * src,const int src_stride,const uint8_t * dst,const int dst_stride,int blk_row,int blk_col,const BLOCK_SIZE plane_bsize,const BLOCK_SIZE tx_bsize)943*77c1e3ccSAndroid Build Coastguard Worker static unsigned pixel_dist(const AV1_COMP *const cpi, const MACROBLOCK *x,
944*77c1e3ccSAndroid Build Coastguard Worker                            int plane, const uint8_t *src, const int src_stride,
945*77c1e3ccSAndroid Build Coastguard Worker                            const uint8_t *dst, const int dst_stride,
946*77c1e3ccSAndroid Build Coastguard Worker                            int blk_row, int blk_col,
947*77c1e3ccSAndroid Build Coastguard Worker                            const BLOCK_SIZE plane_bsize,
948*77c1e3ccSAndroid Build Coastguard Worker                            const BLOCK_SIZE tx_bsize) {
949*77c1e3ccSAndroid Build Coastguard Worker   int txb_rows, txb_cols, visible_rows, visible_cols;
950*77c1e3ccSAndroid Build Coastguard Worker   const MACROBLOCKD *xd = &x->e_mbd;
951*77c1e3ccSAndroid Build Coastguard Worker 
952*77c1e3ccSAndroid Build Coastguard Worker   get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize,
953*77c1e3ccSAndroid Build Coastguard Worker                      &txb_cols, &txb_rows, &visible_cols, &visible_rows);
954*77c1e3ccSAndroid Build Coastguard Worker   assert(visible_rows > 0);
955*77c1e3ccSAndroid Build Coastguard Worker   assert(visible_cols > 0);
956*77c1e3ccSAndroid Build Coastguard Worker 
957*77c1e3ccSAndroid Build Coastguard Worker   unsigned sse = pixel_dist_visible_only(cpi, x, src, src_stride, dst,
958*77c1e3ccSAndroid Build Coastguard Worker                                          dst_stride, tx_bsize, txb_rows,
959*77c1e3ccSAndroid Build Coastguard Worker                                          txb_cols, visible_rows, visible_cols);
960*77c1e3ccSAndroid Build Coastguard Worker 
961*77c1e3ccSAndroid Build Coastguard Worker   return sse;
962*77c1e3ccSAndroid Build Coastguard Worker }
963*77c1e3ccSAndroid Build Coastguard Worker 
dist_block_px_domain(const AV1_COMP * cpi,MACROBLOCK * x,int plane,BLOCK_SIZE plane_bsize,int block,int blk_row,int blk_col,TX_SIZE tx_size)964*77c1e3ccSAndroid Build Coastguard Worker static inline int64_t dist_block_px_domain(const AV1_COMP *cpi, MACROBLOCK *x,
965*77c1e3ccSAndroid Build Coastguard Worker                                            int plane, BLOCK_SIZE plane_bsize,
966*77c1e3ccSAndroid Build Coastguard Worker                                            int block, int blk_row, int blk_col,
967*77c1e3ccSAndroid Build Coastguard Worker                                            TX_SIZE tx_size) {
968*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCKD *const xd = &x->e_mbd;
969*77c1e3ccSAndroid Build Coastguard Worker   const struct macroblock_plane *const p = &x->plane[plane];
970*77c1e3ccSAndroid Build Coastguard Worker   const uint16_t eob = p->eobs[block];
971*77c1e3ccSAndroid Build Coastguard Worker   const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
972*77c1e3ccSAndroid Build Coastguard Worker   const int bsw = block_size_wide[tx_bsize];
973*77c1e3ccSAndroid Build Coastguard Worker   const int bsh = block_size_high[tx_bsize];
974*77c1e3ccSAndroid Build Coastguard Worker   const int src_stride = x->plane[plane].src.stride;
975*77c1e3ccSAndroid Build Coastguard Worker   const int dst_stride = xd->plane[plane].dst.stride;
976*77c1e3ccSAndroid Build Coastguard Worker   // Scale the transform block index to pixel unit.
977*77c1e3ccSAndroid Build Coastguard Worker   const int src_idx = (blk_row * src_stride + blk_col) << MI_SIZE_LOG2;
978*77c1e3ccSAndroid Build Coastguard Worker   const int dst_idx = (blk_row * dst_stride + blk_col) << MI_SIZE_LOG2;
979*77c1e3ccSAndroid Build Coastguard Worker   const uint8_t *src = &x->plane[plane].src.buf[src_idx];
980*77c1e3ccSAndroid Build Coastguard Worker   const uint8_t *dst = &xd->plane[plane].dst.buf[dst_idx];
981*77c1e3ccSAndroid Build Coastguard Worker   const tran_low_t *dqcoeff = p->dqcoeff + BLOCK_OFFSET(block);
982*77c1e3ccSAndroid Build Coastguard Worker 
983*77c1e3ccSAndroid Build Coastguard Worker   assert(cpi != NULL);
984*77c1e3ccSAndroid Build Coastguard Worker   assert(tx_size_wide_log2[0] == tx_size_high_log2[0]);
985*77c1e3ccSAndroid Build Coastguard Worker 
986*77c1e3ccSAndroid Build Coastguard Worker   uint8_t *recon;
987*77c1e3ccSAndroid Build Coastguard Worker   DECLARE_ALIGNED(16, uint16_t, recon16[MAX_TX_SQUARE]);
988*77c1e3ccSAndroid Build Coastguard Worker 
989*77c1e3ccSAndroid Build Coastguard Worker #if CONFIG_AV1_HIGHBITDEPTH
990*77c1e3ccSAndroid Build Coastguard Worker   if (is_cur_buf_hbd(xd)) {
991*77c1e3ccSAndroid Build Coastguard Worker     recon = CONVERT_TO_BYTEPTR(recon16);
992*77c1e3ccSAndroid Build Coastguard Worker     aom_highbd_convolve_copy(CONVERT_TO_SHORTPTR(dst), dst_stride,
993*77c1e3ccSAndroid Build Coastguard Worker                              CONVERT_TO_SHORTPTR(recon), MAX_TX_SIZE, bsw, bsh);
994*77c1e3ccSAndroid Build Coastguard Worker   } else {
995*77c1e3ccSAndroid Build Coastguard Worker     recon = (uint8_t *)recon16;
996*77c1e3ccSAndroid Build Coastguard Worker     aom_convolve_copy(dst, dst_stride, recon, MAX_TX_SIZE, bsw, bsh);
997*77c1e3ccSAndroid Build Coastguard Worker   }
998*77c1e3ccSAndroid Build Coastguard Worker #else
999*77c1e3ccSAndroid Build Coastguard Worker   recon = (uint8_t *)recon16;
1000*77c1e3ccSAndroid Build Coastguard Worker   aom_convolve_copy(dst, dst_stride, recon, MAX_TX_SIZE, bsw, bsh);
1001*77c1e3ccSAndroid Build Coastguard Worker #endif
1002*77c1e3ccSAndroid Build Coastguard Worker 
1003*77c1e3ccSAndroid Build Coastguard Worker   const PLANE_TYPE plane_type = get_plane_type(plane);
1004*77c1e3ccSAndroid Build Coastguard Worker   TX_TYPE tx_type = av1_get_tx_type(xd, plane_type, blk_row, blk_col, tx_size,
1005*77c1e3ccSAndroid Build Coastguard Worker                                     cpi->common.features.reduced_tx_set_used);
1006*77c1e3ccSAndroid Build Coastguard Worker   av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, recon,
1007*77c1e3ccSAndroid Build Coastguard Worker                               MAX_TX_SIZE, eob,
1008*77c1e3ccSAndroid Build Coastguard Worker                               cpi->common.features.reduced_tx_set_used);
1009*77c1e3ccSAndroid Build Coastguard Worker 
1010*77c1e3ccSAndroid Build Coastguard Worker   return 16 * pixel_dist(cpi, x, plane, src, src_stride, recon, MAX_TX_SIZE,
1011*77c1e3ccSAndroid Build Coastguard Worker                          blk_row, blk_col, plane_bsize, tx_bsize);
1012*77c1e3ccSAndroid Build Coastguard Worker }
1013*77c1e3ccSAndroid Build Coastguard Worker 
1014*77c1e3ccSAndroid Build Coastguard Worker // pruning thresholds for prune_txk_type and prune_txk_type_separ
1015*77c1e3ccSAndroid Build Coastguard Worker static const int prune_factors[5] = { 200, 200, 120, 80, 40 };  // scale 1000
1016*77c1e3ccSAndroid Build Coastguard Worker static const int mul_factors[5] = { 80, 80, 70, 50, 30 };       // scale 100
1017*77c1e3ccSAndroid Build Coastguard Worker 
1018*77c1e3ccSAndroid Build Coastguard Worker // R-D costs are sorted in ascending order.
sort_rd(int64_t rds[],int txk[],int len)1019*77c1e3ccSAndroid Build Coastguard Worker static inline void sort_rd(int64_t rds[], int txk[], int len) {
1020*77c1e3ccSAndroid Build Coastguard Worker   int i, j, k;
1021*77c1e3ccSAndroid Build Coastguard Worker 
1022*77c1e3ccSAndroid Build Coastguard Worker   for (i = 1; i <= len - 1; ++i) {
1023*77c1e3ccSAndroid Build Coastguard Worker     for (j = 0; j < i; ++j) {
1024*77c1e3ccSAndroid Build Coastguard Worker       if (rds[j] > rds[i]) {
1025*77c1e3ccSAndroid Build Coastguard Worker         int64_t temprd;
1026*77c1e3ccSAndroid Build Coastguard Worker         int tempi;
1027*77c1e3ccSAndroid Build Coastguard Worker 
1028*77c1e3ccSAndroid Build Coastguard Worker         temprd = rds[i];
1029*77c1e3ccSAndroid Build Coastguard Worker         tempi = txk[i];
1030*77c1e3ccSAndroid Build Coastguard Worker 
1031*77c1e3ccSAndroid Build Coastguard Worker         for (k = i; k > j; k--) {
1032*77c1e3ccSAndroid Build Coastguard Worker           rds[k] = rds[k - 1];
1033*77c1e3ccSAndroid Build Coastguard Worker           txk[k] = txk[k - 1];
1034*77c1e3ccSAndroid Build Coastguard Worker         }
1035*77c1e3ccSAndroid Build Coastguard Worker 
1036*77c1e3ccSAndroid Build Coastguard Worker         rds[j] = temprd;
1037*77c1e3ccSAndroid Build Coastguard Worker         txk[j] = tempi;
1038*77c1e3ccSAndroid Build Coastguard Worker         break;
1039*77c1e3ccSAndroid Build Coastguard Worker       }
1040*77c1e3ccSAndroid Build Coastguard Worker     }
1041*77c1e3ccSAndroid Build Coastguard Worker   }
1042*77c1e3ccSAndroid Build Coastguard Worker }
1043*77c1e3ccSAndroid Build Coastguard Worker 
av1_block_error_qm(const tran_low_t * coeff,const tran_low_t * dqcoeff,intptr_t block_size,const qm_val_t * qmatrix,const int16_t * scan,int64_t * ssz)1044*77c1e3ccSAndroid Build Coastguard Worker static inline int64_t av1_block_error_qm(const tran_low_t *coeff,
1045*77c1e3ccSAndroid Build Coastguard Worker                                          const tran_low_t *dqcoeff,
1046*77c1e3ccSAndroid Build Coastguard Worker                                          intptr_t block_size,
1047*77c1e3ccSAndroid Build Coastguard Worker                                          const qm_val_t *qmatrix,
1048*77c1e3ccSAndroid Build Coastguard Worker                                          const int16_t *scan, int64_t *ssz) {
1049*77c1e3ccSAndroid Build Coastguard Worker   int i;
1050*77c1e3ccSAndroid Build Coastguard Worker   int64_t error = 0, sqcoeff = 0;
1051*77c1e3ccSAndroid Build Coastguard Worker 
1052*77c1e3ccSAndroid Build Coastguard Worker   for (i = 0; i < block_size; i++) {
1053*77c1e3ccSAndroid Build Coastguard Worker     int64_t weight = qmatrix[scan[i]];
1054*77c1e3ccSAndroid Build Coastguard Worker     int64_t dd = coeff[i] - dqcoeff[i];
1055*77c1e3ccSAndroid Build Coastguard Worker     dd *= weight;
1056*77c1e3ccSAndroid Build Coastguard Worker     int64_t cc = coeff[i];
1057*77c1e3ccSAndroid Build Coastguard Worker     cc *= weight;
1058*77c1e3ccSAndroid Build Coastguard Worker     // The ranges of coeff and dqcoeff are
1059*77c1e3ccSAndroid Build Coastguard Worker     //  bd8 : 18 bits (including sign)
1060*77c1e3ccSAndroid Build Coastguard Worker     //  bd10: 20 bits (including sign)
1061*77c1e3ccSAndroid Build Coastguard Worker     //  bd12: 22 bits (including sign)
1062*77c1e3ccSAndroid Build Coastguard Worker     // As AOM_QM_BITS is 5, the intermediate quantities in the calculation
1063*77c1e3ccSAndroid Build Coastguard Worker     // below should fit in 54 bits, thus no overflow should happen.
1064*77c1e3ccSAndroid Build Coastguard Worker     error += (dd * dd + (1 << (2 * AOM_QM_BITS - 1))) >> (2 * AOM_QM_BITS);
1065*77c1e3ccSAndroid Build Coastguard Worker     sqcoeff += (cc * cc + (1 << (2 * AOM_QM_BITS - 1))) >> (2 * AOM_QM_BITS);
1066*77c1e3ccSAndroid Build Coastguard Worker   }
1067*77c1e3ccSAndroid Build Coastguard Worker 
1068*77c1e3ccSAndroid Build Coastguard Worker   *ssz = sqcoeff;
1069*77c1e3ccSAndroid Build Coastguard Worker   return error;
1070*77c1e3ccSAndroid Build Coastguard Worker }
1071*77c1e3ccSAndroid Build Coastguard Worker 
dist_block_tx_domain(MACROBLOCK * x,int plane,int block,TX_SIZE tx_size,const qm_val_t * qmatrix,const int16_t * scan,int64_t * out_dist,int64_t * out_sse)1072*77c1e3ccSAndroid Build Coastguard Worker static inline void dist_block_tx_domain(MACROBLOCK *x, int plane, int block,
1073*77c1e3ccSAndroid Build Coastguard Worker                                         TX_SIZE tx_size,
1074*77c1e3ccSAndroid Build Coastguard Worker                                         const qm_val_t *qmatrix,
1075*77c1e3ccSAndroid Build Coastguard Worker                                         const int16_t *scan, int64_t *out_dist,
1076*77c1e3ccSAndroid Build Coastguard Worker                                         int64_t *out_sse) {
1077*77c1e3ccSAndroid Build Coastguard Worker   const struct macroblock_plane *const p = &x->plane[plane];
1078*77c1e3ccSAndroid Build Coastguard Worker   // Transform domain distortion computation is more efficient as it does
1079*77c1e3ccSAndroid Build Coastguard Worker   // not involve an inverse transform, but it is less accurate.
1080*77c1e3ccSAndroid Build Coastguard Worker   const int buffer_length = av1_get_max_eob(tx_size);
1081*77c1e3ccSAndroid Build Coastguard Worker   int64_t this_sse;
1082*77c1e3ccSAndroid Build Coastguard Worker   // TX-domain results need to shift down to Q2/D10 to match pixel
1083*77c1e3ccSAndroid Build Coastguard Worker   // domain distortion values which are in Q2^2
1084*77c1e3ccSAndroid Build Coastguard Worker   int shift = (MAX_TX_SCALE - av1_get_tx_scale(tx_size)) * 2;
1085*77c1e3ccSAndroid Build Coastguard Worker   const int block_offset = BLOCK_OFFSET(block);
1086*77c1e3ccSAndroid Build Coastguard Worker   tran_low_t *const coeff = p->coeff + block_offset;
1087*77c1e3ccSAndroid Build Coastguard Worker   tran_low_t *const dqcoeff = p->dqcoeff + block_offset;
1088*77c1e3ccSAndroid Build Coastguard Worker #if CONFIG_AV1_HIGHBITDEPTH
1089*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCKD *const xd = &x->e_mbd;
1090*77c1e3ccSAndroid Build Coastguard Worker   if (is_cur_buf_hbd(xd)) {
1091*77c1e3ccSAndroid Build Coastguard Worker     // TODO(veluca): handle use_qm_dist_metric for HBD too.
1092*77c1e3ccSAndroid Build Coastguard Worker     *out_dist = av1_highbd_block_error(coeff, dqcoeff, buffer_length, &this_sse,
1093*77c1e3ccSAndroid Build Coastguard Worker                                        xd->bd);
1094*77c1e3ccSAndroid Build Coastguard Worker   } else {
1095*77c1e3ccSAndroid Build Coastguard Worker #endif
1096*77c1e3ccSAndroid Build Coastguard Worker     if (qmatrix == NULL || !x->txfm_search_params.use_qm_dist_metric) {
1097*77c1e3ccSAndroid Build Coastguard Worker       *out_dist = av1_block_error(coeff, dqcoeff, buffer_length, &this_sse);
1098*77c1e3ccSAndroid Build Coastguard Worker     } else {
1099*77c1e3ccSAndroid Build Coastguard Worker       *out_dist = av1_block_error_qm(coeff, dqcoeff, buffer_length, qmatrix,
1100*77c1e3ccSAndroid Build Coastguard Worker                                      scan, &this_sse);
1101*77c1e3ccSAndroid Build Coastguard Worker     }
1102*77c1e3ccSAndroid Build Coastguard Worker #if CONFIG_AV1_HIGHBITDEPTH
1103*77c1e3ccSAndroid Build Coastguard Worker   }
1104*77c1e3ccSAndroid Build Coastguard Worker #endif
1105*77c1e3ccSAndroid Build Coastguard Worker 
1106*77c1e3ccSAndroid Build Coastguard Worker   *out_dist = RIGHT_SIGNED_SHIFT(*out_dist, shift);
1107*77c1e3ccSAndroid Build Coastguard Worker   *out_sse = RIGHT_SIGNED_SHIFT(this_sse, shift);
1108*77c1e3ccSAndroid Build Coastguard Worker }
1109*77c1e3ccSAndroid Build Coastguard Worker 
prune_txk_type_separ(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,TX_SIZE tx_size,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,int * txk_map,int16_t allowed_tx_mask,int prune_factor,const TXB_CTX * const txb_ctx,int reduced_tx_set_used,int64_t ref_best_rd,int num_sel)1110*77c1e3ccSAndroid Build Coastguard Worker static uint16_t prune_txk_type_separ(
1111*77c1e3ccSAndroid Build Coastguard Worker     const AV1_COMP *cpi, MACROBLOCK *x, int plane, int block, TX_SIZE tx_size,
1112*77c1e3ccSAndroid Build Coastguard Worker     int blk_row, int blk_col, BLOCK_SIZE plane_bsize, int *txk_map,
1113*77c1e3ccSAndroid Build Coastguard Worker     int16_t allowed_tx_mask, int prune_factor, const TXB_CTX *const txb_ctx,
1114*77c1e3ccSAndroid Build Coastguard Worker     int reduced_tx_set_used, int64_t ref_best_rd, int num_sel) {
1115*77c1e3ccSAndroid Build Coastguard Worker   const AV1_COMMON *cm = &cpi->common;
1116*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCKD *xd = &x->e_mbd;
1117*77c1e3ccSAndroid Build Coastguard Worker 
1118*77c1e3ccSAndroid Build Coastguard Worker   int idx;
1119*77c1e3ccSAndroid Build Coastguard Worker 
1120*77c1e3ccSAndroid Build Coastguard Worker   int64_t rds_v[4];
1121*77c1e3ccSAndroid Build Coastguard Worker   int64_t rds_h[4];
1122*77c1e3ccSAndroid Build Coastguard Worker   int idx_v[4] = { 0, 1, 2, 3 };
1123*77c1e3ccSAndroid Build Coastguard Worker   int idx_h[4] = { 0, 1, 2, 3 };
1124*77c1e3ccSAndroid Build Coastguard Worker   int skip_v[4] = { 0 };
1125*77c1e3ccSAndroid Build Coastguard Worker   int skip_h[4] = { 0 };
1126*77c1e3ccSAndroid Build Coastguard Worker   const int idx_map[16] = {
1127*77c1e3ccSAndroid Build Coastguard Worker     DCT_DCT,      DCT_ADST,      DCT_FLIPADST,      V_DCT,
1128*77c1e3ccSAndroid Build Coastguard Worker     ADST_DCT,     ADST_ADST,     ADST_FLIPADST,     V_ADST,
1129*77c1e3ccSAndroid Build Coastguard Worker     FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
1130*77c1e3ccSAndroid Build Coastguard Worker     H_DCT,        H_ADST,        H_FLIPADST,        IDTX
1131*77c1e3ccSAndroid Build Coastguard Worker   };
1132*77c1e3ccSAndroid Build Coastguard Worker 
1133*77c1e3ccSAndroid Build Coastguard Worker   const int sel_pattern_v[16] = {
1134*77c1e3ccSAndroid Build Coastguard Worker     0, 0, 1, 1, 0, 2, 1, 2, 2, 0, 3, 1, 3, 2, 3, 3
1135*77c1e3ccSAndroid Build Coastguard Worker   };
1136*77c1e3ccSAndroid Build Coastguard Worker   const int sel_pattern_h[16] = {
1137*77c1e3ccSAndroid Build Coastguard Worker     0, 1, 0, 1, 2, 0, 2, 1, 2, 3, 0, 3, 1, 3, 2, 3
1138*77c1e3ccSAndroid Build Coastguard Worker   };
1139*77c1e3ccSAndroid Build Coastguard Worker 
1140*77c1e3ccSAndroid Build Coastguard Worker   QUANT_PARAM quant_param;
1141*77c1e3ccSAndroid Build Coastguard Worker   TxfmParam txfm_param;
1142*77c1e3ccSAndroid Build Coastguard Worker   av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
1143*77c1e3ccSAndroid Build Coastguard Worker   av1_setup_quant(tx_size, 1, AV1_XFORM_QUANT_B, cpi->oxcf.q_cfg.quant_b_adapt,
1144*77c1e3ccSAndroid Build Coastguard Worker                   &quant_param);
1145*77c1e3ccSAndroid Build Coastguard Worker   int tx_type;
1146*77c1e3ccSAndroid Build Coastguard Worker   // to ensure we can try ones even outside of ext_tx_set of current block
1147*77c1e3ccSAndroid Build Coastguard Worker   // this function should only be called for size < 16
1148*77c1e3ccSAndroid Build Coastguard Worker   assert(txsize_sqr_up_map[tx_size] <= TX_16X16);
1149*77c1e3ccSAndroid Build Coastguard Worker   txfm_param.tx_set_type = EXT_TX_SET_ALL16;
1150*77c1e3ccSAndroid Build Coastguard Worker 
1151*77c1e3ccSAndroid Build Coastguard Worker   int rate_cost = 0;
1152*77c1e3ccSAndroid Build Coastguard Worker   int64_t dist = 0, sse = 0;
1153*77c1e3ccSAndroid Build Coastguard Worker   // evaluate horizontal with vertical DCT
1154*77c1e3ccSAndroid Build Coastguard Worker   for (idx = 0; idx < 4; ++idx) {
1155*77c1e3ccSAndroid Build Coastguard Worker     tx_type = idx_map[idx];
1156*77c1e3ccSAndroid Build Coastguard Worker     txfm_param.tx_type = tx_type;
1157*77c1e3ccSAndroid Build Coastguard Worker 
1158*77c1e3ccSAndroid Build Coastguard Worker     av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
1159*77c1e3ccSAndroid Build Coastguard Worker                       &quant_param);
1160*77c1e3ccSAndroid Build Coastguard Worker 
1161*77c1e3ccSAndroid Build Coastguard Worker     av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
1162*77c1e3ccSAndroid Build Coastguard Worker                     &quant_param);
1163*77c1e3ccSAndroid Build Coastguard Worker 
1164*77c1e3ccSAndroid Build Coastguard Worker     const SCAN_ORDER *const scan_order =
1165*77c1e3ccSAndroid Build Coastguard Worker         get_scan(txfm_param.tx_size, txfm_param.tx_type);
1166*77c1e3ccSAndroid Build Coastguard Worker     dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix,
1167*77c1e3ccSAndroid Build Coastguard Worker                          scan_order->scan, &dist, &sse);
1168*77c1e3ccSAndroid Build Coastguard Worker 
1169*77c1e3ccSAndroid Build Coastguard Worker     rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
1170*77c1e3ccSAndroid Build Coastguard Worker                                               txb_ctx, reduced_tx_set_used, 0);
1171*77c1e3ccSAndroid Build Coastguard Worker 
1172*77c1e3ccSAndroid Build Coastguard Worker     rds_h[idx] = RDCOST(x->rdmult, rate_cost, dist);
1173*77c1e3ccSAndroid Build Coastguard Worker 
1174*77c1e3ccSAndroid Build Coastguard Worker     if ((rds_h[idx] - (rds_h[idx] >> 2)) > ref_best_rd) {
1175*77c1e3ccSAndroid Build Coastguard Worker       skip_h[idx] = 1;
1176*77c1e3ccSAndroid Build Coastguard Worker     }
1177*77c1e3ccSAndroid Build Coastguard Worker   }
1178*77c1e3ccSAndroid Build Coastguard Worker   sort_rd(rds_h, idx_h, 4);
1179*77c1e3ccSAndroid Build Coastguard Worker   for (idx = 1; idx < 4; idx++) {
1180*77c1e3ccSAndroid Build Coastguard Worker     if (rds_h[idx] > rds_h[0] * 1.2) skip_h[idx_h[idx]] = 1;
1181*77c1e3ccSAndroid Build Coastguard Worker   }
1182*77c1e3ccSAndroid Build Coastguard Worker 
1183*77c1e3ccSAndroid Build Coastguard Worker   if (skip_h[idx_h[0]]) return (uint16_t)0xFFFF;
1184*77c1e3ccSAndroid Build Coastguard Worker 
1185*77c1e3ccSAndroid Build Coastguard Worker   // evaluate vertical with the best horizontal chosen
1186*77c1e3ccSAndroid Build Coastguard Worker   rds_v[0] = rds_h[0];
1187*77c1e3ccSAndroid Build Coastguard Worker   int start_v = 1, end_v = 4;
1188*77c1e3ccSAndroid Build Coastguard Worker   const int *idx_map_v = idx_map + idx_h[0];
1189*77c1e3ccSAndroid Build Coastguard Worker 
1190*77c1e3ccSAndroid Build Coastguard Worker   for (idx = start_v; idx < end_v; ++idx) {
1191*77c1e3ccSAndroid Build Coastguard Worker     tx_type = idx_map_v[idx_v[idx] * 4];
1192*77c1e3ccSAndroid Build Coastguard Worker     txfm_param.tx_type = tx_type;
1193*77c1e3ccSAndroid Build Coastguard Worker 
1194*77c1e3ccSAndroid Build Coastguard Worker     av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
1195*77c1e3ccSAndroid Build Coastguard Worker                       &quant_param);
1196*77c1e3ccSAndroid Build Coastguard Worker 
1197*77c1e3ccSAndroid Build Coastguard Worker     av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
1198*77c1e3ccSAndroid Build Coastguard Worker                     &quant_param);
1199*77c1e3ccSAndroid Build Coastguard Worker 
1200*77c1e3ccSAndroid Build Coastguard Worker     const SCAN_ORDER *const scan_order =
1201*77c1e3ccSAndroid Build Coastguard Worker         get_scan(txfm_param.tx_size, txfm_param.tx_type);
1202*77c1e3ccSAndroid Build Coastguard Worker     dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix,
1203*77c1e3ccSAndroid Build Coastguard Worker                          scan_order->scan, &dist, &sse);
1204*77c1e3ccSAndroid Build Coastguard Worker 
1205*77c1e3ccSAndroid Build Coastguard Worker     rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
1206*77c1e3ccSAndroid Build Coastguard Worker                                               txb_ctx, reduced_tx_set_used, 0);
1207*77c1e3ccSAndroid Build Coastguard Worker 
1208*77c1e3ccSAndroid Build Coastguard Worker     rds_v[idx] = RDCOST(x->rdmult, rate_cost, dist);
1209*77c1e3ccSAndroid Build Coastguard Worker 
1210*77c1e3ccSAndroid Build Coastguard Worker     if ((rds_v[idx] - (rds_v[idx] >> 2)) > ref_best_rd) {
1211*77c1e3ccSAndroid Build Coastguard Worker       skip_v[idx] = 1;
1212*77c1e3ccSAndroid Build Coastguard Worker     }
1213*77c1e3ccSAndroid Build Coastguard Worker   }
1214*77c1e3ccSAndroid Build Coastguard Worker   sort_rd(rds_v, idx_v, 4);
1215*77c1e3ccSAndroid Build Coastguard Worker   for (idx = 1; idx < 4; idx++) {
1216*77c1e3ccSAndroid Build Coastguard Worker     if (rds_v[idx] > rds_v[0] * 1.2) skip_v[idx_v[idx]] = 1;
1217*77c1e3ccSAndroid Build Coastguard Worker   }
1218*77c1e3ccSAndroid Build Coastguard Worker 
1219*77c1e3ccSAndroid Build Coastguard Worker   // combine rd_h and rd_v to prune tx candidates
1220*77c1e3ccSAndroid Build Coastguard Worker   int i_v, i_h;
1221*77c1e3ccSAndroid Build Coastguard Worker   int64_t rds[16];
1222*77c1e3ccSAndroid Build Coastguard Worker   int num_cand = 0, last = TX_TYPES - 1;
1223*77c1e3ccSAndroid Build Coastguard Worker 
1224*77c1e3ccSAndroid Build Coastguard Worker   for (int i = 0; i < 16; i++) {
1225*77c1e3ccSAndroid Build Coastguard Worker     i_v = sel_pattern_v[i];
1226*77c1e3ccSAndroid Build Coastguard Worker     i_h = sel_pattern_h[i];
1227*77c1e3ccSAndroid Build Coastguard Worker     tx_type = idx_map[idx_v[i_v] * 4 + idx_h[i_h]];
1228*77c1e3ccSAndroid Build Coastguard Worker     if (!(allowed_tx_mask & (1 << tx_type)) || skip_h[idx_h[i_h]] ||
1229*77c1e3ccSAndroid Build Coastguard Worker         skip_v[idx_v[i_v]]) {
1230*77c1e3ccSAndroid Build Coastguard Worker       txk_map[last] = tx_type;
1231*77c1e3ccSAndroid Build Coastguard Worker       last--;
1232*77c1e3ccSAndroid Build Coastguard Worker     } else {
1233*77c1e3ccSAndroid Build Coastguard Worker       txk_map[num_cand] = tx_type;
1234*77c1e3ccSAndroid Build Coastguard Worker       rds[num_cand] = rds_v[i_v] + rds_h[i_h];
1235*77c1e3ccSAndroid Build Coastguard Worker       if (rds[num_cand] == 0) rds[num_cand] = 1;
1236*77c1e3ccSAndroid Build Coastguard Worker       num_cand++;
1237*77c1e3ccSAndroid Build Coastguard Worker     }
1238*77c1e3ccSAndroid Build Coastguard Worker   }
1239*77c1e3ccSAndroid Build Coastguard Worker   sort_rd(rds, txk_map, num_cand);
1240*77c1e3ccSAndroid Build Coastguard Worker 
1241*77c1e3ccSAndroid Build Coastguard Worker   uint16_t prune = (uint16_t)(~(1 << txk_map[0]));
1242*77c1e3ccSAndroid Build Coastguard Worker   num_sel = AOMMIN(num_sel, num_cand);
1243*77c1e3ccSAndroid Build Coastguard Worker 
1244*77c1e3ccSAndroid Build Coastguard Worker   for (int i = 1; i < num_sel; i++) {
1245*77c1e3ccSAndroid Build Coastguard Worker     int64_t factor = 1800 * (rds[i] - rds[0]) / (rds[0]);
1246*77c1e3ccSAndroid Build Coastguard Worker     if (factor < (int64_t)prune_factor)
1247*77c1e3ccSAndroid Build Coastguard Worker       prune &= ~(1 << txk_map[i]);
1248*77c1e3ccSAndroid Build Coastguard Worker     else
1249*77c1e3ccSAndroid Build Coastguard Worker       break;
1250*77c1e3ccSAndroid Build Coastguard Worker   }
1251*77c1e3ccSAndroid Build Coastguard Worker   return prune;
1252*77c1e3ccSAndroid Build Coastguard Worker }
1253*77c1e3ccSAndroid Build Coastguard Worker 
prune_txk_type(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,TX_SIZE tx_size,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,int * txk_map,uint16_t allowed_tx_mask,int prune_factor,const TXB_CTX * const txb_ctx,int reduced_tx_set_used)1254*77c1e3ccSAndroid Build Coastguard Worker static uint16_t prune_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
1255*77c1e3ccSAndroid Build Coastguard Worker                                int block, TX_SIZE tx_size, int blk_row,
1256*77c1e3ccSAndroid Build Coastguard Worker                                int blk_col, BLOCK_SIZE plane_bsize,
1257*77c1e3ccSAndroid Build Coastguard Worker                                int *txk_map, uint16_t allowed_tx_mask,
1258*77c1e3ccSAndroid Build Coastguard Worker                                int prune_factor, const TXB_CTX *const txb_ctx,
1259*77c1e3ccSAndroid Build Coastguard Worker                                int reduced_tx_set_used) {
1260*77c1e3ccSAndroid Build Coastguard Worker   const AV1_COMMON *cm = &cpi->common;
1261*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCKD *xd = &x->e_mbd;
1262*77c1e3ccSAndroid Build Coastguard Worker   int tx_type;
1263*77c1e3ccSAndroid Build Coastguard Worker 
1264*77c1e3ccSAndroid Build Coastguard Worker   int64_t rds[TX_TYPES];
1265*77c1e3ccSAndroid Build Coastguard Worker 
1266*77c1e3ccSAndroid Build Coastguard Worker   int num_cand = 0;
1267*77c1e3ccSAndroid Build Coastguard Worker   int last = TX_TYPES - 1;
1268*77c1e3ccSAndroid Build Coastguard Worker 
1269*77c1e3ccSAndroid Build Coastguard Worker   TxfmParam txfm_param;
1270*77c1e3ccSAndroid Build Coastguard Worker   QUANT_PARAM quant_param;
1271*77c1e3ccSAndroid Build Coastguard Worker   av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
1272*77c1e3ccSAndroid Build Coastguard Worker   av1_setup_quant(tx_size, 1, AV1_XFORM_QUANT_B, cpi->oxcf.q_cfg.quant_b_adapt,
1273*77c1e3ccSAndroid Build Coastguard Worker                   &quant_param);
1274*77c1e3ccSAndroid Build Coastguard Worker 
1275*77c1e3ccSAndroid Build Coastguard Worker   for (int idx = 0; idx < TX_TYPES; idx++) {
1276*77c1e3ccSAndroid Build Coastguard Worker     tx_type = idx;
1277*77c1e3ccSAndroid Build Coastguard Worker     int rate_cost = 0;
1278*77c1e3ccSAndroid Build Coastguard Worker     int64_t dist = 0, sse = 0;
1279*77c1e3ccSAndroid Build Coastguard Worker     if (!(allowed_tx_mask & (1 << tx_type))) {
1280*77c1e3ccSAndroid Build Coastguard Worker       txk_map[last] = tx_type;
1281*77c1e3ccSAndroid Build Coastguard Worker       last--;
1282*77c1e3ccSAndroid Build Coastguard Worker       continue;
1283*77c1e3ccSAndroid Build Coastguard Worker     }
1284*77c1e3ccSAndroid Build Coastguard Worker     txfm_param.tx_type = tx_type;
1285*77c1e3ccSAndroid Build Coastguard Worker 
1286*77c1e3ccSAndroid Build Coastguard Worker     av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
1287*77c1e3ccSAndroid Build Coastguard Worker                       &quant_param);
1288*77c1e3ccSAndroid Build Coastguard Worker 
1289*77c1e3ccSAndroid Build Coastguard Worker     // do txfm and quantization
1290*77c1e3ccSAndroid Build Coastguard Worker     av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
1291*77c1e3ccSAndroid Build Coastguard Worker                     &quant_param);
1292*77c1e3ccSAndroid Build Coastguard Worker     // estimate rate cost
1293*77c1e3ccSAndroid Build Coastguard Worker     rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
1294*77c1e3ccSAndroid Build Coastguard Worker                                               txb_ctx, reduced_tx_set_used, 0);
1295*77c1e3ccSAndroid Build Coastguard Worker     // tx domain dist
1296*77c1e3ccSAndroid Build Coastguard Worker     const SCAN_ORDER *const scan_order =
1297*77c1e3ccSAndroid Build Coastguard Worker         get_scan(txfm_param.tx_size, txfm_param.tx_type);
1298*77c1e3ccSAndroid Build Coastguard Worker     dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix,
1299*77c1e3ccSAndroid Build Coastguard Worker                          scan_order->scan, &dist, &sse);
1300*77c1e3ccSAndroid Build Coastguard Worker 
1301*77c1e3ccSAndroid Build Coastguard Worker     txk_map[num_cand] = tx_type;
1302*77c1e3ccSAndroid Build Coastguard Worker     rds[num_cand] = RDCOST(x->rdmult, rate_cost, dist);
1303*77c1e3ccSAndroid Build Coastguard Worker     if (rds[num_cand] == 0) rds[num_cand] = 1;
1304*77c1e3ccSAndroid Build Coastguard Worker     num_cand++;
1305*77c1e3ccSAndroid Build Coastguard Worker   }
1306*77c1e3ccSAndroid Build Coastguard Worker 
1307*77c1e3ccSAndroid Build Coastguard Worker   if (num_cand == 0) return (uint16_t)0xFFFF;
1308*77c1e3ccSAndroid Build Coastguard Worker 
1309*77c1e3ccSAndroid Build Coastguard Worker   sort_rd(rds, txk_map, num_cand);
1310*77c1e3ccSAndroid Build Coastguard Worker   uint16_t prune = (uint16_t)(~(1 << txk_map[0]));
1311*77c1e3ccSAndroid Build Coastguard Worker 
1312*77c1e3ccSAndroid Build Coastguard Worker   // 0 < prune_factor <= 1000 controls aggressiveness
1313*77c1e3ccSAndroid Build Coastguard Worker   int64_t factor = 0;
1314*77c1e3ccSAndroid Build Coastguard Worker   for (int idx = 1; idx < num_cand; idx++) {
1315*77c1e3ccSAndroid Build Coastguard Worker     factor = 1000 * (rds[idx] - rds[0]) / rds[0];
1316*77c1e3ccSAndroid Build Coastguard Worker     if (factor < (int64_t)prune_factor)
1317*77c1e3ccSAndroid Build Coastguard Worker       prune &= ~(1 << txk_map[idx]);
1318*77c1e3ccSAndroid Build Coastguard Worker     else
1319*77c1e3ccSAndroid Build Coastguard Worker       break;
1320*77c1e3ccSAndroid Build Coastguard Worker   }
1321*77c1e3ccSAndroid Build Coastguard Worker   return prune;
1322*77c1e3ccSAndroid Build Coastguard Worker }
1323*77c1e3ccSAndroid Build Coastguard Worker 
1324*77c1e3ccSAndroid Build Coastguard Worker // These thresholds were calibrated to provide a certain number of TX types
1325*77c1e3ccSAndroid Build Coastguard Worker // pruned by the model on average, i.e. selecting a threshold with index i
1326*77c1e3ccSAndroid Build Coastguard Worker // will lead to pruning i+1 TX types on average
1327*77c1e3ccSAndroid Build Coastguard Worker static const float *prune_2D_adaptive_thresholds[] = {
1328*77c1e3ccSAndroid Build Coastguard Worker   // TX_4X4
1329*77c1e3ccSAndroid Build Coastguard Worker   (float[]){ 0.00549f, 0.01306f, 0.02039f, 0.02747f, 0.03406f, 0.04065f,
1330*77c1e3ccSAndroid Build Coastguard Worker              0.04724f, 0.05383f, 0.06067f, 0.06799f, 0.07605f, 0.08533f,
1331*77c1e3ccSAndroid Build Coastguard Worker              0.09778f, 0.11780f },
1332*77c1e3ccSAndroid Build Coastguard Worker   // TX_8X8
1333*77c1e3ccSAndroid Build Coastguard Worker   (float[]){ 0.00037f, 0.00183f, 0.00525f, 0.01038f, 0.01697f, 0.02502f,
1334*77c1e3ccSAndroid Build Coastguard Worker              0.03381f, 0.04333f, 0.05286f, 0.06287f, 0.07434f, 0.08850f,
1335*77c1e3ccSAndroid Build Coastguard Worker              0.10803f, 0.14124f },
1336*77c1e3ccSAndroid Build Coastguard Worker   // TX_16X16
1337*77c1e3ccSAndroid Build Coastguard Worker   (float[]){ 0.01404f, 0.02000f, 0.04211f, 0.05164f, 0.05798f, 0.06335f,
1338*77c1e3ccSAndroid Build Coastguard Worker              0.06897f, 0.07629f, 0.08875f, 0.11169f },
1339*77c1e3ccSAndroid Build Coastguard Worker   // TX_32X32
1340*77c1e3ccSAndroid Build Coastguard Worker   NULL,
1341*77c1e3ccSAndroid Build Coastguard Worker   // TX_64X64
1342*77c1e3ccSAndroid Build Coastguard Worker   NULL,
1343*77c1e3ccSAndroid Build Coastguard Worker   // TX_4X8
1344*77c1e3ccSAndroid Build Coastguard Worker   (float[]){ 0.00183f, 0.00745f, 0.01428f, 0.02185f, 0.02966f, 0.03723f,
1345*77c1e3ccSAndroid Build Coastguard Worker              0.04456f, 0.05188f, 0.05920f, 0.06702f, 0.07605f, 0.08704f,
1346*77c1e3ccSAndroid Build Coastguard Worker              0.10168f, 0.12585f },
1347*77c1e3ccSAndroid Build Coastguard Worker   // TX_8X4
1348*77c1e3ccSAndroid Build Coastguard Worker   (float[]){ 0.00085f, 0.00476f, 0.01135f, 0.01892f, 0.02698f, 0.03528f,
1349*77c1e3ccSAndroid Build Coastguard Worker              0.04358f, 0.05164f, 0.05994f, 0.06848f, 0.07849f, 0.09021f,
1350*77c1e3ccSAndroid Build Coastguard Worker              0.10583f, 0.13123f },
1351*77c1e3ccSAndroid Build Coastguard Worker   // TX_8X16
1352*77c1e3ccSAndroid Build Coastguard Worker   (float[]){ 0.00037f, 0.00232f, 0.00671f, 0.01257f, 0.01965f, 0.02722f,
1353*77c1e3ccSAndroid Build Coastguard Worker              0.03552f, 0.04382f, 0.05237f, 0.06189f, 0.07336f, 0.08728f,
1354*77c1e3ccSAndroid Build Coastguard Worker              0.10730f, 0.14221f },
1355*77c1e3ccSAndroid Build Coastguard Worker   // TX_16X8
1356*77c1e3ccSAndroid Build Coastguard Worker   (float[]){ 0.00061f, 0.00330f, 0.00818f, 0.01453f, 0.02185f, 0.02966f,
1357*77c1e3ccSAndroid Build Coastguard Worker              0.03772f, 0.04578f, 0.05383f, 0.06262f, 0.07288f, 0.08582f,
1358*77c1e3ccSAndroid Build Coastguard Worker              0.10339f, 0.13464f },
1359*77c1e3ccSAndroid Build Coastguard Worker   // TX_16X32
1360*77c1e3ccSAndroid Build Coastguard Worker   NULL,
1361*77c1e3ccSAndroid Build Coastguard Worker   // TX_32X16
1362*77c1e3ccSAndroid Build Coastguard Worker   NULL,
1363*77c1e3ccSAndroid Build Coastguard Worker   // TX_32X64
1364*77c1e3ccSAndroid Build Coastguard Worker   NULL,
1365*77c1e3ccSAndroid Build Coastguard Worker   // TX_64X32
1366*77c1e3ccSAndroid Build Coastguard Worker   NULL,
1367*77c1e3ccSAndroid Build Coastguard Worker   // TX_4X16
1368*77c1e3ccSAndroid Build Coastguard Worker   (float[]){ 0.00232f, 0.00671f, 0.01257f, 0.01941f, 0.02673f, 0.03430f,
1369*77c1e3ccSAndroid Build Coastguard Worker              0.04211f, 0.04968f, 0.05750f, 0.06580f, 0.07507f, 0.08655f,
1370*77c1e3ccSAndroid Build Coastguard Worker              0.10242f, 0.12878f },
1371*77c1e3ccSAndroid Build Coastguard Worker   // TX_16X4
1372*77c1e3ccSAndroid Build Coastguard Worker   (float[]){ 0.00110f, 0.00525f, 0.01208f, 0.01990f, 0.02795f, 0.03601f,
1373*77c1e3ccSAndroid Build Coastguard Worker              0.04358f, 0.05115f, 0.05896f, 0.06702f, 0.07629f, 0.08752f,
1374*77c1e3ccSAndroid Build Coastguard Worker              0.10217f, 0.12610f },
1375*77c1e3ccSAndroid Build Coastguard Worker   // TX_8X32
1376*77c1e3ccSAndroid Build Coastguard Worker   NULL,
1377*77c1e3ccSAndroid Build Coastguard Worker   // TX_32X8
1378*77c1e3ccSAndroid Build Coastguard Worker   NULL,
1379*77c1e3ccSAndroid Build Coastguard Worker   // TX_16X64
1380*77c1e3ccSAndroid Build Coastguard Worker   NULL,
1381*77c1e3ccSAndroid Build Coastguard Worker   // TX_64X16
1382*77c1e3ccSAndroid Build Coastguard Worker   NULL,
1383*77c1e3ccSAndroid Build Coastguard Worker };
1384*77c1e3ccSAndroid Build Coastguard Worker 
get_adaptive_thresholds(TX_SIZE tx_size,TxSetType tx_set_type,TX_TYPE_PRUNE_MODE prune_2d_txfm_mode)1385*77c1e3ccSAndroid Build Coastguard Worker static inline float get_adaptive_thresholds(
1386*77c1e3ccSAndroid Build Coastguard Worker     TX_SIZE tx_size, TxSetType tx_set_type,
1387*77c1e3ccSAndroid Build Coastguard Worker     TX_TYPE_PRUNE_MODE prune_2d_txfm_mode) {
1388*77c1e3ccSAndroid Build Coastguard Worker   const int prune_aggr_table[5][2] = {
1389*77c1e3ccSAndroid Build Coastguard Worker     { 4, 1 }, { 6, 3 }, { 9, 6 }, { 9, 6 }, { 12, 9 }
1390*77c1e3ccSAndroid Build Coastguard Worker   };
1391*77c1e3ccSAndroid Build Coastguard Worker   int pruning_aggressiveness = 0;
1392*77c1e3ccSAndroid Build Coastguard Worker   if (tx_set_type == EXT_TX_SET_ALL16)
1393*77c1e3ccSAndroid Build Coastguard Worker     pruning_aggressiveness =
1394*77c1e3ccSAndroid Build Coastguard Worker         prune_aggr_table[prune_2d_txfm_mode - TX_TYPE_PRUNE_1][0];
1395*77c1e3ccSAndroid Build Coastguard Worker   else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
1396*77c1e3ccSAndroid Build Coastguard Worker     pruning_aggressiveness =
1397*77c1e3ccSAndroid Build Coastguard Worker         prune_aggr_table[prune_2d_txfm_mode - TX_TYPE_PRUNE_1][1];
1398*77c1e3ccSAndroid Build Coastguard Worker 
1399*77c1e3ccSAndroid Build Coastguard Worker   return prune_2D_adaptive_thresholds[tx_size][pruning_aggressiveness];
1400*77c1e3ccSAndroid Build Coastguard Worker }
1401*77c1e3ccSAndroid Build Coastguard Worker 
get_energy_distribution_finer(const int16_t * diff,int stride,int bw,int bh,float * hordist,float * verdist)1402*77c1e3ccSAndroid Build Coastguard Worker static inline void get_energy_distribution_finer(const int16_t *diff,
1403*77c1e3ccSAndroid Build Coastguard Worker                                                  int stride, int bw, int bh,
1404*77c1e3ccSAndroid Build Coastguard Worker                                                  float *hordist,
1405*77c1e3ccSAndroid Build Coastguard Worker                                                  float *verdist) {
1406*77c1e3ccSAndroid Build Coastguard Worker   // First compute downscaled block energy values (esq); downscale factors
1407*77c1e3ccSAndroid Build Coastguard Worker   // are defined by w_shift and h_shift.
1408*77c1e3ccSAndroid Build Coastguard Worker   unsigned int esq[256];
1409*77c1e3ccSAndroid Build Coastguard Worker   const int w_shift = bw <= 8 ? 0 : 1;
1410*77c1e3ccSAndroid Build Coastguard Worker   const int h_shift = bh <= 8 ? 0 : 1;
1411*77c1e3ccSAndroid Build Coastguard Worker   const int esq_w = bw >> w_shift;
1412*77c1e3ccSAndroid Build Coastguard Worker   const int esq_h = bh >> h_shift;
1413*77c1e3ccSAndroid Build Coastguard Worker   const int esq_sz = esq_w * esq_h;
1414*77c1e3ccSAndroid Build Coastguard Worker   int i, j;
1415*77c1e3ccSAndroid Build Coastguard Worker   memset(esq, 0, esq_sz * sizeof(esq[0]));
1416*77c1e3ccSAndroid Build Coastguard Worker   if (w_shift) {
1417*77c1e3ccSAndroid Build Coastguard Worker     for (i = 0; i < bh; i++) {
1418*77c1e3ccSAndroid Build Coastguard Worker       unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w;
1419*77c1e3ccSAndroid Build Coastguard Worker       const int16_t *cur_diff_row = diff + i * stride;
1420*77c1e3ccSAndroid Build Coastguard Worker       for (j = 0; j < bw; j += 2) {
1421*77c1e3ccSAndroid Build Coastguard Worker         cur_esq_row[j >> 1] += (cur_diff_row[j] * cur_diff_row[j] +
1422*77c1e3ccSAndroid Build Coastguard Worker                                 cur_diff_row[j + 1] * cur_diff_row[j + 1]);
1423*77c1e3ccSAndroid Build Coastguard Worker       }
1424*77c1e3ccSAndroid Build Coastguard Worker     }
1425*77c1e3ccSAndroid Build Coastguard Worker   } else {
1426*77c1e3ccSAndroid Build Coastguard Worker     for (i = 0; i < bh; i++) {
1427*77c1e3ccSAndroid Build Coastguard Worker       unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w;
1428*77c1e3ccSAndroid Build Coastguard Worker       const int16_t *cur_diff_row = diff + i * stride;
1429*77c1e3ccSAndroid Build Coastguard Worker       for (j = 0; j < bw; j++) {
1430*77c1e3ccSAndroid Build Coastguard Worker         cur_esq_row[j] += cur_diff_row[j] * cur_diff_row[j];
1431*77c1e3ccSAndroid Build Coastguard Worker       }
1432*77c1e3ccSAndroid Build Coastguard Worker     }
1433*77c1e3ccSAndroid Build Coastguard Worker   }
1434*77c1e3ccSAndroid Build Coastguard Worker 
1435*77c1e3ccSAndroid Build Coastguard Worker   uint64_t total = 0;
1436*77c1e3ccSAndroid Build Coastguard Worker   for (i = 0; i < esq_sz; i++) total += esq[i];
1437*77c1e3ccSAndroid Build Coastguard Worker 
1438*77c1e3ccSAndroid Build Coastguard Worker   // Output hordist and verdist arrays are normalized 1D projections of esq
1439*77c1e3ccSAndroid Build Coastguard Worker   if (total == 0) {
1440*77c1e3ccSAndroid Build Coastguard Worker     float hor_val = 1.0f / esq_w;
1441*77c1e3ccSAndroid Build Coastguard Worker     for (j = 0; j < esq_w - 1; j++) hordist[j] = hor_val;
1442*77c1e3ccSAndroid Build Coastguard Worker     float ver_val = 1.0f / esq_h;
1443*77c1e3ccSAndroid Build Coastguard Worker     for (i = 0; i < esq_h - 1; i++) verdist[i] = ver_val;
1444*77c1e3ccSAndroid Build Coastguard Worker     return;
1445*77c1e3ccSAndroid Build Coastguard Worker   }
1446*77c1e3ccSAndroid Build Coastguard Worker 
1447*77c1e3ccSAndroid Build Coastguard Worker   const float e_recip = 1.0f / (float)total;
1448*77c1e3ccSAndroid Build Coastguard Worker   memset(hordist, 0, (esq_w - 1) * sizeof(hordist[0]));
1449*77c1e3ccSAndroid Build Coastguard Worker   memset(verdist, 0, (esq_h - 1) * sizeof(verdist[0]));
1450*77c1e3ccSAndroid Build Coastguard Worker   const unsigned int *cur_esq_row;
1451*77c1e3ccSAndroid Build Coastguard Worker   for (i = 0; i < esq_h - 1; i++) {
1452*77c1e3ccSAndroid Build Coastguard Worker     cur_esq_row = esq + i * esq_w;
1453*77c1e3ccSAndroid Build Coastguard Worker     for (j = 0; j < esq_w - 1; j++) {
1454*77c1e3ccSAndroid Build Coastguard Worker       hordist[j] += (float)cur_esq_row[j];
1455*77c1e3ccSAndroid Build Coastguard Worker       verdist[i] += (float)cur_esq_row[j];
1456*77c1e3ccSAndroid Build Coastguard Worker     }
1457*77c1e3ccSAndroid Build Coastguard Worker     verdist[i] += (float)cur_esq_row[j];
1458*77c1e3ccSAndroid Build Coastguard Worker   }
1459*77c1e3ccSAndroid Build Coastguard Worker   cur_esq_row = esq + i * esq_w;
1460*77c1e3ccSAndroid Build Coastguard Worker   for (j = 0; j < esq_w - 1; j++) hordist[j] += (float)cur_esq_row[j];
1461*77c1e3ccSAndroid Build Coastguard Worker 
1462*77c1e3ccSAndroid Build Coastguard Worker   for (j = 0; j < esq_w - 1; j++) hordist[j] *= e_recip;
1463*77c1e3ccSAndroid Build Coastguard Worker   for (i = 0; i < esq_h - 1; i++) verdist[i] *= e_recip;
1464*77c1e3ccSAndroid Build Coastguard Worker }
1465*77c1e3ccSAndroid Build Coastguard Worker 
check_bit_mask(uint16_t mask,int val)1466*77c1e3ccSAndroid Build Coastguard Worker static inline bool check_bit_mask(uint16_t mask, int val) {
1467*77c1e3ccSAndroid Build Coastguard Worker   return mask & (1 << val);
1468*77c1e3ccSAndroid Build Coastguard Worker }
1469*77c1e3ccSAndroid Build Coastguard Worker 
set_bit_mask(uint16_t * mask,int val)1470*77c1e3ccSAndroid Build Coastguard Worker static inline void set_bit_mask(uint16_t *mask, int val) {
1471*77c1e3ccSAndroid Build Coastguard Worker   *mask |= (1 << val);
1472*77c1e3ccSAndroid Build Coastguard Worker }
1473*77c1e3ccSAndroid Build Coastguard Worker 
unset_bit_mask(uint16_t * mask,int val)1474*77c1e3ccSAndroid Build Coastguard Worker static inline void unset_bit_mask(uint16_t *mask, int val) {
1475*77c1e3ccSAndroid Build Coastguard Worker   *mask &= ~(1 << val);
1476*77c1e3ccSAndroid Build Coastguard Worker }
1477*77c1e3ccSAndroid Build Coastguard Worker 
prune_tx_2D(MACROBLOCK * x,BLOCK_SIZE bsize,TX_SIZE tx_size,int blk_row,int blk_col,TxSetType tx_set_type,TX_TYPE_PRUNE_MODE prune_2d_txfm_mode,int * txk_map,uint16_t * allowed_tx_mask)1478*77c1e3ccSAndroid Build Coastguard Worker static void prune_tx_2D(MACROBLOCK *x, BLOCK_SIZE bsize, TX_SIZE tx_size,
1479*77c1e3ccSAndroid Build Coastguard Worker                         int blk_row, int blk_col, TxSetType tx_set_type,
1480*77c1e3ccSAndroid Build Coastguard Worker                         TX_TYPE_PRUNE_MODE prune_2d_txfm_mode, int *txk_map,
1481*77c1e3ccSAndroid Build Coastguard Worker                         uint16_t *allowed_tx_mask) {
1482*77c1e3ccSAndroid Build Coastguard Worker   // This table is used because the search order is different from the enum
1483*77c1e3ccSAndroid Build Coastguard Worker   // order.
1484*77c1e3ccSAndroid Build Coastguard Worker   static const int tx_type_table_2D[16] = {
1485*77c1e3ccSAndroid Build Coastguard Worker     DCT_DCT,      DCT_ADST,      DCT_FLIPADST,      V_DCT,
1486*77c1e3ccSAndroid Build Coastguard Worker     ADST_DCT,     ADST_ADST,     ADST_FLIPADST,     V_ADST,
1487*77c1e3ccSAndroid Build Coastguard Worker     FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
1488*77c1e3ccSAndroid Build Coastguard Worker     H_DCT,        H_ADST,        H_FLIPADST,        IDTX
1489*77c1e3ccSAndroid Build Coastguard Worker   };
1490*77c1e3ccSAndroid Build Coastguard Worker   if (tx_set_type != EXT_TX_SET_ALL16 &&
1491*77c1e3ccSAndroid Build Coastguard Worker       tx_set_type != EXT_TX_SET_DTT9_IDTX_1DDCT)
1492*77c1e3ccSAndroid Build Coastguard Worker     return;
1493*77c1e3ccSAndroid Build Coastguard Worker #if CONFIG_NN_V2
1494*77c1e3ccSAndroid Build Coastguard Worker   NN_CONFIG_V2 *nn_config_hor = av1_tx_type_nnconfig_map_hor[tx_size];
1495*77c1e3ccSAndroid Build Coastguard Worker   NN_CONFIG_V2 *nn_config_ver = av1_tx_type_nnconfig_map_ver[tx_size];
1496*77c1e3ccSAndroid Build Coastguard Worker #else
1497*77c1e3ccSAndroid Build Coastguard Worker   const NN_CONFIG *nn_config_hor = av1_tx_type_nnconfig_map_hor[tx_size];
1498*77c1e3ccSAndroid Build Coastguard Worker   const NN_CONFIG *nn_config_ver = av1_tx_type_nnconfig_map_ver[tx_size];
1499*77c1e3ccSAndroid Build Coastguard Worker #endif
1500*77c1e3ccSAndroid Build Coastguard Worker   if (!nn_config_hor || !nn_config_ver) return;  // Model not established yet.
1501*77c1e3ccSAndroid Build Coastguard Worker 
1502*77c1e3ccSAndroid Build Coastguard Worker   float hfeatures[16], vfeatures[16];
1503*77c1e3ccSAndroid Build Coastguard Worker   float hscores[4], vscores[4];
1504*77c1e3ccSAndroid Build Coastguard Worker   float scores_2D_raw[16];
1505*77c1e3ccSAndroid Build Coastguard Worker   const int bw = tx_size_wide[tx_size];
1506*77c1e3ccSAndroid Build Coastguard Worker   const int bh = tx_size_high[tx_size];
1507*77c1e3ccSAndroid Build Coastguard Worker   const int hfeatures_num = bw <= 8 ? bw : bw / 2;
1508*77c1e3ccSAndroid Build Coastguard Worker   const int vfeatures_num = bh <= 8 ? bh : bh / 2;
1509*77c1e3ccSAndroid Build Coastguard Worker   assert(hfeatures_num <= 16);
1510*77c1e3ccSAndroid Build Coastguard Worker   assert(vfeatures_num <= 16);
1511*77c1e3ccSAndroid Build Coastguard Worker 
1512*77c1e3ccSAndroid Build Coastguard Worker   const struct macroblock_plane *const p = &x->plane[0];
1513*77c1e3ccSAndroid Build Coastguard Worker   const int diff_stride = block_size_wide[bsize];
1514*77c1e3ccSAndroid Build Coastguard Worker   const int16_t *diff = p->src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
1515*77c1e3ccSAndroid Build Coastguard Worker   get_energy_distribution_finer(diff, diff_stride, bw, bh, hfeatures,
1516*77c1e3ccSAndroid Build Coastguard Worker                                 vfeatures);
1517*77c1e3ccSAndroid Build Coastguard Worker 
1518*77c1e3ccSAndroid Build Coastguard Worker   av1_get_horver_correlation_full(diff, diff_stride, bw, bh,
1519*77c1e3ccSAndroid Build Coastguard Worker                                   &hfeatures[hfeatures_num - 1],
1520*77c1e3ccSAndroid Build Coastguard Worker                                   &vfeatures[vfeatures_num - 1]);
1521*77c1e3ccSAndroid Build Coastguard Worker 
1522*77c1e3ccSAndroid Build Coastguard Worker #if CONFIG_NN_V2
1523*77c1e3ccSAndroid Build Coastguard Worker   av1_nn_predict_v2(hfeatures, nn_config_hor, 0, hscores);
1524*77c1e3ccSAndroid Build Coastguard Worker   av1_nn_predict_v2(vfeatures, nn_config_ver, 0, vscores);
1525*77c1e3ccSAndroid Build Coastguard Worker #else
1526*77c1e3ccSAndroid Build Coastguard Worker   av1_nn_predict(hfeatures, nn_config_hor, 1, hscores);
1527*77c1e3ccSAndroid Build Coastguard Worker   av1_nn_predict(vfeatures, nn_config_ver, 1, vscores);
1528*77c1e3ccSAndroid Build Coastguard Worker #endif
1529*77c1e3ccSAndroid Build Coastguard Worker 
1530*77c1e3ccSAndroid Build Coastguard Worker   for (int i = 0; i < 4; i++) {
1531*77c1e3ccSAndroid Build Coastguard Worker     float *cur_scores_2D = scores_2D_raw + i * 4;
1532*77c1e3ccSAndroid Build Coastguard Worker     cur_scores_2D[0] = vscores[i] * hscores[0];
1533*77c1e3ccSAndroid Build Coastguard Worker     cur_scores_2D[1] = vscores[i] * hscores[1];
1534*77c1e3ccSAndroid Build Coastguard Worker     cur_scores_2D[2] = vscores[i] * hscores[2];
1535*77c1e3ccSAndroid Build Coastguard Worker     cur_scores_2D[3] = vscores[i] * hscores[3];
1536*77c1e3ccSAndroid Build Coastguard Worker   }
1537*77c1e3ccSAndroid Build Coastguard Worker 
1538*77c1e3ccSAndroid Build Coastguard Worker   assert(TX_TYPES == 16);
1539*77c1e3ccSAndroid Build Coastguard Worker   // This version of the function only works when there are at most 16 classes.
1540*77c1e3ccSAndroid Build Coastguard Worker   // So we will need to change the optimization or use av1_nn_softmax instead if
1541*77c1e3ccSAndroid Build Coastguard Worker   // this ever gets changed.
1542*77c1e3ccSAndroid Build Coastguard Worker   av1_nn_fast_softmax_16(scores_2D_raw, scores_2D_raw);
1543*77c1e3ccSAndroid Build Coastguard Worker 
1544*77c1e3ccSAndroid Build Coastguard Worker   const float score_thresh =
1545*77c1e3ccSAndroid Build Coastguard Worker       get_adaptive_thresholds(tx_size, tx_set_type, prune_2d_txfm_mode);
1546*77c1e3ccSAndroid Build Coastguard Worker 
1547*77c1e3ccSAndroid Build Coastguard Worker   // Always keep the TX type with the highest score, prune all others with
1548*77c1e3ccSAndroid Build Coastguard Worker   // score below score_thresh.
1549*77c1e3ccSAndroid Build Coastguard Worker   int max_score_i = 0;
1550*77c1e3ccSAndroid Build Coastguard Worker   float max_score = 0.0f;
1551*77c1e3ccSAndroid Build Coastguard Worker   uint16_t allow_bitmask = 0;
1552*77c1e3ccSAndroid Build Coastguard Worker   float sum_score = 0.0;
1553*77c1e3ccSAndroid Build Coastguard Worker   // Calculate sum of allowed tx type score and Populate allow bit mask based
1554*77c1e3ccSAndroid Build Coastguard Worker   // on score_thresh and allowed_tx_mask
1555*77c1e3ccSAndroid Build Coastguard Worker   int allow_count = 0;
1556*77c1e3ccSAndroid Build Coastguard Worker   int tx_type_allowed[16] = { TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
1557*77c1e3ccSAndroid Build Coastguard Worker                               TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
1558*77c1e3ccSAndroid Build Coastguard Worker                               TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
1559*77c1e3ccSAndroid Build Coastguard Worker                               TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
1560*77c1e3ccSAndroid Build Coastguard Worker                               TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
1561*77c1e3ccSAndroid Build Coastguard Worker                               TX_TYPE_INVALID };
1562*77c1e3ccSAndroid Build Coastguard Worker   float scores_2D[16] = {
1563*77c1e3ccSAndroid Build Coastguard Worker     -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
1564*77c1e3ccSAndroid Build Coastguard Worker   };
1565*77c1e3ccSAndroid Build Coastguard Worker   for (int tx_idx = 0; tx_idx < TX_TYPES; tx_idx++) {
1566*77c1e3ccSAndroid Build Coastguard Worker     const int allow_tx_type =
1567*77c1e3ccSAndroid Build Coastguard Worker         check_bit_mask(*allowed_tx_mask, tx_type_table_2D[tx_idx]);
1568*77c1e3ccSAndroid Build Coastguard Worker     if (!allow_tx_type) {
1569*77c1e3ccSAndroid Build Coastguard Worker       continue;
1570*77c1e3ccSAndroid Build Coastguard Worker     }
1571*77c1e3ccSAndroid Build Coastguard Worker     if (scores_2D_raw[tx_idx] > max_score) {
1572*77c1e3ccSAndroid Build Coastguard Worker       max_score = scores_2D_raw[tx_idx];
1573*77c1e3ccSAndroid Build Coastguard Worker       max_score_i = tx_idx;
1574*77c1e3ccSAndroid Build Coastguard Worker     }
1575*77c1e3ccSAndroid Build Coastguard Worker     if (scores_2D_raw[tx_idx] >= score_thresh) {
1576*77c1e3ccSAndroid Build Coastguard Worker       // Set allow mask based on score_thresh
1577*77c1e3ccSAndroid Build Coastguard Worker       set_bit_mask(&allow_bitmask, tx_type_table_2D[tx_idx]);
1578*77c1e3ccSAndroid Build Coastguard Worker 
1579*77c1e3ccSAndroid Build Coastguard Worker       // Accumulate score of allowed tx type
1580*77c1e3ccSAndroid Build Coastguard Worker       sum_score += scores_2D_raw[tx_idx];
1581*77c1e3ccSAndroid Build Coastguard Worker 
1582*77c1e3ccSAndroid Build Coastguard Worker       scores_2D[allow_count] = scores_2D_raw[tx_idx];
1583*77c1e3ccSAndroid Build Coastguard Worker       tx_type_allowed[allow_count] = tx_type_table_2D[tx_idx];
1584*77c1e3ccSAndroid Build Coastguard Worker       allow_count += 1;
1585*77c1e3ccSAndroid Build Coastguard Worker     }
1586*77c1e3ccSAndroid Build Coastguard Worker   }
1587*77c1e3ccSAndroid Build Coastguard Worker   if (!check_bit_mask(allow_bitmask, tx_type_table_2D[max_score_i])) {
1588*77c1e3ccSAndroid Build Coastguard Worker     // If even the tx_type with max score is pruned, this means that no other
1589*77c1e3ccSAndroid Build Coastguard Worker     // tx_type is feasible. When this happens, we force enable max_score_i and
1590*77c1e3ccSAndroid Build Coastguard Worker     // end the search.
1591*77c1e3ccSAndroid Build Coastguard Worker     set_bit_mask(&allow_bitmask, tx_type_table_2D[max_score_i]);
1592*77c1e3ccSAndroid Build Coastguard Worker     memcpy(txk_map, tx_type_table_2D, sizeof(tx_type_table_2D));
1593*77c1e3ccSAndroid Build Coastguard Worker     *allowed_tx_mask = allow_bitmask;
1594*77c1e3ccSAndroid Build Coastguard Worker     return;
1595*77c1e3ccSAndroid Build Coastguard Worker   }
1596*77c1e3ccSAndroid Build Coastguard Worker 
1597*77c1e3ccSAndroid Build Coastguard Worker   // Sort tx type probability of all types
1598*77c1e3ccSAndroid Build Coastguard Worker   if (allow_count <= 8) {
1599*77c1e3ccSAndroid Build Coastguard Worker     av1_sort_fi32_8(scores_2D, tx_type_allowed);
1600*77c1e3ccSAndroid Build Coastguard Worker   } else {
1601*77c1e3ccSAndroid Build Coastguard Worker     av1_sort_fi32_16(scores_2D, tx_type_allowed);
1602*77c1e3ccSAndroid Build Coastguard Worker   }
1603*77c1e3ccSAndroid Build Coastguard Worker 
1604*77c1e3ccSAndroid Build Coastguard Worker   // Enable more pruning based on tx type probability and number of allowed tx
1605*77c1e3ccSAndroid Build Coastguard Worker   // types
1606*77c1e3ccSAndroid Build Coastguard Worker   if (prune_2d_txfm_mode >= TX_TYPE_PRUNE_4) {
1607*77c1e3ccSAndroid Build Coastguard Worker     float temp_score = 0.0;
1608*77c1e3ccSAndroid Build Coastguard Worker     float score_ratio = 0.0;
1609*77c1e3ccSAndroid Build Coastguard Worker     int tx_idx, tx_count = 0;
1610*77c1e3ccSAndroid Build Coastguard Worker     const float inv_sum_score = 100 / sum_score;
1611*77c1e3ccSAndroid Build Coastguard Worker     // Get allowed tx types based on sorted probability score and tx count
1612*77c1e3ccSAndroid Build Coastguard Worker     for (tx_idx = 0; tx_idx < allow_count; tx_idx++) {
1613*77c1e3ccSAndroid Build Coastguard Worker       // Skip the tx type which has more than 30% of cumulative
1614*77c1e3ccSAndroid Build Coastguard Worker       // probability and allowed tx type count is more than 2
1615*77c1e3ccSAndroid Build Coastguard Worker       if (score_ratio > 30.0 && tx_count >= 2) break;
1616*77c1e3ccSAndroid Build Coastguard Worker 
1617*77c1e3ccSAndroid Build Coastguard Worker       assert(check_bit_mask(allow_bitmask, tx_type_allowed[tx_idx]));
1618*77c1e3ccSAndroid Build Coastguard Worker       // Calculate cumulative probability
1619*77c1e3ccSAndroid Build Coastguard Worker       temp_score += scores_2D[tx_idx];
1620*77c1e3ccSAndroid Build Coastguard Worker 
1621*77c1e3ccSAndroid Build Coastguard Worker       // Calculate percentage of cumulative probability of allowed tx type
1622*77c1e3ccSAndroid Build Coastguard Worker       score_ratio = temp_score * inv_sum_score;
1623*77c1e3ccSAndroid Build Coastguard Worker       tx_count++;
1624*77c1e3ccSAndroid Build Coastguard Worker     }
1625*77c1e3ccSAndroid Build Coastguard Worker     // Set remaining tx types as pruned
1626*77c1e3ccSAndroid Build Coastguard Worker     for (; tx_idx < allow_count; tx_idx++)
1627*77c1e3ccSAndroid Build Coastguard Worker       unset_bit_mask(&allow_bitmask, tx_type_allowed[tx_idx]);
1628*77c1e3ccSAndroid Build Coastguard Worker   }
1629*77c1e3ccSAndroid Build Coastguard Worker 
1630*77c1e3ccSAndroid Build Coastguard Worker   memcpy(txk_map, tx_type_allowed, sizeof(tx_type_table_2D));
1631*77c1e3ccSAndroid Build Coastguard Worker   *allowed_tx_mask = allow_bitmask;
1632*77c1e3ccSAndroid Build Coastguard Worker }
1633*77c1e3ccSAndroid Build Coastguard Worker 
get_dev(float mean,double x2_sum,int num)1634*77c1e3ccSAndroid Build Coastguard Worker static float get_dev(float mean, double x2_sum, int num) {
1635*77c1e3ccSAndroid Build Coastguard Worker   const float e_x2 = (float)(x2_sum / num);
1636*77c1e3ccSAndroid Build Coastguard Worker   const float diff = e_x2 - mean * mean;
1637*77c1e3ccSAndroid Build Coastguard Worker   const float dev = (diff > 0) ? sqrtf(diff) : 0;
1638*77c1e3ccSAndroid Build Coastguard Worker   return dev;
1639*77c1e3ccSAndroid Build Coastguard Worker }
1640*77c1e3ccSAndroid Build Coastguard Worker 
1641*77c1e3ccSAndroid Build Coastguard Worker // Writes the features required by the ML model to predict tx split based on
1642*77c1e3ccSAndroid Build Coastguard Worker // mean and standard deviation values of the block and sub-blocks.
1643*77c1e3ccSAndroid Build Coastguard Worker // Returns the number of elements written to the output array which is at most
1644*77c1e3ccSAndroid Build Coastguard Worker // 12 currently. Hence 'features' buffer should be able to accommodate at least
1645*77c1e3ccSAndroid Build Coastguard Worker // 12 elements.
get_mean_dev_features(const int16_t * data,int stride,int bw,int bh,float * features)1646*77c1e3ccSAndroid Build Coastguard Worker static inline int get_mean_dev_features(const int16_t *data, int stride, int bw,
1647*77c1e3ccSAndroid Build Coastguard Worker                                         int bh, float *features) {
1648*77c1e3ccSAndroid Build Coastguard Worker   const int16_t *const data_ptr = &data[0];
1649*77c1e3ccSAndroid Build Coastguard Worker   const int subh = (bh >= bw) ? (bh >> 1) : bh;
1650*77c1e3ccSAndroid Build Coastguard Worker   const int subw = (bw >= bh) ? (bw >> 1) : bw;
1651*77c1e3ccSAndroid Build Coastguard Worker   const int num = bw * bh;
1652*77c1e3ccSAndroid Build Coastguard Worker   const int sub_num = subw * subh;
1653*77c1e3ccSAndroid Build Coastguard Worker   int feature_idx = 2;
1654*77c1e3ccSAndroid Build Coastguard Worker   int total_x_sum = 0;
1655*77c1e3ccSAndroid Build Coastguard Worker   int64_t total_x2_sum = 0;
1656*77c1e3ccSAndroid Build Coastguard Worker   int num_sub_blks = 0;
1657*77c1e3ccSAndroid Build Coastguard Worker   double mean2_sum = 0.0f;
1658*77c1e3ccSAndroid Build Coastguard Worker   float dev_sum = 0.0f;
1659*77c1e3ccSAndroid Build Coastguard Worker 
1660*77c1e3ccSAndroid Build Coastguard Worker   for (int row = 0; row < bh; row += subh) {
1661*77c1e3ccSAndroid Build Coastguard Worker     for (int col = 0; col < bw; col += subw) {
1662*77c1e3ccSAndroid Build Coastguard Worker       int x_sum;
1663*77c1e3ccSAndroid Build Coastguard Worker       int64_t x2_sum;
1664*77c1e3ccSAndroid Build Coastguard Worker       // TODO(any): Write a SIMD version. Clear registers.
1665*77c1e3ccSAndroid Build Coastguard Worker       aom_get_blk_sse_sum(data_ptr + row * stride + col, stride, subw, subh,
1666*77c1e3ccSAndroid Build Coastguard Worker                           &x_sum, &x2_sum);
1667*77c1e3ccSAndroid Build Coastguard Worker       total_x_sum += x_sum;
1668*77c1e3ccSAndroid Build Coastguard Worker       total_x2_sum += x2_sum;
1669*77c1e3ccSAndroid Build Coastguard Worker 
1670*77c1e3ccSAndroid Build Coastguard Worker       const float mean = (float)x_sum / sub_num;
1671*77c1e3ccSAndroid Build Coastguard Worker       const float dev = get_dev(mean, (double)x2_sum, sub_num);
1672*77c1e3ccSAndroid Build Coastguard Worker       features[feature_idx++] = mean;
1673*77c1e3ccSAndroid Build Coastguard Worker       features[feature_idx++] = dev;
1674*77c1e3ccSAndroid Build Coastguard Worker       mean2_sum += (double)(mean * mean);
1675*77c1e3ccSAndroid Build Coastguard Worker       dev_sum += dev;
1676*77c1e3ccSAndroid Build Coastguard Worker       num_sub_blks++;
1677*77c1e3ccSAndroid Build Coastguard Worker     }
1678*77c1e3ccSAndroid Build Coastguard Worker   }
1679*77c1e3ccSAndroid Build Coastguard Worker 
1680*77c1e3ccSAndroid Build Coastguard Worker   const float lvl0_mean = (float)total_x_sum / num;
1681*77c1e3ccSAndroid Build Coastguard Worker   features[0] = lvl0_mean;
1682*77c1e3ccSAndroid Build Coastguard Worker   features[1] = get_dev(lvl0_mean, (double)total_x2_sum, num);
1683*77c1e3ccSAndroid Build Coastguard Worker 
1684*77c1e3ccSAndroid Build Coastguard Worker   // Deviation of means.
1685*77c1e3ccSAndroid Build Coastguard Worker   features[feature_idx++] = get_dev(lvl0_mean, mean2_sum, num_sub_blks);
1686*77c1e3ccSAndroid Build Coastguard Worker   // Mean of deviations.
1687*77c1e3ccSAndroid Build Coastguard Worker   features[feature_idx++] = dev_sum / num_sub_blks;
1688*77c1e3ccSAndroid Build Coastguard Worker 
1689*77c1e3ccSAndroid Build Coastguard Worker   return feature_idx;
1690*77c1e3ccSAndroid Build Coastguard Worker }
1691*77c1e3ccSAndroid Build Coastguard Worker 
ml_predict_tx_split(MACROBLOCK * x,BLOCK_SIZE bsize,int blk_row,int blk_col,TX_SIZE tx_size)1692*77c1e3ccSAndroid Build Coastguard Worker static int ml_predict_tx_split(MACROBLOCK *x, BLOCK_SIZE bsize, int blk_row,
1693*77c1e3ccSAndroid Build Coastguard Worker                                int blk_col, TX_SIZE tx_size) {
1694*77c1e3ccSAndroid Build Coastguard Worker   const NN_CONFIG *nn_config = av1_tx_split_nnconfig_map[tx_size];
1695*77c1e3ccSAndroid Build Coastguard Worker   if (!nn_config) return -1;
1696*77c1e3ccSAndroid Build Coastguard Worker 
1697*77c1e3ccSAndroid Build Coastguard Worker   const int diff_stride = block_size_wide[bsize];
1698*77c1e3ccSAndroid Build Coastguard Worker   const int16_t *diff =
1699*77c1e3ccSAndroid Build Coastguard Worker       x->plane[0].src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
1700*77c1e3ccSAndroid Build Coastguard Worker   const int bw = tx_size_wide[tx_size];
1701*77c1e3ccSAndroid Build Coastguard Worker   const int bh = tx_size_high[tx_size];
1702*77c1e3ccSAndroid Build Coastguard Worker 
1703*77c1e3ccSAndroid Build Coastguard Worker   float features[64] = { 0.0f };
1704*77c1e3ccSAndroid Build Coastguard Worker   get_mean_dev_features(diff, diff_stride, bw, bh, features);
1705*77c1e3ccSAndroid Build Coastguard Worker 
1706*77c1e3ccSAndroid Build Coastguard Worker   float score = 0.0f;
1707*77c1e3ccSAndroid Build Coastguard Worker   av1_nn_predict(features, nn_config, 1, &score);
1708*77c1e3ccSAndroid Build Coastguard Worker 
1709*77c1e3ccSAndroid Build Coastguard Worker   int int_score = (int)(score * 10000);
1710*77c1e3ccSAndroid Build Coastguard Worker   return clamp(int_score, -80000, 80000);
1711*77c1e3ccSAndroid Build Coastguard Worker }
1712*77c1e3ccSAndroid Build Coastguard Worker 
get_tx_mask(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,const TXB_CTX * const txb_ctx,FAST_TX_SEARCH_MODE ftxs_mode,int64_t ref_best_rd,TX_TYPE * allowed_txk_types,int * txk_map)1713*77c1e3ccSAndroid Build Coastguard Worker static inline uint16_t get_tx_mask(
1714*77c1e3ccSAndroid Build Coastguard Worker     const AV1_COMP *cpi, MACROBLOCK *x, int plane, int block, int blk_row,
1715*77c1e3ccSAndroid Build Coastguard Worker     int blk_col, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
1716*77c1e3ccSAndroid Build Coastguard Worker     const TXB_CTX *const txb_ctx, FAST_TX_SEARCH_MODE ftxs_mode,
1717*77c1e3ccSAndroid Build Coastguard Worker     int64_t ref_best_rd, TX_TYPE *allowed_txk_types, int *txk_map) {
1718*77c1e3ccSAndroid Build Coastguard Worker   const AV1_COMMON *cm = &cpi->common;
1719*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCKD *xd = &x->e_mbd;
1720*77c1e3ccSAndroid Build Coastguard Worker   MB_MODE_INFO *mbmi = xd->mi[0];
1721*77c1e3ccSAndroid Build Coastguard Worker   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
1722*77c1e3ccSAndroid Build Coastguard Worker   const int is_inter = is_inter_block(mbmi);
1723*77c1e3ccSAndroid Build Coastguard Worker   const int fast_tx_search = ftxs_mode & FTXS_DCT_AND_1D_DCT_ONLY;
1724*77c1e3ccSAndroid Build Coastguard Worker   // if txk_allowed = TX_TYPES, >1 tx types are allowed, else, if txk_allowed <
1725*77c1e3ccSAndroid Build Coastguard Worker   // TX_TYPES, only that specific tx type is allowed.
1726*77c1e3ccSAndroid Build Coastguard Worker   TX_TYPE txk_allowed = TX_TYPES;
1727*77c1e3ccSAndroid Build Coastguard Worker 
1728*77c1e3ccSAndroid Build Coastguard Worker   const FRAME_UPDATE_TYPE update_type =
1729*77c1e3ccSAndroid Build Coastguard Worker       get_frame_update_type(&cpi->ppi->gf_group, cpi->gf_frame_index);
1730*77c1e3ccSAndroid Build Coastguard Worker   int use_actual_frame_probs = 1;
1731*77c1e3ccSAndroid Build Coastguard Worker   const int *tx_type_probs;
1732*77c1e3ccSAndroid Build Coastguard Worker #if CONFIG_FPMT_TEST
1733*77c1e3ccSAndroid Build Coastguard Worker   use_actual_frame_probs =
1734*77c1e3ccSAndroid Build Coastguard Worker       (cpi->ppi->fpmt_unit_test_cfg == PARALLEL_SIMULATION_ENCODE) ? 0 : 1;
1735*77c1e3ccSAndroid Build Coastguard Worker   if (!use_actual_frame_probs) {
1736*77c1e3ccSAndroid Build Coastguard Worker     tx_type_probs =
1737*77c1e3ccSAndroid Build Coastguard Worker         (int *)cpi->ppi->temp_frame_probs.tx_type_probs[update_type][tx_size];
1738*77c1e3ccSAndroid Build Coastguard Worker   }
1739*77c1e3ccSAndroid Build Coastguard Worker #endif
1740*77c1e3ccSAndroid Build Coastguard Worker   if (use_actual_frame_probs) {
1741*77c1e3ccSAndroid Build Coastguard Worker     tx_type_probs = cpi->ppi->frame_probs.tx_type_probs[update_type][tx_size];
1742*77c1e3ccSAndroid Build Coastguard Worker   }
1743*77c1e3ccSAndroid Build Coastguard Worker 
1744*77c1e3ccSAndroid Build Coastguard Worker   if ((!is_inter && txfm_params->use_default_intra_tx_type) ||
1745*77c1e3ccSAndroid Build Coastguard Worker       (is_inter && txfm_params->default_inter_tx_type_prob_thresh == 0)) {
1746*77c1e3ccSAndroid Build Coastguard Worker     txk_allowed =
1747*77c1e3ccSAndroid Build Coastguard Worker         get_default_tx_type(0, xd, tx_size, cpi->use_screen_content_tools);
1748*77c1e3ccSAndroid Build Coastguard Worker   } else if (is_inter &&
1749*77c1e3ccSAndroid Build Coastguard Worker              txfm_params->default_inter_tx_type_prob_thresh != INT_MAX) {
1750*77c1e3ccSAndroid Build Coastguard Worker     if (tx_type_probs[DEFAULT_INTER_TX_TYPE] >
1751*77c1e3ccSAndroid Build Coastguard Worker         txfm_params->default_inter_tx_type_prob_thresh) {
1752*77c1e3ccSAndroid Build Coastguard Worker       txk_allowed = DEFAULT_INTER_TX_TYPE;
1753*77c1e3ccSAndroid Build Coastguard Worker     } else {
1754*77c1e3ccSAndroid Build Coastguard Worker       int force_tx_type = 0;
1755*77c1e3ccSAndroid Build Coastguard Worker       int max_prob = 0;
1756*77c1e3ccSAndroid Build Coastguard Worker       const int tx_type_prob_threshold =
1757*77c1e3ccSAndroid Build Coastguard Worker           txfm_params->default_inter_tx_type_prob_thresh +
1758*77c1e3ccSAndroid Build Coastguard Worker           PROB_THRESH_OFFSET_TX_TYPE;
1759*77c1e3ccSAndroid Build Coastguard Worker       for (int i = 1; i < TX_TYPES; i++) {  // find maximum probability.
1760*77c1e3ccSAndroid Build Coastguard Worker         if (tx_type_probs[i] > max_prob) {
1761*77c1e3ccSAndroid Build Coastguard Worker           max_prob = tx_type_probs[i];
1762*77c1e3ccSAndroid Build Coastguard Worker           force_tx_type = i;
1763*77c1e3ccSAndroid Build Coastguard Worker         }
1764*77c1e3ccSAndroid Build Coastguard Worker       }
1765*77c1e3ccSAndroid Build Coastguard Worker       if (max_prob > tx_type_prob_threshold)  // force tx type with max prob.
1766*77c1e3ccSAndroid Build Coastguard Worker         txk_allowed = force_tx_type;
1767*77c1e3ccSAndroid Build Coastguard Worker       else if (x->rd_model == LOW_TXFM_RD) {
1768*77c1e3ccSAndroid Build Coastguard Worker         if (plane == 0) txk_allowed = DCT_DCT;
1769*77c1e3ccSAndroid Build Coastguard Worker       }
1770*77c1e3ccSAndroid Build Coastguard Worker     }
1771*77c1e3ccSAndroid Build Coastguard Worker   } else if (x->rd_model == LOW_TXFM_RD) {
1772*77c1e3ccSAndroid Build Coastguard Worker     if (plane == 0) txk_allowed = DCT_DCT;
1773*77c1e3ccSAndroid Build Coastguard Worker   }
1774*77c1e3ccSAndroid Build Coastguard Worker 
1775*77c1e3ccSAndroid Build Coastguard Worker   const TxSetType tx_set_type = av1_get_ext_tx_set_type(
1776*77c1e3ccSAndroid Build Coastguard Worker       tx_size, is_inter, cm->features.reduced_tx_set_used);
1777*77c1e3ccSAndroid Build Coastguard Worker 
1778*77c1e3ccSAndroid Build Coastguard Worker   TX_TYPE uv_tx_type = DCT_DCT;
1779*77c1e3ccSAndroid Build Coastguard Worker   if (plane) {
1780*77c1e3ccSAndroid Build Coastguard Worker     // tx_type of PLANE_TYPE_UV should be the same as PLANE_TYPE_Y
1781*77c1e3ccSAndroid Build Coastguard Worker     uv_tx_type = txk_allowed =
1782*77c1e3ccSAndroid Build Coastguard Worker         av1_get_tx_type(xd, get_plane_type(plane), blk_row, blk_col, tx_size,
1783*77c1e3ccSAndroid Build Coastguard Worker                         cm->features.reduced_tx_set_used);
1784*77c1e3ccSAndroid Build Coastguard Worker   }
1785*77c1e3ccSAndroid Build Coastguard Worker   PREDICTION_MODE intra_dir =
1786*77c1e3ccSAndroid Build Coastguard Worker       mbmi->filter_intra_mode_info.use_filter_intra
1787*77c1e3ccSAndroid Build Coastguard Worker           ? fimode_to_intradir[mbmi->filter_intra_mode_info.filter_intra_mode]
1788*77c1e3ccSAndroid Build Coastguard Worker           : mbmi->mode;
1789*77c1e3ccSAndroid Build Coastguard Worker   uint16_t ext_tx_used_flag =
1790*77c1e3ccSAndroid Build Coastguard Worker       cpi->sf.tx_sf.tx_type_search.use_reduced_intra_txset != 0 &&
1791*77c1e3ccSAndroid Build Coastguard Worker               tx_set_type == EXT_TX_SET_DTT4_IDTX_1DDCT
1792*77c1e3ccSAndroid Build Coastguard Worker           ? av1_reduced_intra_tx_used_flag[intra_dir]
1793*77c1e3ccSAndroid Build Coastguard Worker           : av1_ext_tx_used_flag[tx_set_type];
1794*77c1e3ccSAndroid Build Coastguard Worker 
1795*77c1e3ccSAndroid Build Coastguard Worker   if (cpi->sf.tx_sf.tx_type_search.use_reduced_intra_txset == 2)
1796*77c1e3ccSAndroid Build Coastguard Worker     ext_tx_used_flag &= av1_derived_intra_tx_used_flag[intra_dir];
1797*77c1e3ccSAndroid Build Coastguard Worker 
1798*77c1e3ccSAndroid Build Coastguard Worker   if (xd->lossless[mbmi->segment_id] || txsize_sqr_up_map[tx_size] > TX_32X32 ||
1799*77c1e3ccSAndroid Build Coastguard Worker       ext_tx_used_flag == 0x0001 ||
1800*77c1e3ccSAndroid Build Coastguard Worker       (is_inter && cpi->oxcf.txfm_cfg.use_inter_dct_only) ||
1801*77c1e3ccSAndroid Build Coastguard Worker       (!is_inter && cpi->oxcf.txfm_cfg.use_intra_dct_only)) {
1802*77c1e3ccSAndroid Build Coastguard Worker     txk_allowed = DCT_DCT;
1803*77c1e3ccSAndroid Build Coastguard Worker   }
1804*77c1e3ccSAndroid Build Coastguard Worker 
1805*77c1e3ccSAndroid Build Coastguard Worker   if (cpi->oxcf.txfm_cfg.enable_flip_idtx == 0)
1806*77c1e3ccSAndroid Build Coastguard Worker     ext_tx_used_flag &= DCT_ADST_TX_MASK;
1807*77c1e3ccSAndroid Build Coastguard Worker 
1808*77c1e3ccSAndroid Build Coastguard Worker   uint16_t allowed_tx_mask = 0;  // 1: allow; 0: skip.
1809*77c1e3ccSAndroid Build Coastguard Worker   if (txk_allowed < TX_TYPES) {
1810*77c1e3ccSAndroid Build Coastguard Worker     allowed_tx_mask = 1 << txk_allowed;
1811*77c1e3ccSAndroid Build Coastguard Worker     allowed_tx_mask &= ext_tx_used_flag;
1812*77c1e3ccSAndroid Build Coastguard Worker   } else if (fast_tx_search) {
1813*77c1e3ccSAndroid Build Coastguard Worker     allowed_tx_mask = 0x0c01;  // V_DCT, H_DCT, DCT_DCT
1814*77c1e3ccSAndroid Build Coastguard Worker     allowed_tx_mask &= ext_tx_used_flag;
1815*77c1e3ccSAndroid Build Coastguard Worker   } else {
1816*77c1e3ccSAndroid Build Coastguard Worker     assert(plane == 0);
1817*77c1e3ccSAndroid Build Coastguard Worker     allowed_tx_mask = ext_tx_used_flag;
1818*77c1e3ccSAndroid Build Coastguard Worker     int num_allowed = 0;
1819*77c1e3ccSAndroid Build Coastguard Worker     int i;
1820*77c1e3ccSAndroid Build Coastguard Worker 
1821*77c1e3ccSAndroid Build Coastguard Worker     if (cpi->sf.tx_sf.tx_type_search.prune_tx_type_using_stats) {
1822*77c1e3ccSAndroid Build Coastguard Worker       static const int thresh_arr[2][7] = { { 10, 15, 15, 10, 15, 15, 15 },
1823*77c1e3ccSAndroid Build Coastguard Worker                                             { 10, 17, 17, 10, 17, 17, 17 } };
1824*77c1e3ccSAndroid Build Coastguard Worker       const int thresh =
1825*77c1e3ccSAndroid Build Coastguard Worker           thresh_arr[cpi->sf.tx_sf.tx_type_search.prune_tx_type_using_stats - 1]
1826*77c1e3ccSAndroid Build Coastguard Worker                     [update_type];
1827*77c1e3ccSAndroid Build Coastguard Worker       uint16_t prune = 0;
1828*77c1e3ccSAndroid Build Coastguard Worker       int max_prob = -1;
1829*77c1e3ccSAndroid Build Coastguard Worker       int max_idx = 0;
1830*77c1e3ccSAndroid Build Coastguard Worker       for (i = 0; i < TX_TYPES; i++) {
1831*77c1e3ccSAndroid Build Coastguard Worker         if (tx_type_probs[i] > max_prob && (allowed_tx_mask & (1 << i))) {
1832*77c1e3ccSAndroid Build Coastguard Worker           max_prob = tx_type_probs[i];
1833*77c1e3ccSAndroid Build Coastguard Worker           max_idx = i;
1834*77c1e3ccSAndroid Build Coastguard Worker         }
1835*77c1e3ccSAndroid Build Coastguard Worker         if (tx_type_probs[i] < thresh) prune |= (1 << i);
1836*77c1e3ccSAndroid Build Coastguard Worker       }
1837*77c1e3ccSAndroid Build Coastguard Worker       if ((prune >> max_idx) & 0x01) prune &= ~(1 << max_idx);
1838*77c1e3ccSAndroid Build Coastguard Worker       allowed_tx_mask &= (~prune);
1839*77c1e3ccSAndroid Build Coastguard Worker     }
1840*77c1e3ccSAndroid Build Coastguard Worker     for (i = 0; i < TX_TYPES; i++) {
1841*77c1e3ccSAndroid Build Coastguard Worker       if (allowed_tx_mask & (1 << i)) num_allowed++;
1842*77c1e3ccSAndroid Build Coastguard Worker     }
1843*77c1e3ccSAndroid Build Coastguard Worker     assert(num_allowed > 0);
1844*77c1e3ccSAndroid Build Coastguard Worker 
1845*77c1e3ccSAndroid Build Coastguard Worker     if (num_allowed > 2 && cpi->sf.tx_sf.tx_type_search.prune_tx_type_est_rd) {
1846*77c1e3ccSAndroid Build Coastguard Worker       int pf = prune_factors[txfm_params->prune_2d_txfm_mode];
1847*77c1e3ccSAndroid Build Coastguard Worker       int mf = mul_factors[txfm_params->prune_2d_txfm_mode];
1848*77c1e3ccSAndroid Build Coastguard Worker       if (num_allowed <= 7) {
1849*77c1e3ccSAndroid Build Coastguard Worker         const uint16_t prune =
1850*77c1e3ccSAndroid Build Coastguard Worker             prune_txk_type(cpi, x, plane, block, tx_size, blk_row, blk_col,
1851*77c1e3ccSAndroid Build Coastguard Worker                            plane_bsize, txk_map, allowed_tx_mask, pf, txb_ctx,
1852*77c1e3ccSAndroid Build Coastguard Worker                            cm->features.reduced_tx_set_used);
1853*77c1e3ccSAndroid Build Coastguard Worker         allowed_tx_mask &= (~prune);
1854*77c1e3ccSAndroid Build Coastguard Worker       } else {
1855*77c1e3ccSAndroid Build Coastguard Worker         const int num_sel = (num_allowed * mf + 50) / 100;
1856*77c1e3ccSAndroid Build Coastguard Worker         const uint16_t prune = prune_txk_type_separ(
1857*77c1e3ccSAndroid Build Coastguard Worker             cpi, x, plane, block, tx_size, blk_row, blk_col, plane_bsize,
1858*77c1e3ccSAndroid Build Coastguard Worker             txk_map, allowed_tx_mask, pf, txb_ctx,
1859*77c1e3ccSAndroid Build Coastguard Worker             cm->features.reduced_tx_set_used, ref_best_rd, num_sel);
1860*77c1e3ccSAndroid Build Coastguard Worker 
1861*77c1e3ccSAndroid Build Coastguard Worker         allowed_tx_mask &= (~prune);
1862*77c1e3ccSAndroid Build Coastguard Worker       }
1863*77c1e3ccSAndroid Build Coastguard Worker     } else {
1864*77c1e3ccSAndroid Build Coastguard Worker       assert(num_allowed > 0);
1865*77c1e3ccSAndroid Build Coastguard Worker       int allowed_tx_count =
1866*77c1e3ccSAndroid Build Coastguard Worker           (txfm_params->prune_2d_txfm_mode >= TX_TYPE_PRUNE_4) ? 1 : 5;
1867*77c1e3ccSAndroid Build Coastguard Worker       // !fast_tx_search && txk_end != txk_start && plane == 0
1868*77c1e3ccSAndroid Build Coastguard Worker       if (txfm_params->prune_2d_txfm_mode >= TX_TYPE_PRUNE_1 && is_inter &&
1869*77c1e3ccSAndroid Build Coastguard Worker           num_allowed > allowed_tx_count) {
1870*77c1e3ccSAndroid Build Coastguard Worker         prune_tx_2D(x, plane_bsize, tx_size, blk_row, blk_col, tx_set_type,
1871*77c1e3ccSAndroid Build Coastguard Worker                     txfm_params->prune_2d_txfm_mode, txk_map, &allowed_tx_mask);
1872*77c1e3ccSAndroid Build Coastguard Worker       }
1873*77c1e3ccSAndroid Build Coastguard Worker     }
1874*77c1e3ccSAndroid Build Coastguard Worker   }
1875*77c1e3ccSAndroid Build Coastguard Worker 
1876*77c1e3ccSAndroid Build Coastguard Worker   // Need to have at least one transform type allowed.
1877*77c1e3ccSAndroid Build Coastguard Worker   if (allowed_tx_mask == 0) {
1878*77c1e3ccSAndroid Build Coastguard Worker     txk_allowed = (plane ? uv_tx_type : DCT_DCT);
1879*77c1e3ccSAndroid Build Coastguard Worker     allowed_tx_mask = (1 << txk_allowed);
1880*77c1e3ccSAndroid Build Coastguard Worker   }
1881*77c1e3ccSAndroid Build Coastguard Worker 
1882*77c1e3ccSAndroid Build Coastguard Worker   assert(IMPLIES(txk_allowed < TX_TYPES, allowed_tx_mask == 1 << txk_allowed));
1883*77c1e3ccSAndroid Build Coastguard Worker   *allowed_txk_types = txk_allowed;
1884*77c1e3ccSAndroid Build Coastguard Worker   return allowed_tx_mask;
1885*77c1e3ccSAndroid Build Coastguard Worker }
1886*77c1e3ccSAndroid Build Coastguard Worker 
1887*77c1e3ccSAndroid Build Coastguard Worker #if CONFIG_RD_DEBUG
update_txb_coeff_cost(RD_STATS * rd_stats,int plane,int txb_coeff_cost)1888*77c1e3ccSAndroid Build Coastguard Worker static inline void update_txb_coeff_cost(RD_STATS *rd_stats, int plane,
1889*77c1e3ccSAndroid Build Coastguard Worker                                          int txb_coeff_cost) {
1890*77c1e3ccSAndroid Build Coastguard Worker   rd_stats->txb_coeff_cost[plane] += txb_coeff_cost;
1891*77c1e3ccSAndroid Build Coastguard Worker }
1892*77c1e3ccSAndroid Build Coastguard Worker #endif
1893*77c1e3ccSAndroid Build Coastguard Worker 
cost_coeffs(MACROBLOCK * x,int plane,int block,TX_SIZE tx_size,const TX_TYPE tx_type,const TXB_CTX * const txb_ctx,int reduced_tx_set_used)1894*77c1e3ccSAndroid Build Coastguard Worker static inline int cost_coeffs(MACROBLOCK *x, int plane, int block,
1895*77c1e3ccSAndroid Build Coastguard Worker                               TX_SIZE tx_size, const TX_TYPE tx_type,
1896*77c1e3ccSAndroid Build Coastguard Worker                               const TXB_CTX *const txb_ctx,
1897*77c1e3ccSAndroid Build Coastguard Worker                               int reduced_tx_set_used) {
1898*77c1e3ccSAndroid Build Coastguard Worker #if TXCOEFF_COST_TIMER
1899*77c1e3ccSAndroid Build Coastguard Worker   struct aom_usec_timer timer;
1900*77c1e3ccSAndroid Build Coastguard Worker   aom_usec_timer_start(&timer);
1901*77c1e3ccSAndroid Build Coastguard Worker #endif
1902*77c1e3ccSAndroid Build Coastguard Worker   const int cost = av1_cost_coeffs_txb(x, plane, block, tx_size, tx_type,
1903*77c1e3ccSAndroid Build Coastguard Worker                                        txb_ctx, reduced_tx_set_used);
1904*77c1e3ccSAndroid Build Coastguard Worker #if TXCOEFF_COST_TIMER
1905*77c1e3ccSAndroid Build Coastguard Worker   AV1_COMMON *tmp_cm = (AV1_COMMON *)&cpi->common;
1906*77c1e3ccSAndroid Build Coastguard Worker   aom_usec_timer_mark(&timer);
1907*77c1e3ccSAndroid Build Coastguard Worker   const int64_t elapsed_time = aom_usec_timer_elapsed(&timer);
1908*77c1e3ccSAndroid Build Coastguard Worker   tmp_cm->txcoeff_cost_timer += elapsed_time;
1909*77c1e3ccSAndroid Build Coastguard Worker   ++tmp_cm->txcoeff_cost_count;
1910*77c1e3ccSAndroid Build Coastguard Worker #endif
1911*77c1e3ccSAndroid Build Coastguard Worker   return cost;
1912*77c1e3ccSAndroid Build Coastguard Worker }
1913*77c1e3ccSAndroid Build Coastguard Worker 
skip_trellis_opt_based_on_satd(MACROBLOCK * x,QUANT_PARAM * quant_param,int plane,int block,TX_SIZE tx_size,int quant_b_adapt,int qstep,unsigned int coeff_opt_satd_threshold,int skip_trellis,int dc_only_blk)1914*77c1e3ccSAndroid Build Coastguard Worker static int skip_trellis_opt_based_on_satd(MACROBLOCK *x,
1915*77c1e3ccSAndroid Build Coastguard Worker                                           QUANT_PARAM *quant_param, int plane,
1916*77c1e3ccSAndroid Build Coastguard Worker                                           int block, TX_SIZE tx_size,
1917*77c1e3ccSAndroid Build Coastguard Worker                                           int quant_b_adapt, int qstep,
1918*77c1e3ccSAndroid Build Coastguard Worker                                           unsigned int coeff_opt_satd_threshold,
1919*77c1e3ccSAndroid Build Coastguard Worker                                           int skip_trellis, int dc_only_blk) {
1920*77c1e3ccSAndroid Build Coastguard Worker   if (skip_trellis || (coeff_opt_satd_threshold == UINT_MAX))
1921*77c1e3ccSAndroid Build Coastguard Worker     return skip_trellis;
1922*77c1e3ccSAndroid Build Coastguard Worker 
1923*77c1e3ccSAndroid Build Coastguard Worker   const struct macroblock_plane *const p = &x->plane[plane];
1924*77c1e3ccSAndroid Build Coastguard Worker   const int block_offset = BLOCK_OFFSET(block);
1925*77c1e3ccSAndroid Build Coastguard Worker   tran_low_t *const coeff_ptr = p->coeff + block_offset;
1926*77c1e3ccSAndroid Build Coastguard Worker   const int n_coeffs = av1_get_max_eob(tx_size);
1927*77c1e3ccSAndroid Build Coastguard Worker   const int shift = (MAX_TX_SCALE - av1_get_tx_scale(tx_size));
1928*77c1e3ccSAndroid Build Coastguard Worker   int satd = (dc_only_blk) ? abs(coeff_ptr[0]) : aom_satd(coeff_ptr, n_coeffs);
1929*77c1e3ccSAndroid Build Coastguard Worker   satd = RIGHT_SIGNED_SHIFT(satd, shift);
1930*77c1e3ccSAndroid Build Coastguard Worker   satd >>= (x->e_mbd.bd - 8);
1931*77c1e3ccSAndroid Build Coastguard Worker 
1932*77c1e3ccSAndroid Build Coastguard Worker   const int skip_block_trellis =
1933*77c1e3ccSAndroid Build Coastguard Worker       ((uint64_t)satd >
1934*77c1e3ccSAndroid Build Coastguard Worker        (uint64_t)coeff_opt_satd_threshold * qstep * sqrt_tx_pixels_2d[tx_size]);
1935*77c1e3ccSAndroid Build Coastguard Worker 
1936*77c1e3ccSAndroid Build Coastguard Worker   av1_setup_quant(
1937*77c1e3ccSAndroid Build Coastguard Worker       tx_size, !skip_block_trellis,
1938*77c1e3ccSAndroid Build Coastguard Worker       skip_block_trellis
1939*77c1e3ccSAndroid Build Coastguard Worker           ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B : AV1_XFORM_QUANT_FP)
1940*77c1e3ccSAndroid Build Coastguard Worker           : AV1_XFORM_QUANT_FP,
1941*77c1e3ccSAndroid Build Coastguard Worker       quant_b_adapt, quant_param);
1942*77c1e3ccSAndroid Build Coastguard Worker 
1943*77c1e3ccSAndroid Build Coastguard Worker   return skip_block_trellis;
1944*77c1e3ccSAndroid Build Coastguard Worker }
1945*77c1e3ccSAndroid Build Coastguard Worker 
1946*77c1e3ccSAndroid Build Coastguard Worker // Predict DC only blocks if the residual variance is below a qstep based
1947*77c1e3ccSAndroid Build Coastguard Worker // threshold.For such blocks, transform type search is bypassed.
predict_dc_only_block(MACROBLOCK * x,int plane,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,int block,int blk_row,int blk_col,RD_STATS * best_rd_stats,int64_t * block_sse,unsigned int * block_mse_q8,int64_t * per_px_mean,int * dc_only_blk)1948*77c1e3ccSAndroid Build Coastguard Worker static inline void predict_dc_only_block(
1949*77c1e3ccSAndroid Build Coastguard Worker     MACROBLOCK *x, int plane, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
1950*77c1e3ccSAndroid Build Coastguard Worker     int block, int blk_row, int blk_col, RD_STATS *best_rd_stats,
1951*77c1e3ccSAndroid Build Coastguard Worker     int64_t *block_sse, unsigned int *block_mse_q8, int64_t *per_px_mean,
1952*77c1e3ccSAndroid Build Coastguard Worker     int *dc_only_blk) {
1953*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCKD *xd = &x->e_mbd;
1954*77c1e3ccSAndroid Build Coastguard Worker   MB_MODE_INFO *mbmi = xd->mi[0];
1955*77c1e3ccSAndroid Build Coastguard Worker   const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
1956*77c1e3ccSAndroid Build Coastguard Worker   const int qstep = x->plane[plane].dequant_QTX[1] >> dequant_shift;
1957*77c1e3ccSAndroid Build Coastguard Worker   uint64_t block_var = UINT64_MAX;
1958*77c1e3ccSAndroid Build Coastguard Worker   const int dc_qstep = x->plane[plane].dequant_QTX[0] >> 3;
1959*77c1e3ccSAndroid Build Coastguard Worker   *block_sse = pixel_diff_stats(x, plane, blk_row, blk_col, plane_bsize,
1960*77c1e3ccSAndroid Build Coastguard Worker                                 txsize_to_bsize[tx_size], block_mse_q8,
1961*77c1e3ccSAndroid Build Coastguard Worker                                 per_px_mean, &block_var);
1962*77c1e3ccSAndroid Build Coastguard Worker   assert((*block_mse_q8) != UINT_MAX);
1963*77c1e3ccSAndroid Build Coastguard Worker   uint64_t var_threshold = (uint64_t)(1.8 * qstep * qstep);
1964*77c1e3ccSAndroid Build Coastguard Worker   if (is_cur_buf_hbd(xd))
1965*77c1e3ccSAndroid Build Coastguard Worker     block_var = ROUND_POWER_OF_TWO(block_var, (xd->bd - 8) * 2);
1966*77c1e3ccSAndroid Build Coastguard Worker 
1967*77c1e3ccSAndroid Build Coastguard Worker   if (block_var >= var_threshold) return;
1968*77c1e3ccSAndroid Build Coastguard Worker   const unsigned int predict_dc_level = x->txfm_search_params.predict_dc_level;
1969*77c1e3ccSAndroid Build Coastguard Worker   assert(predict_dc_level != 0);
1970*77c1e3ccSAndroid Build Coastguard Worker 
1971*77c1e3ccSAndroid Build Coastguard Worker   // Prediction of skip block if residual mean and variance are less
1972*77c1e3ccSAndroid Build Coastguard Worker   // than qstep based threshold
1973*77c1e3ccSAndroid Build Coastguard Worker   if ((llabs(*per_px_mean) * dc_coeff_scale[tx_size]) < (dc_qstep << 12)) {
1974*77c1e3ccSAndroid Build Coastguard Worker     // If the normalized mean of residual block is less than the dc qstep and
1975*77c1e3ccSAndroid Build Coastguard Worker     // the  normalized block variance is less than ac qstep, then the block is
1976*77c1e3ccSAndroid Build Coastguard Worker     // assumed to be a skip block and its rdcost is updated accordingly.
1977*77c1e3ccSAndroid Build Coastguard Worker     best_rd_stats->skip_txfm = 1;
1978*77c1e3ccSAndroid Build Coastguard Worker 
1979*77c1e3ccSAndroid Build Coastguard Worker     x->plane[plane].eobs[block] = 0;
1980*77c1e3ccSAndroid Build Coastguard Worker 
1981*77c1e3ccSAndroid Build Coastguard Worker     if (is_cur_buf_hbd(xd))
1982*77c1e3ccSAndroid Build Coastguard Worker       *block_sse = ROUND_POWER_OF_TWO((*block_sse), (xd->bd - 8) * 2);
1983*77c1e3ccSAndroid Build Coastguard Worker 
1984*77c1e3ccSAndroid Build Coastguard Worker     best_rd_stats->dist = (*block_sse) << 4;
1985*77c1e3ccSAndroid Build Coastguard Worker     best_rd_stats->sse = best_rd_stats->dist;
1986*77c1e3ccSAndroid Build Coastguard Worker 
1987*77c1e3ccSAndroid Build Coastguard Worker     ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
1988*77c1e3ccSAndroid Build Coastguard Worker     ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
1989*77c1e3ccSAndroid Build Coastguard Worker     av1_get_entropy_contexts(plane_bsize, &xd->plane[plane], ctxa, ctxl);
1990*77c1e3ccSAndroid Build Coastguard Worker     ENTROPY_CONTEXT *ta = ctxa;
1991*77c1e3ccSAndroid Build Coastguard Worker     ENTROPY_CONTEXT *tl = ctxl;
1992*77c1e3ccSAndroid Build Coastguard Worker     const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
1993*77c1e3ccSAndroid Build Coastguard Worker     TXB_CTX txb_ctx_tmp;
1994*77c1e3ccSAndroid Build Coastguard Worker     const PLANE_TYPE plane_type = get_plane_type(plane);
1995*77c1e3ccSAndroid Build Coastguard Worker     get_txb_ctx(plane_bsize, tx_size, plane, ta, tl, &txb_ctx_tmp);
1996*77c1e3ccSAndroid Build Coastguard Worker     const int zero_blk_rate = x->coeff_costs.coeff_costs[txs_ctx][plane_type]
1997*77c1e3ccSAndroid Build Coastguard Worker                                   .txb_skip_cost[txb_ctx_tmp.txb_skip_ctx][1];
1998*77c1e3ccSAndroid Build Coastguard Worker     best_rd_stats->rate = zero_blk_rate;
1999*77c1e3ccSAndroid Build Coastguard Worker 
2000*77c1e3ccSAndroid Build Coastguard Worker     best_rd_stats->rdcost =
2001*77c1e3ccSAndroid Build Coastguard Worker         RDCOST(x->rdmult, best_rd_stats->rate, best_rd_stats->sse);
2002*77c1e3ccSAndroid Build Coastguard Worker 
2003*77c1e3ccSAndroid Build Coastguard Worker     x->plane[plane].txb_entropy_ctx[block] = 0;
2004*77c1e3ccSAndroid Build Coastguard Worker   } else if (predict_dc_level > 1) {
2005*77c1e3ccSAndroid Build Coastguard Worker     // Predict DC only blocks based on residual variance.
2006*77c1e3ccSAndroid Build Coastguard Worker     // For chroma plane, this prediction is disabled for intra blocks.
2007*77c1e3ccSAndroid Build Coastguard Worker     if ((plane == 0) || (plane > 0 && is_inter_block(mbmi))) *dc_only_blk = 1;
2008*77c1e3ccSAndroid Build Coastguard Worker   }
2009*77c1e3ccSAndroid Build Coastguard Worker }
2010*77c1e3ccSAndroid Build Coastguard Worker 
2011*77c1e3ccSAndroid Build Coastguard Worker // Search for the best transform type for a given transform block.
2012*77c1e3ccSAndroid Build Coastguard Worker // This function can be used for both inter and intra, both luma and chroma.
search_tx_type(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,const TXB_CTX * const txb_ctx,FAST_TX_SEARCH_MODE ftxs_mode,int skip_trellis,int64_t ref_best_rd,RD_STATS * best_rd_stats)2013*77c1e3ccSAndroid Build Coastguard Worker static void search_tx_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
2014*77c1e3ccSAndroid Build Coastguard Worker                            int block, int blk_row, int blk_col,
2015*77c1e3ccSAndroid Build Coastguard Worker                            BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
2016*77c1e3ccSAndroid Build Coastguard Worker                            const TXB_CTX *const txb_ctx,
2017*77c1e3ccSAndroid Build Coastguard Worker                            FAST_TX_SEARCH_MODE ftxs_mode, int skip_trellis,
2018*77c1e3ccSAndroid Build Coastguard Worker                            int64_t ref_best_rd, RD_STATS *best_rd_stats) {
2019*77c1e3ccSAndroid Build Coastguard Worker   const AV1_COMMON *cm = &cpi->common;
2020*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCKD *xd = &x->e_mbd;
2021*77c1e3ccSAndroid Build Coastguard Worker   MB_MODE_INFO *mbmi = xd->mi[0];
2022*77c1e3ccSAndroid Build Coastguard Worker   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
2023*77c1e3ccSAndroid Build Coastguard Worker   int64_t best_rd = INT64_MAX;
2024*77c1e3ccSAndroid Build Coastguard Worker   uint16_t best_eob = 0;
2025*77c1e3ccSAndroid Build Coastguard Worker   TX_TYPE best_tx_type = DCT_DCT;
2026*77c1e3ccSAndroid Build Coastguard Worker   int rate_cost = 0;
2027*77c1e3ccSAndroid Build Coastguard Worker   struct macroblock_plane *const p = &x->plane[plane];
2028*77c1e3ccSAndroid Build Coastguard Worker   tran_low_t *orig_dqcoeff = p->dqcoeff;
2029*77c1e3ccSAndroid Build Coastguard Worker   tran_low_t *best_dqcoeff = x->dqcoeff_buf;
2030*77c1e3ccSAndroid Build Coastguard Worker   const int tx_type_map_idx =
2031*77c1e3ccSAndroid Build Coastguard Worker       plane ? 0 : blk_row * xd->tx_type_map_stride + blk_col;
2032*77c1e3ccSAndroid Build Coastguard Worker   av1_invalid_rd_stats(best_rd_stats);
2033*77c1e3ccSAndroid Build Coastguard Worker 
2034*77c1e3ccSAndroid Build Coastguard Worker   skip_trellis |= !is_trellis_used(cpi->optimize_seg_arr[xd->mi[0]->segment_id],
2035*77c1e3ccSAndroid Build Coastguard Worker                                    DRY_RUN_NORMAL);
2036*77c1e3ccSAndroid Build Coastguard Worker 
2037*77c1e3ccSAndroid Build Coastguard Worker   uint8_t best_txb_ctx = 0;
2038*77c1e3ccSAndroid Build Coastguard Worker   // txk_allowed = TX_TYPES: >1 tx types are allowed
2039*77c1e3ccSAndroid Build Coastguard Worker   // txk_allowed < TX_TYPES: only that specific tx type is allowed.
2040*77c1e3ccSAndroid Build Coastguard Worker   TX_TYPE txk_allowed = TX_TYPES;
2041*77c1e3ccSAndroid Build Coastguard Worker   int txk_map[TX_TYPES] = {
2042*77c1e3ccSAndroid Build Coastguard Worker     0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
2043*77c1e3ccSAndroid Build Coastguard Worker   };
2044*77c1e3ccSAndroid Build Coastguard Worker   const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
2045*77c1e3ccSAndroid Build Coastguard Worker   const int qstep = x->plane[plane].dequant_QTX[1] >> dequant_shift;
2046*77c1e3ccSAndroid Build Coastguard Worker 
2047*77c1e3ccSAndroid Build Coastguard Worker   const uint8_t txw = tx_size_wide[tx_size];
2048*77c1e3ccSAndroid Build Coastguard Worker   const uint8_t txh = tx_size_high[tx_size];
2049*77c1e3ccSAndroid Build Coastguard Worker   int64_t block_sse;
2050*77c1e3ccSAndroid Build Coastguard Worker   unsigned int block_mse_q8;
2051*77c1e3ccSAndroid Build Coastguard Worker   int dc_only_blk = 0;
2052*77c1e3ccSAndroid Build Coastguard Worker   const bool predict_dc_block =
2053*77c1e3ccSAndroid Build Coastguard Worker       txfm_params->predict_dc_level >= 1 && txw != 64 && txh != 64;
2054*77c1e3ccSAndroid Build Coastguard Worker   int64_t per_px_mean = INT64_MAX;
2055*77c1e3ccSAndroid Build Coastguard Worker   if (predict_dc_block) {
2056*77c1e3ccSAndroid Build Coastguard Worker     predict_dc_only_block(x, plane, plane_bsize, tx_size, block, blk_row,
2057*77c1e3ccSAndroid Build Coastguard Worker                           blk_col, best_rd_stats, &block_sse, &block_mse_q8,
2058*77c1e3ccSAndroid Build Coastguard Worker                           &per_px_mean, &dc_only_blk);
2059*77c1e3ccSAndroid Build Coastguard Worker     if (best_rd_stats->skip_txfm == 1) {
2060*77c1e3ccSAndroid Build Coastguard Worker       const TX_TYPE tx_type = DCT_DCT;
2061*77c1e3ccSAndroid Build Coastguard Worker       if (plane == 0) xd->tx_type_map[tx_type_map_idx] = tx_type;
2062*77c1e3ccSAndroid Build Coastguard Worker       return;
2063*77c1e3ccSAndroid Build Coastguard Worker     }
2064*77c1e3ccSAndroid Build Coastguard Worker   } else {
2065*77c1e3ccSAndroid Build Coastguard Worker     block_sse = av1_pixel_diff_dist(x, plane, blk_row, blk_col, plane_bsize,
2066*77c1e3ccSAndroid Build Coastguard Worker                                     txsize_to_bsize[tx_size], &block_mse_q8);
2067*77c1e3ccSAndroid Build Coastguard Worker     assert(block_mse_q8 != UINT_MAX);
2068*77c1e3ccSAndroid Build Coastguard Worker   }
2069*77c1e3ccSAndroid Build Coastguard Worker 
2070*77c1e3ccSAndroid Build Coastguard Worker   // Bit mask to indicate which transform types are allowed in the RD search.
2071*77c1e3ccSAndroid Build Coastguard Worker   uint16_t tx_mask;
2072*77c1e3ccSAndroid Build Coastguard Worker 
2073*77c1e3ccSAndroid Build Coastguard Worker   // Use DCT_DCT transform for DC only block.
2074*77c1e3ccSAndroid Build Coastguard Worker   if (dc_only_blk || cpi->sf.rt_sf.dct_only_palette_nonrd == 1)
2075*77c1e3ccSAndroid Build Coastguard Worker     tx_mask = 1 << DCT_DCT;
2076*77c1e3ccSAndroid Build Coastguard Worker   else
2077*77c1e3ccSAndroid Build Coastguard Worker     tx_mask = get_tx_mask(cpi, x, plane, block, blk_row, blk_col, plane_bsize,
2078*77c1e3ccSAndroid Build Coastguard Worker                           tx_size, txb_ctx, ftxs_mode, ref_best_rd,
2079*77c1e3ccSAndroid Build Coastguard Worker                           &txk_allowed, txk_map);
2080*77c1e3ccSAndroid Build Coastguard Worker   const uint16_t allowed_tx_mask = tx_mask;
2081*77c1e3ccSAndroid Build Coastguard Worker 
2082*77c1e3ccSAndroid Build Coastguard Worker   if (is_cur_buf_hbd(xd)) {
2083*77c1e3ccSAndroid Build Coastguard Worker     block_sse = ROUND_POWER_OF_TWO(block_sse, (xd->bd - 8) * 2);
2084*77c1e3ccSAndroid Build Coastguard Worker     block_mse_q8 = ROUND_POWER_OF_TWO(block_mse_q8, (xd->bd - 8) * 2);
2085*77c1e3ccSAndroid Build Coastguard Worker   }
2086*77c1e3ccSAndroid Build Coastguard Worker   block_sse *= 16;
2087*77c1e3ccSAndroid Build Coastguard Worker   // Use mse / qstep^2 based threshold logic to take decision of R-D
2088*77c1e3ccSAndroid Build Coastguard Worker   // optimization of coeffs. For smaller residuals, coeff optimization
2089*77c1e3ccSAndroid Build Coastguard Worker   // would be helpful. For larger residuals, R-D optimization may not be
2090*77c1e3ccSAndroid Build Coastguard Worker   // effective.
2091*77c1e3ccSAndroid Build Coastguard Worker   // TODO(any): Experiment with variance and mean based thresholds
2092*77c1e3ccSAndroid Build Coastguard Worker   const int perform_block_coeff_opt =
2093*77c1e3ccSAndroid Build Coastguard Worker       ((uint64_t)block_mse_q8 <=
2094*77c1e3ccSAndroid Build Coastguard Worker        (uint64_t)txfm_params->coeff_opt_thresholds[0] * qstep * qstep);
2095*77c1e3ccSAndroid Build Coastguard Worker   skip_trellis |= !perform_block_coeff_opt;
2096*77c1e3ccSAndroid Build Coastguard Worker 
2097*77c1e3ccSAndroid Build Coastguard Worker   // Flag to indicate if distortion should be calculated in transform domain or
2098*77c1e3ccSAndroid Build Coastguard Worker   // not during iterating through transform type candidates.
2099*77c1e3ccSAndroid Build Coastguard Worker   // Transform domain distortion is accurate for higher residuals.
2100*77c1e3ccSAndroid Build Coastguard Worker   // TODO(any): Experiment with variance and mean based thresholds
2101*77c1e3ccSAndroid Build Coastguard Worker   int use_transform_domain_distortion =
2102*77c1e3ccSAndroid Build Coastguard Worker       (txfm_params->use_transform_domain_distortion > 0) &&
2103*77c1e3ccSAndroid Build Coastguard Worker       (block_mse_q8 >= txfm_params->tx_domain_dist_threshold) &&
2104*77c1e3ccSAndroid Build Coastguard Worker       // Any 64-pt transforms only preserves half the coefficients.
2105*77c1e3ccSAndroid Build Coastguard Worker       // Therefore transform domain distortion is not valid for these
2106*77c1e3ccSAndroid Build Coastguard Worker       // transform sizes.
2107*77c1e3ccSAndroid Build Coastguard Worker       (txsize_sqr_up_map[tx_size] != TX_64X64) &&
2108*77c1e3ccSAndroid Build Coastguard Worker       // Use pixel domain distortion for DC only blocks
2109*77c1e3ccSAndroid Build Coastguard Worker       !dc_only_blk;
2110*77c1e3ccSAndroid Build Coastguard Worker   // Flag to indicate if an extra calculation of distortion in the pixel domain
2111*77c1e3ccSAndroid Build Coastguard Worker   // should be performed at the end, after the best transform type has been
2112*77c1e3ccSAndroid Build Coastguard Worker   // decided.
2113*77c1e3ccSAndroid Build Coastguard Worker   int calc_pixel_domain_distortion_final =
2114*77c1e3ccSAndroid Build Coastguard Worker       txfm_params->use_transform_domain_distortion == 1 &&
2115*77c1e3ccSAndroid Build Coastguard Worker       use_transform_domain_distortion && x->rd_model != LOW_TXFM_RD;
2116*77c1e3ccSAndroid Build Coastguard Worker   if (calc_pixel_domain_distortion_final &&
2117*77c1e3ccSAndroid Build Coastguard Worker       (txk_allowed < TX_TYPES || allowed_tx_mask == 0x0001))
2118*77c1e3ccSAndroid Build Coastguard Worker     calc_pixel_domain_distortion_final = use_transform_domain_distortion = 0;
2119*77c1e3ccSAndroid Build Coastguard Worker 
2120*77c1e3ccSAndroid Build Coastguard Worker   const uint16_t *eobs_ptr = x->plane[plane].eobs;
2121*77c1e3ccSAndroid Build Coastguard Worker 
2122*77c1e3ccSAndroid Build Coastguard Worker   TxfmParam txfm_param;
2123*77c1e3ccSAndroid Build Coastguard Worker   QUANT_PARAM quant_param;
2124*77c1e3ccSAndroid Build Coastguard Worker   int skip_trellis_based_on_satd[TX_TYPES] = { 0 };
2125*77c1e3ccSAndroid Build Coastguard Worker   av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
2126*77c1e3ccSAndroid Build Coastguard Worker   av1_setup_quant(tx_size, !skip_trellis,
2127*77c1e3ccSAndroid Build Coastguard Worker                   skip_trellis ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B
2128*77c1e3ccSAndroid Build Coastguard Worker                                                          : AV1_XFORM_QUANT_FP)
2129*77c1e3ccSAndroid Build Coastguard Worker                                : AV1_XFORM_QUANT_FP,
2130*77c1e3ccSAndroid Build Coastguard Worker                   cpi->oxcf.q_cfg.quant_b_adapt, &quant_param);
2131*77c1e3ccSAndroid Build Coastguard Worker 
2132*77c1e3ccSAndroid Build Coastguard Worker   // Iterate through all transform type candidates.
2133*77c1e3ccSAndroid Build Coastguard Worker   for (int idx = 0; idx < TX_TYPES; ++idx) {
2134*77c1e3ccSAndroid Build Coastguard Worker     const TX_TYPE tx_type = (TX_TYPE)txk_map[idx];
2135*77c1e3ccSAndroid Build Coastguard Worker     if (tx_type == TX_TYPE_INVALID || !check_bit_mask(allowed_tx_mask, tx_type))
2136*77c1e3ccSAndroid Build Coastguard Worker       continue;
2137*77c1e3ccSAndroid Build Coastguard Worker     txfm_param.tx_type = tx_type;
2138*77c1e3ccSAndroid Build Coastguard Worker     if (av1_use_qmatrix(&cm->quant_params, xd, mbmi->segment_id)) {
2139*77c1e3ccSAndroid Build Coastguard Worker       av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
2140*77c1e3ccSAndroid Build Coastguard Worker                         &quant_param);
2141*77c1e3ccSAndroid Build Coastguard Worker     }
2142*77c1e3ccSAndroid Build Coastguard Worker     if (plane == 0) xd->tx_type_map[tx_type_map_idx] = tx_type;
2143*77c1e3ccSAndroid Build Coastguard Worker     RD_STATS this_rd_stats;
2144*77c1e3ccSAndroid Build Coastguard Worker     av1_invalid_rd_stats(&this_rd_stats);
2145*77c1e3ccSAndroid Build Coastguard Worker 
2146*77c1e3ccSAndroid Build Coastguard Worker     if (!dc_only_blk)
2147*77c1e3ccSAndroid Build Coastguard Worker       av1_xform(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param);
2148*77c1e3ccSAndroid Build Coastguard Worker     else
2149*77c1e3ccSAndroid Build Coastguard Worker       av1_xform_dc_only(x, plane, block, &txfm_param, per_px_mean);
2150*77c1e3ccSAndroid Build Coastguard Worker 
2151*77c1e3ccSAndroid Build Coastguard Worker     skip_trellis_based_on_satd[tx_type] = skip_trellis_opt_based_on_satd(
2152*77c1e3ccSAndroid Build Coastguard Worker         x, &quant_param, plane, block, tx_size, cpi->oxcf.q_cfg.quant_b_adapt,
2153*77c1e3ccSAndroid Build Coastguard Worker         qstep, txfm_params->coeff_opt_thresholds[1], skip_trellis, dc_only_blk);
2154*77c1e3ccSAndroid Build Coastguard Worker 
2155*77c1e3ccSAndroid Build Coastguard Worker     av1_quant(x, plane, block, &txfm_param, &quant_param);
2156*77c1e3ccSAndroid Build Coastguard Worker 
2157*77c1e3ccSAndroid Build Coastguard Worker     // Calculate rate cost of quantized coefficients.
2158*77c1e3ccSAndroid Build Coastguard Worker     if (quant_param.use_optimize_b) {
2159*77c1e3ccSAndroid Build Coastguard Worker       // TODO(aomedia:3209): update Trellis quantization to take into account
2160*77c1e3ccSAndroid Build Coastguard Worker       // quantization matrices.
2161*77c1e3ccSAndroid Build Coastguard Worker       av1_optimize_b(cpi, x, plane, block, tx_size, tx_type, txb_ctx,
2162*77c1e3ccSAndroid Build Coastguard Worker                      &rate_cost);
2163*77c1e3ccSAndroid Build Coastguard Worker     } else {
2164*77c1e3ccSAndroid Build Coastguard Worker       rate_cost = cost_coeffs(x, plane, block, tx_size, tx_type, txb_ctx,
2165*77c1e3ccSAndroid Build Coastguard Worker                               cm->features.reduced_tx_set_used);
2166*77c1e3ccSAndroid Build Coastguard Worker     }
2167*77c1e3ccSAndroid Build Coastguard Worker 
2168*77c1e3ccSAndroid Build Coastguard Worker     // If rd cost based on coeff rate alone is already more than best_rd,
2169*77c1e3ccSAndroid Build Coastguard Worker     // terminate early.
2170*77c1e3ccSAndroid Build Coastguard Worker     if (RDCOST(x->rdmult, rate_cost, 0) > best_rd) continue;
2171*77c1e3ccSAndroid Build Coastguard Worker 
2172*77c1e3ccSAndroid Build Coastguard Worker     // Calculate distortion.
2173*77c1e3ccSAndroid Build Coastguard Worker     if (eobs_ptr[block] == 0) {
2174*77c1e3ccSAndroid Build Coastguard Worker       // When eob is 0, pixel domain distortion is more efficient and accurate.
2175*77c1e3ccSAndroid Build Coastguard Worker       this_rd_stats.dist = this_rd_stats.sse = block_sse;
2176*77c1e3ccSAndroid Build Coastguard Worker     } else if (dc_only_blk) {
2177*77c1e3ccSAndroid Build Coastguard Worker       this_rd_stats.sse = block_sse;
2178*77c1e3ccSAndroid Build Coastguard Worker       this_rd_stats.dist = dist_block_px_domain(
2179*77c1e3ccSAndroid Build Coastguard Worker           cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
2180*77c1e3ccSAndroid Build Coastguard Worker     } else if (use_transform_domain_distortion) {
2181*77c1e3ccSAndroid Build Coastguard Worker       const SCAN_ORDER *const scan_order =
2182*77c1e3ccSAndroid Build Coastguard Worker           get_scan(txfm_param.tx_size, txfm_param.tx_type);
2183*77c1e3ccSAndroid Build Coastguard Worker       dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix,
2184*77c1e3ccSAndroid Build Coastguard Worker                            scan_order->scan, &this_rd_stats.dist,
2185*77c1e3ccSAndroid Build Coastguard Worker                            &this_rd_stats.sse);
2186*77c1e3ccSAndroid Build Coastguard Worker     } else {
2187*77c1e3ccSAndroid Build Coastguard Worker       int64_t sse_diff = INT64_MAX;
2188*77c1e3ccSAndroid Build Coastguard Worker       // high_energy threshold assumes that every pixel within a txfm block
2189*77c1e3ccSAndroid Build Coastguard Worker       // has a residue energy of at least 25% of the maximum, i.e. 128 * 128
2190*77c1e3ccSAndroid Build Coastguard Worker       // for 8 bit.
2191*77c1e3ccSAndroid Build Coastguard Worker       const int64_t high_energy_thresh =
2192*77c1e3ccSAndroid Build Coastguard Worker           ((int64_t)128 * 128 * tx_size_2d[tx_size]);
2193*77c1e3ccSAndroid Build Coastguard Worker       const int is_high_energy = (block_sse >= high_energy_thresh);
2194*77c1e3ccSAndroid Build Coastguard Worker       if (tx_size == TX_64X64 || is_high_energy) {
2195*77c1e3ccSAndroid Build Coastguard Worker         // Because 3 out 4 quadrants of transform coefficients are forced to
2196*77c1e3ccSAndroid Build Coastguard Worker         // zero, the inverse transform has a tendency to overflow. sse_diff
2197*77c1e3ccSAndroid Build Coastguard Worker         // is effectively the energy of those 3 quadrants, here we use it
2198*77c1e3ccSAndroid Build Coastguard Worker         // to decide if we should do pixel domain distortion. If the energy
2199*77c1e3ccSAndroid Build Coastguard Worker         // is mostly in first quadrant, then it is unlikely that we have
2200*77c1e3ccSAndroid Build Coastguard Worker         // overflow issue in inverse transform.
2201*77c1e3ccSAndroid Build Coastguard Worker         const SCAN_ORDER *const scan_order =
2202*77c1e3ccSAndroid Build Coastguard Worker             get_scan(txfm_param.tx_size, txfm_param.tx_type);
2203*77c1e3ccSAndroid Build Coastguard Worker         dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix,
2204*77c1e3ccSAndroid Build Coastguard Worker                              scan_order->scan, &this_rd_stats.dist,
2205*77c1e3ccSAndroid Build Coastguard Worker                              &this_rd_stats.sse);
2206*77c1e3ccSAndroid Build Coastguard Worker         sse_diff = block_sse - this_rd_stats.sse;
2207*77c1e3ccSAndroid Build Coastguard Worker       }
2208*77c1e3ccSAndroid Build Coastguard Worker       if (tx_size != TX_64X64 || !is_high_energy ||
2209*77c1e3ccSAndroid Build Coastguard Worker           (sse_diff * 2) < this_rd_stats.sse) {
2210*77c1e3ccSAndroid Build Coastguard Worker         const int64_t tx_domain_dist = this_rd_stats.dist;
2211*77c1e3ccSAndroid Build Coastguard Worker         this_rd_stats.dist = dist_block_px_domain(
2212*77c1e3ccSAndroid Build Coastguard Worker             cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
2213*77c1e3ccSAndroid Build Coastguard Worker         // For high energy blocks, occasionally, the pixel domain distortion
2214*77c1e3ccSAndroid Build Coastguard Worker         // can be artificially low due to clamping at reconstruction stage
2215*77c1e3ccSAndroid Build Coastguard Worker         // even when inverse transform output is hugely different from the
2216*77c1e3ccSAndroid Build Coastguard Worker         // actual residue.
2217*77c1e3ccSAndroid Build Coastguard Worker         if (is_high_energy && this_rd_stats.dist < tx_domain_dist)
2218*77c1e3ccSAndroid Build Coastguard Worker           this_rd_stats.dist = tx_domain_dist;
2219*77c1e3ccSAndroid Build Coastguard Worker       } else {
2220*77c1e3ccSAndroid Build Coastguard Worker         assert(sse_diff < INT64_MAX);
2221*77c1e3ccSAndroid Build Coastguard Worker         this_rd_stats.dist += sse_diff;
2222*77c1e3ccSAndroid Build Coastguard Worker       }
2223*77c1e3ccSAndroid Build Coastguard Worker       this_rd_stats.sse = block_sse;
2224*77c1e3ccSAndroid Build Coastguard Worker     }
2225*77c1e3ccSAndroid Build Coastguard Worker 
2226*77c1e3ccSAndroid Build Coastguard Worker     this_rd_stats.rate = rate_cost;
2227*77c1e3ccSAndroid Build Coastguard Worker 
2228*77c1e3ccSAndroid Build Coastguard Worker     const int64_t rd =
2229*77c1e3ccSAndroid Build Coastguard Worker         RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
2230*77c1e3ccSAndroid Build Coastguard Worker 
2231*77c1e3ccSAndroid Build Coastguard Worker     if (rd < best_rd) {
2232*77c1e3ccSAndroid Build Coastguard Worker       best_rd = rd;
2233*77c1e3ccSAndroid Build Coastguard Worker       *best_rd_stats = this_rd_stats;
2234*77c1e3ccSAndroid Build Coastguard Worker       best_tx_type = tx_type;
2235*77c1e3ccSAndroid Build Coastguard Worker       best_txb_ctx = x->plane[plane].txb_entropy_ctx[block];
2236*77c1e3ccSAndroid Build Coastguard Worker       best_eob = x->plane[plane].eobs[block];
2237*77c1e3ccSAndroid Build Coastguard Worker       // Swap dqcoeff buffers
2238*77c1e3ccSAndroid Build Coastguard Worker       tran_low_t *const tmp_dqcoeff = best_dqcoeff;
2239*77c1e3ccSAndroid Build Coastguard Worker       best_dqcoeff = p->dqcoeff;
2240*77c1e3ccSAndroid Build Coastguard Worker       p->dqcoeff = tmp_dqcoeff;
2241*77c1e3ccSAndroid Build Coastguard Worker     }
2242*77c1e3ccSAndroid Build Coastguard Worker 
2243*77c1e3ccSAndroid Build Coastguard Worker #if CONFIG_COLLECT_RD_STATS == 1
2244*77c1e3ccSAndroid Build Coastguard Worker     if (plane == 0) {
2245*77c1e3ccSAndroid Build Coastguard Worker       PrintTransformUnitStats(cpi, x, &this_rd_stats, blk_row, blk_col,
2246*77c1e3ccSAndroid Build Coastguard Worker                               plane_bsize, tx_size, tx_type, rd);
2247*77c1e3ccSAndroid Build Coastguard Worker     }
2248*77c1e3ccSAndroid Build Coastguard Worker #endif  // CONFIG_COLLECT_RD_STATS == 1
2249*77c1e3ccSAndroid Build Coastguard Worker 
2250*77c1e3ccSAndroid Build Coastguard Worker #if COLLECT_TX_SIZE_DATA
2251*77c1e3ccSAndroid Build Coastguard Worker     // Generate small sample to restrict output size.
2252*77c1e3ccSAndroid Build Coastguard Worker     static unsigned int seed = 21743;
2253*77c1e3ccSAndroid Build Coastguard Worker     if (lcg_rand16(&seed) % 200 == 0) {
2254*77c1e3ccSAndroid Build Coastguard Worker       FILE *fp = NULL;
2255*77c1e3ccSAndroid Build Coastguard Worker 
2256*77c1e3ccSAndroid Build Coastguard Worker       if (within_border) {
2257*77c1e3ccSAndroid Build Coastguard Worker         fp = fopen(av1_tx_size_data_output_file, "a");
2258*77c1e3ccSAndroid Build Coastguard Worker       }
2259*77c1e3ccSAndroid Build Coastguard Worker 
2260*77c1e3ccSAndroid Build Coastguard Worker       if (fp) {
2261*77c1e3ccSAndroid Build Coastguard Worker         // Transform info and RD
2262*77c1e3ccSAndroid Build Coastguard Worker         const int txb_w = tx_size_wide[tx_size];
2263*77c1e3ccSAndroid Build Coastguard Worker         const int txb_h = tx_size_high[tx_size];
2264*77c1e3ccSAndroid Build Coastguard Worker 
2265*77c1e3ccSAndroid Build Coastguard Worker         // Residue signal.
2266*77c1e3ccSAndroid Build Coastguard Worker         const int diff_stride = block_size_wide[plane_bsize];
2267*77c1e3ccSAndroid Build Coastguard Worker         struct macroblock_plane *const p = &x->plane[plane];
2268*77c1e3ccSAndroid Build Coastguard Worker         const int16_t *src_diff =
2269*77c1e3ccSAndroid Build Coastguard Worker             &p->src_diff[(blk_row * diff_stride + blk_col) * 4];
2270*77c1e3ccSAndroid Build Coastguard Worker 
2271*77c1e3ccSAndroid Build Coastguard Worker         for (int r = 0; r < txb_h; ++r) {
2272*77c1e3ccSAndroid Build Coastguard Worker           for (int c = 0; c < txb_w; ++c) {
2273*77c1e3ccSAndroid Build Coastguard Worker             fprintf(fp, "%d,", src_diff[c]);
2274*77c1e3ccSAndroid Build Coastguard Worker           }
2275*77c1e3ccSAndroid Build Coastguard Worker           src_diff += diff_stride;
2276*77c1e3ccSAndroid Build Coastguard Worker         }
2277*77c1e3ccSAndroid Build Coastguard Worker 
2278*77c1e3ccSAndroid Build Coastguard Worker         fprintf(fp, "%d,%d,%d,%" PRId64, txb_w, txb_h, tx_type, rd);
2279*77c1e3ccSAndroid Build Coastguard Worker         fprintf(fp, "\n");
2280*77c1e3ccSAndroid Build Coastguard Worker         fclose(fp);
2281*77c1e3ccSAndroid Build Coastguard Worker       }
2282*77c1e3ccSAndroid Build Coastguard Worker     }
2283*77c1e3ccSAndroid Build Coastguard Worker #endif  // COLLECT_TX_SIZE_DATA
2284*77c1e3ccSAndroid Build Coastguard Worker 
2285*77c1e3ccSAndroid Build Coastguard Worker     // If the current best RD cost is much worse than the reference RD cost,
2286*77c1e3ccSAndroid Build Coastguard Worker     // terminate early.
2287*77c1e3ccSAndroid Build Coastguard Worker     if (cpi->sf.tx_sf.adaptive_txb_search_level) {
2288*77c1e3ccSAndroid Build Coastguard Worker       if ((best_rd - (best_rd >> cpi->sf.tx_sf.adaptive_txb_search_level)) >
2289*77c1e3ccSAndroid Build Coastguard Worker           ref_best_rd) {
2290*77c1e3ccSAndroid Build Coastguard Worker         break;
2291*77c1e3ccSAndroid Build Coastguard Worker       }
2292*77c1e3ccSAndroid Build Coastguard Worker     }
2293*77c1e3ccSAndroid Build Coastguard Worker 
2294*77c1e3ccSAndroid Build Coastguard Worker     // Terminate transform type search if the block has been quantized to
2295*77c1e3ccSAndroid Build Coastguard Worker     // all zero.
2296*77c1e3ccSAndroid Build Coastguard Worker     if (cpi->sf.tx_sf.tx_type_search.skip_tx_search && !best_eob) break;
2297*77c1e3ccSAndroid Build Coastguard Worker   }
2298*77c1e3ccSAndroid Build Coastguard Worker 
2299*77c1e3ccSAndroid Build Coastguard Worker   assert(best_rd != INT64_MAX);
2300*77c1e3ccSAndroid Build Coastguard Worker 
2301*77c1e3ccSAndroid Build Coastguard Worker   best_rd_stats->skip_txfm = best_eob == 0;
2302*77c1e3ccSAndroid Build Coastguard Worker   if (plane == 0) update_txk_array(xd, blk_row, blk_col, tx_size, best_tx_type);
2303*77c1e3ccSAndroid Build Coastguard Worker   x->plane[plane].txb_entropy_ctx[block] = best_txb_ctx;
2304*77c1e3ccSAndroid Build Coastguard Worker   x->plane[plane].eobs[block] = best_eob;
2305*77c1e3ccSAndroid Build Coastguard Worker   skip_trellis = skip_trellis_based_on_satd[best_tx_type];
2306*77c1e3ccSAndroid Build Coastguard Worker 
2307*77c1e3ccSAndroid Build Coastguard Worker   // Point dqcoeff to the quantized coefficients corresponding to the best
2308*77c1e3ccSAndroid Build Coastguard Worker   // transform type, then we can skip transform and quantization, e.g. in the
2309*77c1e3ccSAndroid Build Coastguard Worker   // final pixel domain distortion calculation and recon_intra().
2310*77c1e3ccSAndroid Build Coastguard Worker   p->dqcoeff = best_dqcoeff;
2311*77c1e3ccSAndroid Build Coastguard Worker 
2312*77c1e3ccSAndroid Build Coastguard Worker   if (calc_pixel_domain_distortion_final && best_eob) {
2313*77c1e3ccSAndroid Build Coastguard Worker     best_rd_stats->dist = dist_block_px_domain(
2314*77c1e3ccSAndroid Build Coastguard Worker         cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
2315*77c1e3ccSAndroid Build Coastguard Worker     best_rd_stats->sse = block_sse;
2316*77c1e3ccSAndroid Build Coastguard Worker   }
2317*77c1e3ccSAndroid Build Coastguard Worker 
2318*77c1e3ccSAndroid Build Coastguard Worker   // Intra mode needs decoded pixels such that the next transform block
2319*77c1e3ccSAndroid Build Coastguard Worker   // can use them for prediction.
2320*77c1e3ccSAndroid Build Coastguard Worker   recon_intra(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
2321*77c1e3ccSAndroid Build Coastguard Worker               txb_ctx, skip_trellis, best_tx_type, 0, &rate_cost, best_eob);
2322*77c1e3ccSAndroid Build Coastguard Worker   p->dqcoeff = orig_dqcoeff;
2323*77c1e3ccSAndroid Build Coastguard Worker }
2324*77c1e3ccSAndroid Build Coastguard Worker 
2325*77c1e3ccSAndroid Build Coastguard Worker // Pick transform type for a luma transform block of tx_size. Note this function
2326*77c1e3ccSAndroid Build Coastguard Worker // is used only for inter-predicted blocks.
tx_type_rd(const AV1_COMP * cpi,MACROBLOCK * x,TX_SIZE tx_size,int blk_row,int blk_col,int block,int plane_bsize,TXB_CTX * txb_ctx,RD_STATS * rd_stats,FAST_TX_SEARCH_MODE ftxs_mode,int64_t ref_rdcost)2327*77c1e3ccSAndroid Build Coastguard Worker static inline void tx_type_rd(const AV1_COMP *cpi, MACROBLOCK *x,
2328*77c1e3ccSAndroid Build Coastguard Worker                               TX_SIZE tx_size, int blk_row, int blk_col,
2329*77c1e3ccSAndroid Build Coastguard Worker                               int block, int plane_bsize, TXB_CTX *txb_ctx,
2330*77c1e3ccSAndroid Build Coastguard Worker                               RD_STATS *rd_stats, FAST_TX_SEARCH_MODE ftxs_mode,
2331*77c1e3ccSAndroid Build Coastguard Worker                               int64_t ref_rdcost) {
2332*77c1e3ccSAndroid Build Coastguard Worker   assert(is_inter_block(x->e_mbd.mi[0]));
2333*77c1e3ccSAndroid Build Coastguard Worker   RD_STATS this_rd_stats;
2334*77c1e3ccSAndroid Build Coastguard Worker   const int skip_trellis = 0;
2335*77c1e3ccSAndroid Build Coastguard Worker   search_tx_type(cpi, x, 0, block, blk_row, blk_col, plane_bsize, tx_size,
2336*77c1e3ccSAndroid Build Coastguard Worker                  txb_ctx, ftxs_mode, skip_trellis, ref_rdcost, &this_rd_stats);
2337*77c1e3ccSAndroid Build Coastguard Worker 
2338*77c1e3ccSAndroid Build Coastguard Worker   av1_merge_rd_stats(rd_stats, &this_rd_stats);
2339*77c1e3ccSAndroid Build Coastguard Worker }
2340*77c1e3ccSAndroid Build Coastguard Worker 
try_tx_block_no_split(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,int depth,BLOCK_SIZE plane_bsize,const ENTROPY_CONTEXT * ta,const ENTROPY_CONTEXT * tl,int txfm_partition_ctx,RD_STATS * rd_stats,int64_t ref_best_rd,FAST_TX_SEARCH_MODE ftxs_mode,TxCandidateInfo * no_split)2341*77c1e3ccSAndroid Build Coastguard Worker static inline void try_tx_block_no_split(
2342*77c1e3ccSAndroid Build Coastguard Worker     const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
2343*77c1e3ccSAndroid Build Coastguard Worker     TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize,
2344*77c1e3ccSAndroid Build Coastguard Worker     const ENTROPY_CONTEXT *ta, const ENTROPY_CONTEXT *tl,
2345*77c1e3ccSAndroid Build Coastguard Worker     int txfm_partition_ctx, RD_STATS *rd_stats, int64_t ref_best_rd,
2346*77c1e3ccSAndroid Build Coastguard Worker     FAST_TX_SEARCH_MODE ftxs_mode, TxCandidateInfo *no_split) {
2347*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCKD *const xd = &x->e_mbd;
2348*77c1e3ccSAndroid Build Coastguard Worker   MB_MODE_INFO *const mbmi = xd->mi[0];
2349*77c1e3ccSAndroid Build Coastguard Worker   struct macroblock_plane *const p = &x->plane[0];
2350*77c1e3ccSAndroid Build Coastguard Worker   const int bw = mi_size_wide[plane_bsize];
2351*77c1e3ccSAndroid Build Coastguard Worker   const ENTROPY_CONTEXT *const pta = ta + blk_col;
2352*77c1e3ccSAndroid Build Coastguard Worker   const ENTROPY_CONTEXT *const ptl = tl + blk_row;
2353*77c1e3ccSAndroid Build Coastguard Worker   const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
2354*77c1e3ccSAndroid Build Coastguard Worker   TXB_CTX txb_ctx;
2355*77c1e3ccSAndroid Build Coastguard Worker   get_txb_ctx(plane_bsize, tx_size, 0, pta, ptl, &txb_ctx);
2356*77c1e3ccSAndroid Build Coastguard Worker   const int zero_blk_rate = x->coeff_costs.coeff_costs[txs_ctx][PLANE_TYPE_Y]
2357*77c1e3ccSAndroid Build Coastguard Worker                                 .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
2358*77c1e3ccSAndroid Build Coastguard Worker   rd_stats->zero_rate = zero_blk_rate;
2359*77c1e3ccSAndroid Build Coastguard Worker   const int index = av1_get_txb_size_index(plane_bsize, blk_row, blk_col);
2360*77c1e3ccSAndroid Build Coastguard Worker   mbmi->inter_tx_size[index] = tx_size;
2361*77c1e3ccSAndroid Build Coastguard Worker   tx_type_rd(cpi, x, tx_size, blk_row, blk_col, block, plane_bsize, &txb_ctx,
2362*77c1e3ccSAndroid Build Coastguard Worker              rd_stats, ftxs_mode, ref_best_rd);
2363*77c1e3ccSAndroid Build Coastguard Worker   assert(rd_stats->rate < INT_MAX);
2364*77c1e3ccSAndroid Build Coastguard Worker 
2365*77c1e3ccSAndroid Build Coastguard Worker   const int pick_skip_txfm =
2366*77c1e3ccSAndroid Build Coastguard Worker       !xd->lossless[mbmi->segment_id] &&
2367*77c1e3ccSAndroid Build Coastguard Worker       (rd_stats->skip_txfm == 1 ||
2368*77c1e3ccSAndroid Build Coastguard Worker        RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >=
2369*77c1e3ccSAndroid Build Coastguard Worker            RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse));
2370*77c1e3ccSAndroid Build Coastguard Worker   if (pick_skip_txfm) {
2371*77c1e3ccSAndroid Build Coastguard Worker #if CONFIG_RD_DEBUG
2372*77c1e3ccSAndroid Build Coastguard Worker     update_txb_coeff_cost(rd_stats, 0, zero_blk_rate - rd_stats->rate);
2373*77c1e3ccSAndroid Build Coastguard Worker #endif  // CONFIG_RD_DEBUG
2374*77c1e3ccSAndroid Build Coastguard Worker     rd_stats->rate = zero_blk_rate;
2375*77c1e3ccSAndroid Build Coastguard Worker     rd_stats->dist = rd_stats->sse;
2376*77c1e3ccSAndroid Build Coastguard Worker     p->eobs[block] = 0;
2377*77c1e3ccSAndroid Build Coastguard Worker     update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
2378*77c1e3ccSAndroid Build Coastguard Worker   }
2379*77c1e3ccSAndroid Build Coastguard Worker   rd_stats->skip_txfm = pick_skip_txfm;
2380*77c1e3ccSAndroid Build Coastguard Worker   set_blk_skip(x->txfm_search_info.blk_skip, 0, blk_row * bw + blk_col,
2381*77c1e3ccSAndroid Build Coastguard Worker                pick_skip_txfm);
2382*77c1e3ccSAndroid Build Coastguard Worker 
2383*77c1e3ccSAndroid Build Coastguard Worker   if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
2384*77c1e3ccSAndroid Build Coastguard Worker     rd_stats->rate += x->mode_costs.txfm_partition_cost[txfm_partition_ctx][0];
2385*77c1e3ccSAndroid Build Coastguard Worker 
2386*77c1e3ccSAndroid Build Coastguard Worker   no_split->rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
2387*77c1e3ccSAndroid Build Coastguard Worker   no_split->txb_entropy_ctx = p->txb_entropy_ctx[block];
2388*77c1e3ccSAndroid Build Coastguard Worker   no_split->tx_type =
2389*77c1e3ccSAndroid Build Coastguard Worker       xd->tx_type_map[blk_row * xd->tx_type_map_stride + blk_col];
2390*77c1e3ccSAndroid Build Coastguard Worker }
2391*77c1e3ccSAndroid Build Coastguard Worker 
try_tx_block_split(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,int depth,BLOCK_SIZE plane_bsize,ENTROPY_CONTEXT * ta,ENTROPY_CONTEXT * tl,TXFM_CONTEXT * tx_above,TXFM_CONTEXT * tx_left,int txfm_partition_ctx,int64_t no_split_rd,int64_t ref_best_rd,FAST_TX_SEARCH_MODE ftxs_mode,RD_STATS * split_rd_stats)2392*77c1e3ccSAndroid Build Coastguard Worker static inline void try_tx_block_split(
2393*77c1e3ccSAndroid Build Coastguard Worker     const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
2394*77c1e3ccSAndroid Build Coastguard Worker     TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
2395*77c1e3ccSAndroid Build Coastguard Worker     ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
2396*77c1e3ccSAndroid Build Coastguard Worker     int txfm_partition_ctx, int64_t no_split_rd, int64_t ref_best_rd,
2397*77c1e3ccSAndroid Build Coastguard Worker     FAST_TX_SEARCH_MODE ftxs_mode, RD_STATS *split_rd_stats) {
2398*77c1e3ccSAndroid Build Coastguard Worker   assert(tx_size < TX_SIZES_ALL);
2399*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCKD *const xd = &x->e_mbd;
2400*77c1e3ccSAndroid Build Coastguard Worker   const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
2401*77c1e3ccSAndroid Build Coastguard Worker   const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
2402*77c1e3ccSAndroid Build Coastguard Worker   const int txb_width = tx_size_wide_unit[tx_size];
2403*77c1e3ccSAndroid Build Coastguard Worker   const int txb_height = tx_size_high_unit[tx_size];
2404*77c1e3ccSAndroid Build Coastguard Worker   // Transform size after splitting current block.
2405*77c1e3ccSAndroid Build Coastguard Worker   const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
2406*77c1e3ccSAndroid Build Coastguard Worker   const int sub_txb_width = tx_size_wide_unit[sub_txs];
2407*77c1e3ccSAndroid Build Coastguard Worker   const int sub_txb_height = tx_size_high_unit[sub_txs];
2408*77c1e3ccSAndroid Build Coastguard Worker   const int sub_step = sub_txb_width * sub_txb_height;
2409*77c1e3ccSAndroid Build Coastguard Worker   const int nblks = (txb_height / sub_txb_height) * (txb_width / sub_txb_width);
2410*77c1e3ccSAndroid Build Coastguard Worker   assert(nblks > 0);
2411*77c1e3ccSAndroid Build Coastguard Worker   av1_init_rd_stats(split_rd_stats);
2412*77c1e3ccSAndroid Build Coastguard Worker   split_rd_stats->rate =
2413*77c1e3ccSAndroid Build Coastguard Worker       x->mode_costs.txfm_partition_cost[txfm_partition_ctx][1];
2414*77c1e3ccSAndroid Build Coastguard Worker 
2415*77c1e3ccSAndroid Build Coastguard Worker   for (int r = 0, blk_idx = 0; r < txb_height; r += sub_txb_height) {
2416*77c1e3ccSAndroid Build Coastguard Worker     const int offsetr = blk_row + r;
2417*77c1e3ccSAndroid Build Coastguard Worker     if (offsetr >= max_blocks_high) break;
2418*77c1e3ccSAndroid Build Coastguard Worker     for (int c = 0; c < txb_width; c += sub_txb_width, ++blk_idx) {
2419*77c1e3ccSAndroid Build Coastguard Worker       assert(blk_idx < 4);
2420*77c1e3ccSAndroid Build Coastguard Worker       const int offsetc = blk_col + c;
2421*77c1e3ccSAndroid Build Coastguard Worker       if (offsetc >= max_blocks_wide) continue;
2422*77c1e3ccSAndroid Build Coastguard Worker 
2423*77c1e3ccSAndroid Build Coastguard Worker       RD_STATS this_rd_stats;
2424*77c1e3ccSAndroid Build Coastguard Worker       int this_cost_valid = 1;
2425*77c1e3ccSAndroid Build Coastguard Worker       select_tx_block(cpi, x, offsetr, offsetc, block, sub_txs, depth + 1,
2426*77c1e3ccSAndroid Build Coastguard Worker                       plane_bsize, ta, tl, tx_above, tx_left, &this_rd_stats,
2427*77c1e3ccSAndroid Build Coastguard Worker                       no_split_rd / nblks, ref_best_rd - split_rd_stats->rdcost,
2428*77c1e3ccSAndroid Build Coastguard Worker                       &this_cost_valid, ftxs_mode);
2429*77c1e3ccSAndroid Build Coastguard Worker       if (!this_cost_valid) {
2430*77c1e3ccSAndroid Build Coastguard Worker         split_rd_stats->rdcost = INT64_MAX;
2431*77c1e3ccSAndroid Build Coastguard Worker         return;
2432*77c1e3ccSAndroid Build Coastguard Worker       }
2433*77c1e3ccSAndroid Build Coastguard Worker       av1_merge_rd_stats(split_rd_stats, &this_rd_stats);
2434*77c1e3ccSAndroid Build Coastguard Worker       split_rd_stats->rdcost =
2435*77c1e3ccSAndroid Build Coastguard Worker           RDCOST(x->rdmult, split_rd_stats->rate, split_rd_stats->dist);
2436*77c1e3ccSAndroid Build Coastguard Worker       if (split_rd_stats->rdcost > ref_best_rd) {
2437*77c1e3ccSAndroid Build Coastguard Worker         split_rd_stats->rdcost = INT64_MAX;
2438*77c1e3ccSAndroid Build Coastguard Worker         return;
2439*77c1e3ccSAndroid Build Coastguard Worker       }
2440*77c1e3ccSAndroid Build Coastguard Worker       block += sub_step;
2441*77c1e3ccSAndroid Build Coastguard Worker     }
2442*77c1e3ccSAndroid Build Coastguard Worker   }
2443*77c1e3ccSAndroid Build Coastguard Worker }
2444*77c1e3ccSAndroid Build Coastguard Worker 
get_var(float mean,double x2_sum,int num)2445*77c1e3ccSAndroid Build Coastguard Worker static float get_var(float mean, double x2_sum, int num) {
2446*77c1e3ccSAndroid Build Coastguard Worker   const float e_x2 = (float)(x2_sum / num);
2447*77c1e3ccSAndroid Build Coastguard Worker   const float diff = e_x2 - mean * mean;
2448*77c1e3ccSAndroid Build Coastguard Worker   return diff;
2449*77c1e3ccSAndroid Build Coastguard Worker }
2450*77c1e3ccSAndroid Build Coastguard Worker 
get_blk_var_dev(const int16_t * data,int stride,int bw,int bh,float * dev_of_mean,float * var_of_vars)2451*77c1e3ccSAndroid Build Coastguard Worker static inline void get_blk_var_dev(const int16_t *data, int stride, int bw,
2452*77c1e3ccSAndroid Build Coastguard Worker                                    int bh, float *dev_of_mean,
2453*77c1e3ccSAndroid Build Coastguard Worker                                    float *var_of_vars) {
2454*77c1e3ccSAndroid Build Coastguard Worker   const int16_t *const data_ptr = &data[0];
2455*77c1e3ccSAndroid Build Coastguard Worker   const int subh = (bh >= bw) ? (bh >> 1) : bh;
2456*77c1e3ccSAndroid Build Coastguard Worker   const int subw = (bw >= bh) ? (bw >> 1) : bw;
2457*77c1e3ccSAndroid Build Coastguard Worker   const int num = bw * bh;
2458*77c1e3ccSAndroid Build Coastguard Worker   const int sub_num = subw * subh;
2459*77c1e3ccSAndroid Build Coastguard Worker   int total_x_sum = 0;
2460*77c1e3ccSAndroid Build Coastguard Worker   int64_t total_x2_sum = 0;
2461*77c1e3ccSAndroid Build Coastguard Worker   int blk_idx = 0;
2462*77c1e3ccSAndroid Build Coastguard Worker   float var_sum = 0.0f;
2463*77c1e3ccSAndroid Build Coastguard Worker   float mean_sum = 0.0f;
2464*77c1e3ccSAndroid Build Coastguard Worker   double var2_sum = 0.0f;
2465*77c1e3ccSAndroid Build Coastguard Worker   double mean2_sum = 0.0f;
2466*77c1e3ccSAndroid Build Coastguard Worker 
2467*77c1e3ccSAndroid Build Coastguard Worker   for (int row = 0; row < bh; row += subh) {
2468*77c1e3ccSAndroid Build Coastguard Worker     for (int col = 0; col < bw; col += subw) {
2469*77c1e3ccSAndroid Build Coastguard Worker       int x_sum;
2470*77c1e3ccSAndroid Build Coastguard Worker       int64_t x2_sum;
2471*77c1e3ccSAndroid Build Coastguard Worker       aom_get_blk_sse_sum(data_ptr + row * stride + col, stride, subw, subh,
2472*77c1e3ccSAndroid Build Coastguard Worker                           &x_sum, &x2_sum);
2473*77c1e3ccSAndroid Build Coastguard Worker       total_x_sum += x_sum;
2474*77c1e3ccSAndroid Build Coastguard Worker       total_x2_sum += x2_sum;
2475*77c1e3ccSAndroid Build Coastguard Worker 
2476*77c1e3ccSAndroid Build Coastguard Worker       const float mean = (float)x_sum / sub_num;
2477*77c1e3ccSAndroid Build Coastguard Worker       const float var = get_var(mean, (double)x2_sum, sub_num);
2478*77c1e3ccSAndroid Build Coastguard Worker       mean_sum += mean;
2479*77c1e3ccSAndroid Build Coastguard Worker       mean2_sum += (double)(mean * mean);
2480*77c1e3ccSAndroid Build Coastguard Worker       var_sum += var;
2481*77c1e3ccSAndroid Build Coastguard Worker       var2_sum += var * var;
2482*77c1e3ccSAndroid Build Coastguard Worker       blk_idx++;
2483*77c1e3ccSAndroid Build Coastguard Worker     }
2484*77c1e3ccSAndroid Build Coastguard Worker   }
2485*77c1e3ccSAndroid Build Coastguard Worker 
2486*77c1e3ccSAndroid Build Coastguard Worker   const float lvl0_mean = (float)total_x_sum / num;
2487*77c1e3ccSAndroid Build Coastguard Worker   const float block_var = get_var(lvl0_mean, (double)total_x2_sum, num);
2488*77c1e3ccSAndroid Build Coastguard Worker   mean_sum += lvl0_mean;
2489*77c1e3ccSAndroid Build Coastguard Worker   mean2_sum += (double)(lvl0_mean * lvl0_mean);
2490*77c1e3ccSAndroid Build Coastguard Worker   var_sum += block_var;
2491*77c1e3ccSAndroid Build Coastguard Worker   var2_sum += block_var * block_var;
2492*77c1e3ccSAndroid Build Coastguard Worker   const float av_mean = mean_sum / 5;
2493*77c1e3ccSAndroid Build Coastguard Worker 
2494*77c1e3ccSAndroid Build Coastguard Worker   if (blk_idx > 1) {
2495*77c1e3ccSAndroid Build Coastguard Worker     // Deviation of means.
2496*77c1e3ccSAndroid Build Coastguard Worker     *dev_of_mean = get_dev(av_mean, mean2_sum, (blk_idx + 1));
2497*77c1e3ccSAndroid Build Coastguard Worker     // Variance of variances.
2498*77c1e3ccSAndroid Build Coastguard Worker     const float mean_var = var_sum / (blk_idx + 1);
2499*77c1e3ccSAndroid Build Coastguard Worker     *var_of_vars = get_var(mean_var, var2_sum, (blk_idx + 1));
2500*77c1e3ccSAndroid Build Coastguard Worker   }
2501*77c1e3ccSAndroid Build Coastguard Worker }
2502*77c1e3ccSAndroid Build Coastguard Worker 
prune_tx_split_no_split(MACROBLOCK * x,BLOCK_SIZE bsize,int blk_row,int blk_col,TX_SIZE tx_size,int * try_no_split,int * try_split,int pruning_level)2503*77c1e3ccSAndroid Build Coastguard Worker static void prune_tx_split_no_split(MACROBLOCK *x, BLOCK_SIZE bsize,
2504*77c1e3ccSAndroid Build Coastguard Worker                                     int blk_row, int blk_col, TX_SIZE tx_size,
2505*77c1e3ccSAndroid Build Coastguard Worker                                     int *try_no_split, int *try_split,
2506*77c1e3ccSAndroid Build Coastguard Worker                                     int pruning_level) {
2507*77c1e3ccSAndroid Build Coastguard Worker   const int diff_stride = block_size_wide[bsize];
2508*77c1e3ccSAndroid Build Coastguard Worker   const int16_t *diff =
2509*77c1e3ccSAndroid Build Coastguard Worker       x->plane[0].src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
2510*77c1e3ccSAndroid Build Coastguard Worker   const int bw = tx_size_wide[tx_size];
2511*77c1e3ccSAndroid Build Coastguard Worker   const int bh = tx_size_high[tx_size];
2512*77c1e3ccSAndroid Build Coastguard Worker   float dev_of_means = 0.0f;
2513*77c1e3ccSAndroid Build Coastguard Worker   float var_of_vars = 0.0f;
2514*77c1e3ccSAndroid Build Coastguard Worker 
2515*77c1e3ccSAndroid Build Coastguard Worker   // This function calculates the deviation of means, and the variance of pixel
2516*77c1e3ccSAndroid Build Coastguard Worker   // variances of the block as well as it's sub-blocks.
2517*77c1e3ccSAndroid Build Coastguard Worker   get_blk_var_dev(diff, diff_stride, bw, bh, &dev_of_means, &var_of_vars);
2518*77c1e3ccSAndroid Build Coastguard Worker   const int dc_q = x->plane[0].dequant_QTX[0] >> 3;
2519*77c1e3ccSAndroid Build Coastguard Worker   const int ac_q = x->plane[0].dequant_QTX[1] >> 3;
2520*77c1e3ccSAndroid Build Coastguard Worker   const int no_split_thresh_scales[4] = { 0, 24, 8, 8 };
2521*77c1e3ccSAndroid Build Coastguard Worker   const int no_split_thresh_scale = no_split_thresh_scales[pruning_level];
2522*77c1e3ccSAndroid Build Coastguard Worker   const int split_thresh_scales[4] = { 0, 24, 10, 8 };
2523*77c1e3ccSAndroid Build Coastguard Worker   const int split_thresh_scale = split_thresh_scales[pruning_level];
2524*77c1e3ccSAndroid Build Coastguard Worker 
2525*77c1e3ccSAndroid Build Coastguard Worker   if ((dev_of_means <= dc_q) &&
2526*77c1e3ccSAndroid Build Coastguard Worker       (split_thresh_scale * var_of_vars <= ac_q * ac_q)) {
2527*77c1e3ccSAndroid Build Coastguard Worker     *try_split = 0;
2528*77c1e3ccSAndroid Build Coastguard Worker   }
2529*77c1e3ccSAndroid Build Coastguard Worker   if ((dev_of_means > no_split_thresh_scale * dc_q) &&
2530*77c1e3ccSAndroid Build Coastguard Worker       (var_of_vars > no_split_thresh_scale * ac_q * ac_q)) {
2531*77c1e3ccSAndroid Build Coastguard Worker     *try_no_split = 0;
2532*77c1e3ccSAndroid Build Coastguard Worker   }
2533*77c1e3ccSAndroid Build Coastguard Worker }
2534*77c1e3ccSAndroid Build Coastguard Worker 
2535*77c1e3ccSAndroid Build Coastguard Worker // Search for the best transform partition(recursive)/type for a given
2536*77c1e3ccSAndroid Build Coastguard Worker // inter-predicted luma block. The obtained transform selection will be saved
2537*77c1e3ccSAndroid Build Coastguard Worker // in xd->mi[0], the corresponding RD stats will be saved in rd_stats.
select_tx_block(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,int depth,BLOCK_SIZE plane_bsize,ENTROPY_CONTEXT * ta,ENTROPY_CONTEXT * tl,TXFM_CONTEXT * tx_above,TXFM_CONTEXT * tx_left,RD_STATS * rd_stats,int64_t prev_level_rd,int64_t ref_best_rd,int * is_cost_valid,FAST_TX_SEARCH_MODE ftxs_mode)2538*77c1e3ccSAndroid Build Coastguard Worker static inline void select_tx_block(
2539*77c1e3ccSAndroid Build Coastguard Worker     const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
2540*77c1e3ccSAndroid Build Coastguard Worker     TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
2541*77c1e3ccSAndroid Build Coastguard Worker     ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
2542*77c1e3ccSAndroid Build Coastguard Worker     RD_STATS *rd_stats, int64_t prev_level_rd, int64_t ref_best_rd,
2543*77c1e3ccSAndroid Build Coastguard Worker     int *is_cost_valid, FAST_TX_SEARCH_MODE ftxs_mode) {
2544*77c1e3ccSAndroid Build Coastguard Worker   assert(tx_size < TX_SIZES_ALL);
2545*77c1e3ccSAndroid Build Coastguard Worker   av1_init_rd_stats(rd_stats);
2546*77c1e3ccSAndroid Build Coastguard Worker   if (ref_best_rd < 0) {
2547*77c1e3ccSAndroid Build Coastguard Worker     *is_cost_valid = 0;
2548*77c1e3ccSAndroid Build Coastguard Worker     return;
2549*77c1e3ccSAndroid Build Coastguard Worker   }
2550*77c1e3ccSAndroid Build Coastguard Worker 
2551*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCKD *const xd = &x->e_mbd;
2552*77c1e3ccSAndroid Build Coastguard Worker   assert(blk_row < max_block_high(xd, plane_bsize, 0) &&
2553*77c1e3ccSAndroid Build Coastguard Worker          blk_col < max_block_wide(xd, plane_bsize, 0));
2554*77c1e3ccSAndroid Build Coastguard Worker   MB_MODE_INFO *const mbmi = xd->mi[0];
2555*77c1e3ccSAndroid Build Coastguard Worker   const int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
2556*77c1e3ccSAndroid Build Coastguard Worker                                          mbmi->bsize, tx_size);
2557*77c1e3ccSAndroid Build Coastguard Worker   struct macroblock_plane *const p = &x->plane[0];
2558*77c1e3ccSAndroid Build Coastguard Worker 
2559*77c1e3ccSAndroid Build Coastguard Worker   int try_no_split = (cpi->oxcf.txfm_cfg.enable_tx64 ||
2560*77c1e3ccSAndroid Build Coastguard Worker                       txsize_sqr_up_map[tx_size] != TX_64X64) &&
2561*77c1e3ccSAndroid Build Coastguard Worker                      (cpi->oxcf.txfm_cfg.enable_rect_tx ||
2562*77c1e3ccSAndroid Build Coastguard Worker                       tx_size_wide[tx_size] == tx_size_high[tx_size]);
2563*77c1e3ccSAndroid Build Coastguard Worker   int try_split = tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH;
2564*77c1e3ccSAndroid Build Coastguard Worker   TxCandidateInfo no_split = { INT64_MAX, 0, TX_TYPES };
2565*77c1e3ccSAndroid Build Coastguard Worker 
2566*77c1e3ccSAndroid Build Coastguard Worker   // Prune tx_split and no-split based on sub-block properties.
2567*77c1e3ccSAndroid Build Coastguard Worker   if (tx_size != TX_4X4 && try_split == 1 && try_no_split == 1 &&
2568*77c1e3ccSAndroid Build Coastguard Worker       cpi->sf.tx_sf.prune_tx_size_level > 0) {
2569*77c1e3ccSAndroid Build Coastguard Worker     prune_tx_split_no_split(x, plane_bsize, blk_row, blk_col, tx_size,
2570*77c1e3ccSAndroid Build Coastguard Worker                             &try_no_split, &try_split,
2571*77c1e3ccSAndroid Build Coastguard Worker                             cpi->sf.tx_sf.prune_tx_size_level);
2572*77c1e3ccSAndroid Build Coastguard Worker   }
2573*77c1e3ccSAndroid Build Coastguard Worker 
2574*77c1e3ccSAndroid Build Coastguard Worker   if (cpi->sf.rt_sf.skip_tx_no_split_var_based_partition) {
2575*77c1e3ccSAndroid Build Coastguard Worker     if (x->try_merge_partition && try_split && p->eobs[block]) try_no_split = 0;
2576*77c1e3ccSAndroid Build Coastguard Worker   }
2577*77c1e3ccSAndroid Build Coastguard Worker 
2578*77c1e3ccSAndroid Build Coastguard Worker   // Try using current block as a single transform block without split.
2579*77c1e3ccSAndroid Build Coastguard Worker   if (try_no_split) {
2580*77c1e3ccSAndroid Build Coastguard Worker     try_tx_block_no_split(cpi, x, blk_row, blk_col, block, tx_size, depth,
2581*77c1e3ccSAndroid Build Coastguard Worker                           plane_bsize, ta, tl, ctx, rd_stats, ref_best_rd,
2582*77c1e3ccSAndroid Build Coastguard Worker                           ftxs_mode, &no_split);
2583*77c1e3ccSAndroid Build Coastguard Worker 
2584*77c1e3ccSAndroid Build Coastguard Worker     // Speed features for early termination.
2585*77c1e3ccSAndroid Build Coastguard Worker     const int search_level = cpi->sf.tx_sf.adaptive_txb_search_level;
2586*77c1e3ccSAndroid Build Coastguard Worker     if (search_level) {
2587*77c1e3ccSAndroid Build Coastguard Worker       if ((no_split.rd - (no_split.rd >> (1 + search_level))) > ref_best_rd) {
2588*77c1e3ccSAndroid Build Coastguard Worker         *is_cost_valid = 0;
2589*77c1e3ccSAndroid Build Coastguard Worker         return;
2590*77c1e3ccSAndroid Build Coastguard Worker       }
2591*77c1e3ccSAndroid Build Coastguard Worker       if (no_split.rd - (no_split.rd >> (2 + search_level)) > prev_level_rd) {
2592*77c1e3ccSAndroid Build Coastguard Worker         try_split = 0;
2593*77c1e3ccSAndroid Build Coastguard Worker       }
2594*77c1e3ccSAndroid Build Coastguard Worker     }
2595*77c1e3ccSAndroid Build Coastguard Worker     if (cpi->sf.tx_sf.txb_split_cap) {
2596*77c1e3ccSAndroid Build Coastguard Worker       if (p->eobs[block] == 0) try_split = 0;
2597*77c1e3ccSAndroid Build Coastguard Worker     }
2598*77c1e3ccSAndroid Build Coastguard Worker   }
2599*77c1e3ccSAndroid Build Coastguard Worker 
2600*77c1e3ccSAndroid Build Coastguard Worker   // ML based speed feature to skip searching for split transform blocks.
2601*77c1e3ccSAndroid Build Coastguard Worker   if (x->e_mbd.bd == 8 && try_split &&
2602*77c1e3ccSAndroid Build Coastguard Worker       !(ref_best_rd == INT64_MAX && no_split.rd == INT64_MAX)) {
2603*77c1e3ccSAndroid Build Coastguard Worker     const int threshold = cpi->sf.tx_sf.tx_type_search.ml_tx_split_thresh;
2604*77c1e3ccSAndroid Build Coastguard Worker     if (threshold >= 0) {
2605*77c1e3ccSAndroid Build Coastguard Worker       const int split_score =
2606*77c1e3ccSAndroid Build Coastguard Worker           ml_predict_tx_split(x, plane_bsize, blk_row, blk_col, tx_size);
2607*77c1e3ccSAndroid Build Coastguard Worker       if (split_score < -threshold) try_split = 0;
2608*77c1e3ccSAndroid Build Coastguard Worker     }
2609*77c1e3ccSAndroid Build Coastguard Worker   }
2610*77c1e3ccSAndroid Build Coastguard Worker 
2611*77c1e3ccSAndroid Build Coastguard Worker   RD_STATS split_rd_stats;
2612*77c1e3ccSAndroid Build Coastguard Worker   split_rd_stats.rdcost = INT64_MAX;
2613*77c1e3ccSAndroid Build Coastguard Worker   // Try splitting current block into smaller transform blocks.
2614*77c1e3ccSAndroid Build Coastguard Worker   if (try_split) {
2615*77c1e3ccSAndroid Build Coastguard Worker     try_tx_block_split(cpi, x, blk_row, blk_col, block, tx_size, depth,
2616*77c1e3ccSAndroid Build Coastguard Worker                        plane_bsize, ta, tl, tx_above, tx_left, ctx, no_split.rd,
2617*77c1e3ccSAndroid Build Coastguard Worker                        AOMMIN(no_split.rd, ref_best_rd), ftxs_mode,
2618*77c1e3ccSAndroid Build Coastguard Worker                        &split_rd_stats);
2619*77c1e3ccSAndroid Build Coastguard Worker   }
2620*77c1e3ccSAndroid Build Coastguard Worker 
2621*77c1e3ccSAndroid Build Coastguard Worker   if (no_split.rd < split_rd_stats.rdcost) {
2622*77c1e3ccSAndroid Build Coastguard Worker     ENTROPY_CONTEXT *pta = ta + blk_col;
2623*77c1e3ccSAndroid Build Coastguard Worker     ENTROPY_CONTEXT *ptl = tl + blk_row;
2624*77c1e3ccSAndroid Build Coastguard Worker     p->txb_entropy_ctx[block] = no_split.txb_entropy_ctx;
2625*77c1e3ccSAndroid Build Coastguard Worker     av1_set_txb_context(x, 0, block, tx_size, pta, ptl);
2626*77c1e3ccSAndroid Build Coastguard Worker     txfm_partition_update(tx_above + blk_col, tx_left + blk_row, tx_size,
2627*77c1e3ccSAndroid Build Coastguard Worker                           tx_size);
2628*77c1e3ccSAndroid Build Coastguard Worker     for (int idy = 0; idy < tx_size_high_unit[tx_size]; ++idy) {
2629*77c1e3ccSAndroid Build Coastguard Worker       for (int idx = 0; idx < tx_size_wide_unit[tx_size]; ++idx) {
2630*77c1e3ccSAndroid Build Coastguard Worker         const int index =
2631*77c1e3ccSAndroid Build Coastguard Worker             av1_get_txb_size_index(plane_bsize, blk_row + idy, blk_col + idx);
2632*77c1e3ccSAndroid Build Coastguard Worker         mbmi->inter_tx_size[index] = tx_size;
2633*77c1e3ccSAndroid Build Coastguard Worker       }
2634*77c1e3ccSAndroid Build Coastguard Worker     }
2635*77c1e3ccSAndroid Build Coastguard Worker     mbmi->tx_size = tx_size;
2636*77c1e3ccSAndroid Build Coastguard Worker     update_txk_array(xd, blk_row, blk_col, tx_size, no_split.tx_type);
2637*77c1e3ccSAndroid Build Coastguard Worker     const int bw = mi_size_wide[plane_bsize];
2638*77c1e3ccSAndroid Build Coastguard Worker     set_blk_skip(x->txfm_search_info.blk_skip, 0, blk_row * bw + blk_col,
2639*77c1e3ccSAndroid Build Coastguard Worker                  rd_stats->skip_txfm);
2640*77c1e3ccSAndroid Build Coastguard Worker   } else {
2641*77c1e3ccSAndroid Build Coastguard Worker     *rd_stats = split_rd_stats;
2642*77c1e3ccSAndroid Build Coastguard Worker     if (split_rd_stats.rdcost == INT64_MAX) *is_cost_valid = 0;
2643*77c1e3ccSAndroid Build Coastguard Worker   }
2644*77c1e3ccSAndroid Build Coastguard Worker }
2645*77c1e3ccSAndroid Build Coastguard Worker 
choose_largest_tx_size(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs)2646*77c1e3ccSAndroid Build Coastguard Worker static inline void choose_largest_tx_size(const AV1_COMP *const cpi,
2647*77c1e3ccSAndroid Build Coastguard Worker                                           MACROBLOCK *x, RD_STATS *rd_stats,
2648*77c1e3ccSAndroid Build Coastguard Worker                                           int64_t ref_best_rd, BLOCK_SIZE bs) {
2649*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCKD *const xd = &x->e_mbd;
2650*77c1e3ccSAndroid Build Coastguard Worker   MB_MODE_INFO *const mbmi = xd->mi[0];
2651*77c1e3ccSAndroid Build Coastguard Worker   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
2652*77c1e3ccSAndroid Build Coastguard Worker   mbmi->tx_size = tx_size_from_tx_mode(bs, txfm_params->tx_mode_search_type);
2653*77c1e3ccSAndroid Build Coastguard Worker 
2654*77c1e3ccSAndroid Build Coastguard Worker   // If tx64 is not enabled, we need to go down to the next available size
2655*77c1e3ccSAndroid Build Coastguard Worker   if (!cpi->oxcf.txfm_cfg.enable_tx64 && cpi->oxcf.txfm_cfg.enable_rect_tx) {
2656*77c1e3ccSAndroid Build Coastguard Worker     static const TX_SIZE tx_size_max_32[TX_SIZES_ALL] = {
2657*77c1e3ccSAndroid Build Coastguard Worker       TX_4X4,    // 4x4 transform
2658*77c1e3ccSAndroid Build Coastguard Worker       TX_8X8,    // 8x8 transform
2659*77c1e3ccSAndroid Build Coastguard Worker       TX_16X16,  // 16x16 transform
2660*77c1e3ccSAndroid Build Coastguard Worker       TX_32X32,  // 32x32 transform
2661*77c1e3ccSAndroid Build Coastguard Worker       TX_32X32,  // 64x64 transform
2662*77c1e3ccSAndroid Build Coastguard Worker       TX_4X8,    // 4x8 transform
2663*77c1e3ccSAndroid Build Coastguard Worker       TX_8X4,    // 8x4 transform
2664*77c1e3ccSAndroid Build Coastguard Worker       TX_8X16,   // 8x16 transform
2665*77c1e3ccSAndroid Build Coastguard Worker       TX_16X8,   // 16x8 transform
2666*77c1e3ccSAndroid Build Coastguard Worker       TX_16X32,  // 16x32 transform
2667*77c1e3ccSAndroid Build Coastguard Worker       TX_32X16,  // 32x16 transform
2668*77c1e3ccSAndroid Build Coastguard Worker       TX_32X32,  // 32x64 transform
2669*77c1e3ccSAndroid Build Coastguard Worker       TX_32X32,  // 64x32 transform
2670*77c1e3ccSAndroid Build Coastguard Worker       TX_4X16,   // 4x16 transform
2671*77c1e3ccSAndroid Build Coastguard Worker       TX_16X4,   // 16x4 transform
2672*77c1e3ccSAndroid Build Coastguard Worker       TX_8X32,   // 8x32 transform
2673*77c1e3ccSAndroid Build Coastguard Worker       TX_32X8,   // 32x8 transform
2674*77c1e3ccSAndroid Build Coastguard Worker       TX_16X32,  // 16x64 transform
2675*77c1e3ccSAndroid Build Coastguard Worker       TX_32X16,  // 64x16 transform
2676*77c1e3ccSAndroid Build Coastguard Worker     };
2677*77c1e3ccSAndroid Build Coastguard Worker     mbmi->tx_size = tx_size_max_32[mbmi->tx_size];
2678*77c1e3ccSAndroid Build Coastguard Worker   } else if (cpi->oxcf.txfm_cfg.enable_tx64 &&
2679*77c1e3ccSAndroid Build Coastguard Worker              !cpi->oxcf.txfm_cfg.enable_rect_tx) {
2680*77c1e3ccSAndroid Build Coastguard Worker     static const TX_SIZE tx_size_max_square[TX_SIZES_ALL] = {
2681*77c1e3ccSAndroid Build Coastguard Worker       TX_4X4,    // 4x4 transform
2682*77c1e3ccSAndroid Build Coastguard Worker       TX_8X8,    // 8x8 transform
2683*77c1e3ccSAndroid Build Coastguard Worker       TX_16X16,  // 16x16 transform
2684*77c1e3ccSAndroid Build Coastguard Worker       TX_32X32,  // 32x32 transform
2685*77c1e3ccSAndroid Build Coastguard Worker       TX_64X64,  // 64x64 transform
2686*77c1e3ccSAndroid Build Coastguard Worker       TX_4X4,    // 4x8 transform
2687*77c1e3ccSAndroid Build Coastguard Worker       TX_4X4,    // 8x4 transform
2688*77c1e3ccSAndroid Build Coastguard Worker       TX_8X8,    // 8x16 transform
2689*77c1e3ccSAndroid Build Coastguard Worker       TX_8X8,    // 16x8 transform
2690*77c1e3ccSAndroid Build Coastguard Worker       TX_16X16,  // 16x32 transform
2691*77c1e3ccSAndroid Build Coastguard Worker       TX_16X16,  // 32x16 transform
2692*77c1e3ccSAndroid Build Coastguard Worker       TX_32X32,  // 32x64 transform
2693*77c1e3ccSAndroid Build Coastguard Worker       TX_32X32,  // 64x32 transform
2694*77c1e3ccSAndroid Build Coastguard Worker       TX_4X4,    // 4x16 transform
2695*77c1e3ccSAndroid Build Coastguard Worker       TX_4X4,    // 16x4 transform
2696*77c1e3ccSAndroid Build Coastguard Worker       TX_8X8,    // 8x32 transform
2697*77c1e3ccSAndroid Build Coastguard Worker       TX_8X8,    // 32x8 transform
2698*77c1e3ccSAndroid Build Coastguard Worker       TX_16X16,  // 16x64 transform
2699*77c1e3ccSAndroid Build Coastguard Worker       TX_16X16,  // 64x16 transform
2700*77c1e3ccSAndroid Build Coastguard Worker     };
2701*77c1e3ccSAndroid Build Coastguard Worker     mbmi->tx_size = tx_size_max_square[mbmi->tx_size];
2702*77c1e3ccSAndroid Build Coastguard Worker   } else if (!cpi->oxcf.txfm_cfg.enable_tx64 &&
2703*77c1e3ccSAndroid Build Coastguard Worker              !cpi->oxcf.txfm_cfg.enable_rect_tx) {
2704*77c1e3ccSAndroid Build Coastguard Worker     static const TX_SIZE tx_size_max_32_square[TX_SIZES_ALL] = {
2705*77c1e3ccSAndroid Build Coastguard Worker       TX_4X4,    // 4x4 transform
2706*77c1e3ccSAndroid Build Coastguard Worker       TX_8X8,    // 8x8 transform
2707*77c1e3ccSAndroid Build Coastguard Worker       TX_16X16,  // 16x16 transform
2708*77c1e3ccSAndroid Build Coastguard Worker       TX_32X32,  // 32x32 transform
2709*77c1e3ccSAndroid Build Coastguard Worker       TX_32X32,  // 64x64 transform
2710*77c1e3ccSAndroid Build Coastguard Worker       TX_4X4,    // 4x8 transform
2711*77c1e3ccSAndroid Build Coastguard Worker       TX_4X4,    // 8x4 transform
2712*77c1e3ccSAndroid Build Coastguard Worker       TX_8X8,    // 8x16 transform
2713*77c1e3ccSAndroid Build Coastguard Worker       TX_8X8,    // 16x8 transform
2714*77c1e3ccSAndroid Build Coastguard Worker       TX_16X16,  // 16x32 transform
2715*77c1e3ccSAndroid Build Coastguard Worker       TX_16X16,  // 32x16 transform
2716*77c1e3ccSAndroid Build Coastguard Worker       TX_32X32,  // 32x64 transform
2717*77c1e3ccSAndroid Build Coastguard Worker       TX_32X32,  // 64x32 transform
2718*77c1e3ccSAndroid Build Coastguard Worker       TX_4X4,    // 4x16 transform
2719*77c1e3ccSAndroid Build Coastguard Worker       TX_4X4,    // 16x4 transform
2720*77c1e3ccSAndroid Build Coastguard Worker       TX_8X8,    // 8x32 transform
2721*77c1e3ccSAndroid Build Coastguard Worker       TX_8X8,    // 32x8 transform
2722*77c1e3ccSAndroid Build Coastguard Worker       TX_16X16,  // 16x64 transform
2723*77c1e3ccSAndroid Build Coastguard Worker       TX_16X16,  // 64x16 transform
2724*77c1e3ccSAndroid Build Coastguard Worker     };
2725*77c1e3ccSAndroid Build Coastguard Worker 
2726*77c1e3ccSAndroid Build Coastguard Worker     mbmi->tx_size = tx_size_max_32_square[mbmi->tx_size];
2727*77c1e3ccSAndroid Build Coastguard Worker   }
2728*77c1e3ccSAndroid Build Coastguard Worker 
2729*77c1e3ccSAndroid Build Coastguard Worker   const int skip_ctx = av1_get_skip_txfm_context(xd);
2730*77c1e3ccSAndroid Build Coastguard Worker   const int no_skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][0];
2731*77c1e3ccSAndroid Build Coastguard Worker   const int skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][1];
2732*77c1e3ccSAndroid Build Coastguard Worker   // Skip RDcost is used only for Inter blocks
2733*77c1e3ccSAndroid Build Coastguard Worker   const int64_t skip_txfm_rd =
2734*77c1e3ccSAndroid Build Coastguard Worker       is_inter_block(mbmi) ? RDCOST(x->rdmult, skip_txfm_rate, 0) : INT64_MAX;
2735*77c1e3ccSAndroid Build Coastguard Worker   const int64_t no_skip_txfm_rd = RDCOST(x->rdmult, no_skip_txfm_rate, 0);
2736*77c1e3ccSAndroid Build Coastguard Worker   const int skip_trellis = 0;
2737*77c1e3ccSAndroid Build Coastguard Worker   av1_txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd,
2738*77c1e3ccSAndroid Build Coastguard Worker                        AOMMIN(no_skip_txfm_rd, skip_txfm_rd), AOM_PLANE_Y, bs,
2739*77c1e3ccSAndroid Build Coastguard Worker                        mbmi->tx_size, FTXS_NONE, skip_trellis);
2740*77c1e3ccSAndroid Build Coastguard Worker }
2741*77c1e3ccSAndroid Build Coastguard Worker 
choose_smallest_tx_size(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs)2742*77c1e3ccSAndroid Build Coastguard Worker static inline void choose_smallest_tx_size(const AV1_COMP *const cpi,
2743*77c1e3ccSAndroid Build Coastguard Worker                                            MACROBLOCK *x, RD_STATS *rd_stats,
2744*77c1e3ccSAndroid Build Coastguard Worker                                            int64_t ref_best_rd, BLOCK_SIZE bs) {
2745*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCKD *const xd = &x->e_mbd;
2746*77c1e3ccSAndroid Build Coastguard Worker   MB_MODE_INFO *const mbmi = xd->mi[0];
2747*77c1e3ccSAndroid Build Coastguard Worker 
2748*77c1e3ccSAndroid Build Coastguard Worker   mbmi->tx_size = TX_4X4;
2749*77c1e3ccSAndroid Build Coastguard Worker   // TODO(any) : Pass this_rd based on skip/non-skip cost
2750*77c1e3ccSAndroid Build Coastguard Worker   const int skip_trellis = 0;
2751*77c1e3ccSAndroid Build Coastguard Worker   av1_txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd, 0, 0, bs, mbmi->tx_size,
2752*77c1e3ccSAndroid Build Coastguard Worker                        FTXS_NONE, skip_trellis);
2753*77c1e3ccSAndroid Build Coastguard Worker }
2754*77c1e3ccSAndroid Build Coastguard Worker 
2755*77c1e3ccSAndroid Build Coastguard Worker #if !CONFIG_REALTIME_ONLY
ml_predict_intra_tx_depth_prune(MACROBLOCK * x,int blk_row,int blk_col,BLOCK_SIZE bsize,TX_SIZE tx_size)2756*77c1e3ccSAndroid Build Coastguard Worker static void ml_predict_intra_tx_depth_prune(MACROBLOCK *x, int blk_row,
2757*77c1e3ccSAndroid Build Coastguard Worker                                             int blk_col, BLOCK_SIZE bsize,
2758*77c1e3ccSAndroid Build Coastguard Worker                                             TX_SIZE tx_size) {
2759*77c1e3ccSAndroid Build Coastguard Worker   const MACROBLOCKD *const xd = &x->e_mbd;
2760*77c1e3ccSAndroid Build Coastguard Worker   const MB_MODE_INFO *const mbmi = xd->mi[0];
2761*77c1e3ccSAndroid Build Coastguard Worker 
2762*77c1e3ccSAndroid Build Coastguard Worker   // Disable the pruning logic using NN model for the following cases:
2763*77c1e3ccSAndroid Build Coastguard Worker   // 1) Lossless coding as only 4x4 transform is evaluated in this case
2764*77c1e3ccSAndroid Build Coastguard Worker   // 2) When transform and current block sizes do not match as the features are
2765*77c1e3ccSAndroid Build Coastguard Worker   // obtained over the current block
2766*77c1e3ccSAndroid Build Coastguard Worker   // 3) When operating bit-depth is not 8-bit as the input features are not
2767*77c1e3ccSAndroid Build Coastguard Worker   // scaled according to bit-depth.
2768*77c1e3ccSAndroid Build Coastguard Worker   if (xd->lossless[mbmi->segment_id] || txsize_to_bsize[tx_size] != bsize ||
2769*77c1e3ccSAndroid Build Coastguard Worker       xd->bd != 8)
2770*77c1e3ccSAndroid Build Coastguard Worker     return;
2771*77c1e3ccSAndroid Build Coastguard Worker 
2772*77c1e3ccSAndroid Build Coastguard Worker   // Currently NN model based pruning is supported only when largest transform
2773*77c1e3ccSAndroid Build Coastguard Worker   // size is 8x8
2774*77c1e3ccSAndroid Build Coastguard Worker   if (tx_size != TX_8X8) return;
2775*77c1e3ccSAndroid Build Coastguard Worker 
2776*77c1e3ccSAndroid Build Coastguard Worker   // Neural network model is a sequential neural net and was trained using SGD
2777*77c1e3ccSAndroid Build Coastguard Worker   // optimizer. The model can be further improved in terms of speed/quality by
2778*77c1e3ccSAndroid Build Coastguard Worker   // considering the following experiments:
2779*77c1e3ccSAndroid Build Coastguard Worker   // 1) Generate ML model by training with balanced data for different learning
2780*77c1e3ccSAndroid Build Coastguard Worker   // rates and optimizers.
2781*77c1e3ccSAndroid Build Coastguard Worker   // 2) Experiment with ML model by adding features related to the statistics of
2782*77c1e3ccSAndroid Build Coastguard Worker   // top and left pixels to capture the accuracy of reconstructed neighbouring
2783*77c1e3ccSAndroid Build Coastguard Worker   // pixels for 4x4 blocks numbered 1, 2, 3 in 8x8 block, source variance of 4x4
2784*77c1e3ccSAndroid Build Coastguard Worker   // sub-blocks, etc.
2785*77c1e3ccSAndroid Build Coastguard Worker   // 3) Generate ML models for transform blocks other than 8x8.
2786*77c1e3ccSAndroid Build Coastguard Worker   const NN_CONFIG *const nn_config = &av1_intra_tx_split_nnconfig_8x8;
2787*77c1e3ccSAndroid Build Coastguard Worker   const float *const intra_tx_prune_thresh = av1_intra_tx_prune_nn_thresh_8x8;
2788*77c1e3ccSAndroid Build Coastguard Worker 
2789*77c1e3ccSAndroid Build Coastguard Worker   float features[NUM_INTRA_TX_SPLIT_FEATURES] = { 0.0f };
2790*77c1e3ccSAndroid Build Coastguard Worker   const int diff_stride = block_size_wide[bsize];
2791*77c1e3ccSAndroid Build Coastguard Worker 
2792*77c1e3ccSAndroid Build Coastguard Worker   const int16_t *diff = x->plane[0].src_diff + MI_SIZE * blk_row * diff_stride +
2793*77c1e3ccSAndroid Build Coastguard Worker                         MI_SIZE * blk_col;
2794*77c1e3ccSAndroid Build Coastguard Worker   const int bw = tx_size_wide[tx_size];
2795*77c1e3ccSAndroid Build Coastguard Worker   const int bh = tx_size_high[tx_size];
2796*77c1e3ccSAndroid Build Coastguard Worker 
2797*77c1e3ccSAndroid Build Coastguard Worker   int feature_idx = get_mean_dev_features(diff, diff_stride, bw, bh, features);
2798*77c1e3ccSAndroid Build Coastguard Worker 
2799*77c1e3ccSAndroid Build Coastguard Worker   features[feature_idx++] = log1pf((float)x->source_variance);
2800*77c1e3ccSAndroid Build Coastguard Worker 
2801*77c1e3ccSAndroid Build Coastguard Worker   const int dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd) >> (xd->bd - 8);
2802*77c1e3ccSAndroid Build Coastguard Worker   const float log_dc_q_square = log1pf((float)(dc_q * dc_q) / 256.0f);
2803*77c1e3ccSAndroid Build Coastguard Worker   features[feature_idx++] = log_dc_q_square;
2804*77c1e3ccSAndroid Build Coastguard Worker   assert(feature_idx == NUM_INTRA_TX_SPLIT_FEATURES);
2805*77c1e3ccSAndroid Build Coastguard Worker   for (int i = 0; i < NUM_INTRA_TX_SPLIT_FEATURES; i++) {
2806*77c1e3ccSAndroid Build Coastguard Worker     features[i] = (features[i] - av1_intra_tx_split_8x8_mean[i]) /
2807*77c1e3ccSAndroid Build Coastguard Worker                   av1_intra_tx_split_8x8_std[i];
2808*77c1e3ccSAndroid Build Coastguard Worker   }
2809*77c1e3ccSAndroid Build Coastguard Worker 
2810*77c1e3ccSAndroid Build Coastguard Worker   float score;
2811*77c1e3ccSAndroid Build Coastguard Worker   av1_nn_predict(features, nn_config, 1, &score);
2812*77c1e3ccSAndroid Build Coastguard Worker 
2813*77c1e3ccSAndroid Build Coastguard Worker   TxfmSearchParams *const txfm_params = &x->txfm_search_params;
2814*77c1e3ccSAndroid Build Coastguard Worker   if (score <= intra_tx_prune_thresh[0])
2815*77c1e3ccSAndroid Build Coastguard Worker     txfm_params->nn_prune_depths_for_intra_tx = TX_PRUNE_SPLIT;
2816*77c1e3ccSAndroid Build Coastguard Worker   else if (score > intra_tx_prune_thresh[1])
2817*77c1e3ccSAndroid Build Coastguard Worker     txfm_params->nn_prune_depths_for_intra_tx = TX_PRUNE_LARGEST;
2818*77c1e3ccSAndroid Build Coastguard Worker }
2819*77c1e3ccSAndroid Build Coastguard Worker #endif  // !CONFIG_REALTIME_ONLY
2820*77c1e3ccSAndroid Build Coastguard Worker 
2821*77c1e3ccSAndroid Build Coastguard Worker /*!\brief Transform type search for luma macroblock with fixed transform size.
2822*77c1e3ccSAndroid Build Coastguard Worker  *
2823*77c1e3ccSAndroid Build Coastguard Worker  * \ingroup transform_search
2824*77c1e3ccSAndroid Build Coastguard Worker  * Search for the best transform type and return the transform coefficients RD
2825*77c1e3ccSAndroid Build Coastguard Worker  * cost of current luma macroblock with the given uniform transform size.
2826*77c1e3ccSAndroid Build Coastguard Worker  *
2827*77c1e3ccSAndroid Build Coastguard Worker  * \param[in]    x              Pointer to structure holding the data for the
2828*77c1e3ccSAndroid Build Coastguard Worker                                 current encoding macroblock
2829*77c1e3ccSAndroid Build Coastguard Worker  * \param[in]    cpi            Top-level encoder structure
2830*77c1e3ccSAndroid Build Coastguard Worker  * \param[in]    rd_stats       Pointer to struct to keep track of the RD stats
2831*77c1e3ccSAndroid Build Coastguard Worker  * \param[in]    ref_best_rd    Best RD cost seen for this block so far
2832*77c1e3ccSAndroid Build Coastguard Worker  * \param[in]    bs             Size of the current macroblock
2833*77c1e3ccSAndroid Build Coastguard Worker  * \param[in]    tx_size        The given transform size
2834*77c1e3ccSAndroid Build Coastguard Worker  * \param[in]    ftxs_mode      Transform search mode specifying desired speed
2835*77c1e3ccSAndroid Build Coastguard Worker                                 and quality tradeoff
2836*77c1e3ccSAndroid Build Coastguard Worker  * \param[in]    skip_trellis   Binary flag indicating if trellis optimization
2837*77c1e3ccSAndroid Build Coastguard Worker                                 should be skipped
2838*77c1e3ccSAndroid Build Coastguard Worker  * \return       An int64_t value that is the best RD cost found.
2839*77c1e3ccSAndroid Build Coastguard Worker  */
uniform_txfm_yrd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs,TX_SIZE tx_size,FAST_TX_SEARCH_MODE ftxs_mode,int skip_trellis)2840*77c1e3ccSAndroid Build Coastguard Worker static int64_t uniform_txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
2841*77c1e3ccSAndroid Build Coastguard Worker                                 RD_STATS *rd_stats, int64_t ref_best_rd,
2842*77c1e3ccSAndroid Build Coastguard Worker                                 BLOCK_SIZE bs, TX_SIZE tx_size,
2843*77c1e3ccSAndroid Build Coastguard Worker                                 FAST_TX_SEARCH_MODE ftxs_mode,
2844*77c1e3ccSAndroid Build Coastguard Worker                                 int skip_trellis) {
2845*77c1e3ccSAndroid Build Coastguard Worker   assert(IMPLIES(is_rect_tx(tx_size), is_rect_tx_allowed_bsize(bs)));
2846*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCKD *const xd = &x->e_mbd;
2847*77c1e3ccSAndroid Build Coastguard Worker   MB_MODE_INFO *const mbmi = xd->mi[0];
2848*77c1e3ccSAndroid Build Coastguard Worker   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
2849*77c1e3ccSAndroid Build Coastguard Worker   const ModeCosts *mode_costs = &x->mode_costs;
2850*77c1e3ccSAndroid Build Coastguard Worker   const int is_inter = is_inter_block(mbmi);
2851*77c1e3ccSAndroid Build Coastguard Worker   const int tx_select = txfm_params->tx_mode_search_type == TX_MODE_SELECT &&
2852*77c1e3ccSAndroid Build Coastguard Worker                         block_signals_txsize(mbmi->bsize);
2853*77c1e3ccSAndroid Build Coastguard Worker   int tx_size_rate = 0;
2854*77c1e3ccSAndroid Build Coastguard Worker   if (tx_select) {
2855*77c1e3ccSAndroid Build Coastguard Worker     const int ctx = txfm_partition_context(
2856*77c1e3ccSAndroid Build Coastguard Worker         xd->above_txfm_context, xd->left_txfm_context, mbmi->bsize, tx_size);
2857*77c1e3ccSAndroid Build Coastguard Worker     tx_size_rate = is_inter ? mode_costs->txfm_partition_cost[ctx][0]
2858*77c1e3ccSAndroid Build Coastguard Worker                             : tx_size_cost(x, bs, tx_size);
2859*77c1e3ccSAndroid Build Coastguard Worker   }
2860*77c1e3ccSAndroid Build Coastguard Worker   const int skip_ctx = av1_get_skip_txfm_context(xd);
2861*77c1e3ccSAndroid Build Coastguard Worker   const int no_skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][0];
2862*77c1e3ccSAndroid Build Coastguard Worker   const int skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][1];
2863*77c1e3ccSAndroid Build Coastguard Worker   const int64_t skip_txfm_rd =
2864*77c1e3ccSAndroid Build Coastguard Worker       is_inter ? RDCOST(x->rdmult, skip_txfm_rate, 0) : INT64_MAX;
2865*77c1e3ccSAndroid Build Coastguard Worker   const int64_t no_this_rd =
2866*77c1e3ccSAndroid Build Coastguard Worker       RDCOST(x->rdmult, no_skip_txfm_rate + tx_size_rate, 0);
2867*77c1e3ccSAndroid Build Coastguard Worker 
2868*77c1e3ccSAndroid Build Coastguard Worker   mbmi->tx_size = tx_size;
2869*77c1e3ccSAndroid Build Coastguard Worker   av1_txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd,
2870*77c1e3ccSAndroid Build Coastguard Worker                        AOMMIN(no_this_rd, skip_txfm_rd), AOM_PLANE_Y, bs,
2871*77c1e3ccSAndroid Build Coastguard Worker                        tx_size, ftxs_mode, skip_trellis);
2872*77c1e3ccSAndroid Build Coastguard Worker   if (rd_stats->rate == INT_MAX) return INT64_MAX;
2873*77c1e3ccSAndroid Build Coastguard Worker 
2874*77c1e3ccSAndroid Build Coastguard Worker   int64_t rd;
2875*77c1e3ccSAndroid Build Coastguard Worker   // rdstats->rate should include all the rate except skip/non-skip cost as the
2876*77c1e3ccSAndroid Build Coastguard Worker   // same is accounted in the caller functions after rd evaluation of all
2877*77c1e3ccSAndroid Build Coastguard Worker   // planes. However the decisions should be done after considering the
2878*77c1e3ccSAndroid Build Coastguard Worker   // skip/non-skip header cost
2879*77c1e3ccSAndroid Build Coastguard Worker   if (rd_stats->skip_txfm && is_inter) {
2880*77c1e3ccSAndroid Build Coastguard Worker     rd = RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
2881*77c1e3ccSAndroid Build Coastguard Worker   } else {
2882*77c1e3ccSAndroid Build Coastguard Worker     // Intra blocks are always signalled as non-skip
2883*77c1e3ccSAndroid Build Coastguard Worker     rd = RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_rate + tx_size_rate,
2884*77c1e3ccSAndroid Build Coastguard Worker                 rd_stats->dist);
2885*77c1e3ccSAndroid Build Coastguard Worker     rd_stats->rate += tx_size_rate;
2886*77c1e3ccSAndroid Build Coastguard Worker   }
2887*77c1e3ccSAndroid Build Coastguard Worker   // Check if forcing the block to skip transform leads to smaller RD cost.
2888*77c1e3ccSAndroid Build Coastguard Worker   if (is_inter && !rd_stats->skip_txfm && !xd->lossless[mbmi->segment_id]) {
2889*77c1e3ccSAndroid Build Coastguard Worker     int64_t temp_skip_txfm_rd =
2890*77c1e3ccSAndroid Build Coastguard Worker         RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
2891*77c1e3ccSAndroid Build Coastguard Worker     if (temp_skip_txfm_rd <= rd) {
2892*77c1e3ccSAndroid Build Coastguard Worker       rd = temp_skip_txfm_rd;
2893*77c1e3ccSAndroid Build Coastguard Worker       rd_stats->rate = 0;
2894*77c1e3ccSAndroid Build Coastguard Worker       rd_stats->dist = rd_stats->sse;
2895*77c1e3ccSAndroid Build Coastguard Worker       rd_stats->skip_txfm = 1;
2896*77c1e3ccSAndroid Build Coastguard Worker     }
2897*77c1e3ccSAndroid Build Coastguard Worker   }
2898*77c1e3ccSAndroid Build Coastguard Worker 
2899*77c1e3ccSAndroid Build Coastguard Worker   return rd;
2900*77c1e3ccSAndroid Build Coastguard Worker }
2901*77c1e3ccSAndroid Build Coastguard Worker 
2902*77c1e3ccSAndroid Build Coastguard Worker // Search for the best uniform transform size and type for current coding block.
choose_tx_size_type_from_rd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs)2903*77c1e3ccSAndroid Build Coastguard Worker static inline void choose_tx_size_type_from_rd(const AV1_COMP *const cpi,
2904*77c1e3ccSAndroid Build Coastguard Worker                                                MACROBLOCK *x,
2905*77c1e3ccSAndroid Build Coastguard Worker                                                RD_STATS *rd_stats,
2906*77c1e3ccSAndroid Build Coastguard Worker                                                int64_t ref_best_rd,
2907*77c1e3ccSAndroid Build Coastguard Worker                                                BLOCK_SIZE bs) {
2908*77c1e3ccSAndroid Build Coastguard Worker   av1_invalid_rd_stats(rd_stats);
2909*77c1e3ccSAndroid Build Coastguard Worker 
2910*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCKD *const xd = &x->e_mbd;
2911*77c1e3ccSAndroid Build Coastguard Worker   MB_MODE_INFO *const mbmi = xd->mi[0];
2912*77c1e3ccSAndroid Build Coastguard Worker   TxfmSearchParams *const txfm_params = &x->txfm_search_params;
2913*77c1e3ccSAndroid Build Coastguard Worker   const TX_SIZE max_rect_tx_size = max_txsize_rect_lookup[bs];
2914*77c1e3ccSAndroid Build Coastguard Worker   const int tx_select = txfm_params->tx_mode_search_type == TX_MODE_SELECT;
2915*77c1e3ccSAndroid Build Coastguard Worker   int start_tx;
2916*77c1e3ccSAndroid Build Coastguard Worker   // The split depth can be at most MAX_TX_DEPTH, so the init_depth controls
2917*77c1e3ccSAndroid Build Coastguard Worker   // how many times of splitting is allowed during the RD search.
2918*77c1e3ccSAndroid Build Coastguard Worker   int init_depth;
2919*77c1e3ccSAndroid Build Coastguard Worker 
2920*77c1e3ccSAndroid Build Coastguard Worker   if (tx_select) {
2921*77c1e3ccSAndroid Build Coastguard Worker     start_tx = max_rect_tx_size;
2922*77c1e3ccSAndroid Build Coastguard Worker     init_depth = get_search_init_depth(mi_size_wide[bs], mi_size_high[bs],
2923*77c1e3ccSAndroid Build Coastguard Worker                                        is_inter_block(mbmi), &cpi->sf,
2924*77c1e3ccSAndroid Build Coastguard Worker                                        txfm_params->tx_size_search_method);
2925*77c1e3ccSAndroid Build Coastguard Worker     if (init_depth == MAX_TX_DEPTH && !cpi->oxcf.txfm_cfg.enable_tx64 &&
2926*77c1e3ccSAndroid Build Coastguard Worker         txsize_sqr_up_map[start_tx] == TX_64X64) {
2927*77c1e3ccSAndroid Build Coastguard Worker       start_tx = sub_tx_size_map[start_tx];
2928*77c1e3ccSAndroid Build Coastguard Worker     }
2929*77c1e3ccSAndroid Build Coastguard Worker   } else {
2930*77c1e3ccSAndroid Build Coastguard Worker     const TX_SIZE chosen_tx_size =
2931*77c1e3ccSAndroid Build Coastguard Worker         tx_size_from_tx_mode(bs, txfm_params->tx_mode_search_type);
2932*77c1e3ccSAndroid Build Coastguard Worker     start_tx = chosen_tx_size;
2933*77c1e3ccSAndroid Build Coastguard Worker     init_depth = MAX_TX_DEPTH;
2934*77c1e3ccSAndroid Build Coastguard Worker   }
2935*77c1e3ccSAndroid Build Coastguard Worker 
2936*77c1e3ccSAndroid Build Coastguard Worker   const int skip_trellis = 0;
2937*77c1e3ccSAndroid Build Coastguard Worker   uint8_t best_txk_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
2938*77c1e3ccSAndroid Build Coastguard Worker   uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
2939*77c1e3ccSAndroid Build Coastguard Worker   TX_SIZE best_tx_size = max_rect_tx_size;
2940*77c1e3ccSAndroid Build Coastguard Worker   int64_t best_rd = INT64_MAX;
2941*77c1e3ccSAndroid Build Coastguard Worker   const int num_blks = bsize_to_num_blk(bs);
2942*77c1e3ccSAndroid Build Coastguard Worker   x->rd_model = FULL_TXFM_RD;
2943*77c1e3ccSAndroid Build Coastguard Worker   int64_t rd[MAX_TX_DEPTH + 1] = { INT64_MAX, INT64_MAX, INT64_MAX };
2944*77c1e3ccSAndroid Build Coastguard Worker   TxfmSearchInfo *txfm_info = &x->txfm_search_info;
2945*77c1e3ccSAndroid Build Coastguard Worker   for (int tx_size = start_tx, depth = init_depth; depth <= MAX_TX_DEPTH;
2946*77c1e3ccSAndroid Build Coastguard Worker        depth++, tx_size = sub_tx_size_map[tx_size]) {
2947*77c1e3ccSAndroid Build Coastguard Worker     if ((!cpi->oxcf.txfm_cfg.enable_tx64 &&
2948*77c1e3ccSAndroid Build Coastguard Worker          txsize_sqr_up_map[tx_size] == TX_64X64) ||
2949*77c1e3ccSAndroid Build Coastguard Worker         (!cpi->oxcf.txfm_cfg.enable_rect_tx &&
2950*77c1e3ccSAndroid Build Coastguard Worker          tx_size_wide[tx_size] != tx_size_high[tx_size])) {
2951*77c1e3ccSAndroid Build Coastguard Worker       continue;
2952*77c1e3ccSAndroid Build Coastguard Worker     }
2953*77c1e3ccSAndroid Build Coastguard Worker 
2954*77c1e3ccSAndroid Build Coastguard Worker #if !CONFIG_REALTIME_ONLY
2955*77c1e3ccSAndroid Build Coastguard Worker     if (txfm_params->nn_prune_depths_for_intra_tx == TX_PRUNE_SPLIT) break;
2956*77c1e3ccSAndroid Build Coastguard Worker 
2957*77c1e3ccSAndroid Build Coastguard Worker     // Set the flag to enable the evaluation of NN classifier to prune transform
2958*77c1e3ccSAndroid Build Coastguard Worker     // depths. As the features are based on intra residual information of
2959*77c1e3ccSAndroid Build Coastguard Worker     // largest transform, the evaluation of NN model is enabled only for this
2960*77c1e3ccSAndroid Build Coastguard Worker     // case.
2961*77c1e3ccSAndroid Build Coastguard Worker     txfm_params->enable_nn_prune_intra_tx_depths =
2962*77c1e3ccSAndroid Build Coastguard Worker         (cpi->sf.tx_sf.prune_intra_tx_depths_using_nn && tx_size == start_tx);
2963*77c1e3ccSAndroid Build Coastguard Worker #endif
2964*77c1e3ccSAndroid Build Coastguard Worker 
2965*77c1e3ccSAndroid Build Coastguard Worker     RD_STATS this_rd_stats;
2966*77c1e3ccSAndroid Build Coastguard Worker     // When the speed feature use_rd_based_breakout_for_intra_tx_search is
2967*77c1e3ccSAndroid Build Coastguard Worker     // enabled, use the known minimum best_rd for early termination.
2968*77c1e3ccSAndroid Build Coastguard Worker     const int64_t rd_thresh =
2969*77c1e3ccSAndroid Build Coastguard Worker         cpi->sf.tx_sf.use_rd_based_breakout_for_intra_tx_search
2970*77c1e3ccSAndroid Build Coastguard Worker             ? AOMMIN(ref_best_rd, best_rd)
2971*77c1e3ccSAndroid Build Coastguard Worker             : ref_best_rd;
2972*77c1e3ccSAndroid Build Coastguard Worker     rd[depth] = uniform_txfm_yrd(cpi, x, &this_rd_stats, rd_thresh, bs, tx_size,
2973*77c1e3ccSAndroid Build Coastguard Worker                                  FTXS_NONE, skip_trellis);
2974*77c1e3ccSAndroid Build Coastguard Worker     if (rd[depth] < best_rd) {
2975*77c1e3ccSAndroid Build Coastguard Worker       av1_copy_array(best_blk_skip, txfm_info->blk_skip, num_blks);
2976*77c1e3ccSAndroid Build Coastguard Worker       av1_copy_array(best_txk_type_map, xd->tx_type_map, num_blks);
2977*77c1e3ccSAndroid Build Coastguard Worker       best_tx_size = tx_size;
2978*77c1e3ccSAndroid Build Coastguard Worker       best_rd = rd[depth];
2979*77c1e3ccSAndroid Build Coastguard Worker       *rd_stats = this_rd_stats;
2980*77c1e3ccSAndroid Build Coastguard Worker     }
2981*77c1e3ccSAndroid Build Coastguard Worker     if (tx_size == TX_4X4) break;
2982*77c1e3ccSAndroid Build Coastguard Worker     // If we are searching three depths, prune the smallest size depending
2983*77c1e3ccSAndroid Build Coastguard Worker     // on rd results for the first two depths for low contrast blocks.
2984*77c1e3ccSAndroid Build Coastguard Worker     if (depth > init_depth && depth != MAX_TX_DEPTH &&
2985*77c1e3ccSAndroid Build Coastguard Worker         x->source_variance < 256) {
2986*77c1e3ccSAndroid Build Coastguard Worker       if (rd[depth - 1] != INT64_MAX && rd[depth] > rd[depth - 1]) break;
2987*77c1e3ccSAndroid Build Coastguard Worker     }
2988*77c1e3ccSAndroid Build Coastguard Worker   }
2989*77c1e3ccSAndroid Build Coastguard Worker 
2990*77c1e3ccSAndroid Build Coastguard Worker   if (rd_stats->rate != INT_MAX) {
2991*77c1e3ccSAndroid Build Coastguard Worker     mbmi->tx_size = best_tx_size;
2992*77c1e3ccSAndroid Build Coastguard Worker     av1_copy_array(xd->tx_type_map, best_txk_type_map, num_blks);
2993*77c1e3ccSAndroid Build Coastguard Worker     av1_copy_array(txfm_info->blk_skip, best_blk_skip, num_blks);
2994*77c1e3ccSAndroid Build Coastguard Worker   }
2995*77c1e3ccSAndroid Build Coastguard Worker 
2996*77c1e3ccSAndroid Build Coastguard Worker #if !CONFIG_REALTIME_ONLY
2997*77c1e3ccSAndroid Build Coastguard Worker   // Reset the flags to avoid any unintentional evaluation of NN model and
2998*77c1e3ccSAndroid Build Coastguard Worker   // consumption of prune depths.
2999*77c1e3ccSAndroid Build Coastguard Worker   txfm_params->enable_nn_prune_intra_tx_depths = false;
3000*77c1e3ccSAndroid Build Coastguard Worker   txfm_params->nn_prune_depths_for_intra_tx = TX_PRUNE_NONE;
3001*77c1e3ccSAndroid Build Coastguard Worker #endif
3002*77c1e3ccSAndroid Build Coastguard Worker }
3003*77c1e3ccSAndroid Build Coastguard Worker 
3004*77c1e3ccSAndroid Build Coastguard Worker // Search for the best transform type for the given transform block in the
3005*77c1e3ccSAndroid Build Coastguard Worker // given plane/channel, and calculate the corresponding RD cost.
block_rd_txfm(int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,void * arg)3006*77c1e3ccSAndroid Build Coastguard Worker static inline void block_rd_txfm(int plane, int block, int blk_row, int blk_col,
3007*77c1e3ccSAndroid Build Coastguard Worker                                  BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
3008*77c1e3ccSAndroid Build Coastguard Worker                                  void *arg) {
3009*77c1e3ccSAndroid Build Coastguard Worker   struct rdcost_block_args *args = arg;
3010*77c1e3ccSAndroid Build Coastguard Worker   if (args->exit_early) {
3011*77c1e3ccSAndroid Build Coastguard Worker     args->incomplete_exit = 1;
3012*77c1e3ccSAndroid Build Coastguard Worker     return;
3013*77c1e3ccSAndroid Build Coastguard Worker   }
3014*77c1e3ccSAndroid Build Coastguard Worker 
3015*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCK *const x = args->x;
3016*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCKD *const xd = &x->e_mbd;
3017*77c1e3ccSAndroid Build Coastguard Worker   const int is_inter = is_inter_block(xd->mi[0]);
3018*77c1e3ccSAndroid Build Coastguard Worker   const AV1_COMP *cpi = args->cpi;
3019*77c1e3ccSAndroid Build Coastguard Worker   ENTROPY_CONTEXT *a = args->t_above + blk_col;
3020*77c1e3ccSAndroid Build Coastguard Worker   ENTROPY_CONTEXT *l = args->t_left + blk_row;
3021*77c1e3ccSAndroid Build Coastguard Worker   const AV1_COMMON *cm = &cpi->common;
3022*77c1e3ccSAndroid Build Coastguard Worker   RD_STATS this_rd_stats;
3023*77c1e3ccSAndroid Build Coastguard Worker   av1_init_rd_stats(&this_rd_stats);
3024*77c1e3ccSAndroid Build Coastguard Worker 
3025*77c1e3ccSAndroid Build Coastguard Worker   if (!is_inter) {
3026*77c1e3ccSAndroid Build Coastguard Worker     av1_predict_intra_block_facade(cm, xd, plane, blk_col, blk_row, tx_size);
3027*77c1e3ccSAndroid Build Coastguard Worker     av1_subtract_txb(x, plane, plane_bsize, blk_col, blk_row, tx_size);
3028*77c1e3ccSAndroid Build Coastguard Worker #if !CONFIG_REALTIME_ONLY
3029*77c1e3ccSAndroid Build Coastguard Worker     const TxfmSearchParams *const txfm_params = &x->txfm_search_params;
3030*77c1e3ccSAndroid Build Coastguard Worker     if (txfm_params->enable_nn_prune_intra_tx_depths) {
3031*77c1e3ccSAndroid Build Coastguard Worker       ml_predict_intra_tx_depth_prune(x, blk_row, blk_col, plane_bsize,
3032*77c1e3ccSAndroid Build Coastguard Worker                                       tx_size);
3033*77c1e3ccSAndroid Build Coastguard Worker       if (txfm_params->nn_prune_depths_for_intra_tx == TX_PRUNE_LARGEST) {
3034*77c1e3ccSAndroid Build Coastguard Worker         av1_invalid_rd_stats(&args->rd_stats);
3035*77c1e3ccSAndroid Build Coastguard Worker         args->exit_early = 1;
3036*77c1e3ccSAndroid Build Coastguard Worker         return;
3037*77c1e3ccSAndroid Build Coastguard Worker       }
3038*77c1e3ccSAndroid Build Coastguard Worker     }
3039*77c1e3ccSAndroid Build Coastguard Worker #endif
3040*77c1e3ccSAndroid Build Coastguard Worker   }
3041*77c1e3ccSAndroid Build Coastguard Worker 
3042*77c1e3ccSAndroid Build Coastguard Worker   TXB_CTX txb_ctx;
3043*77c1e3ccSAndroid Build Coastguard Worker   get_txb_ctx(plane_bsize, tx_size, plane, a, l, &txb_ctx);
3044*77c1e3ccSAndroid Build Coastguard Worker   search_tx_type(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
3045*77c1e3ccSAndroid Build Coastguard Worker                  &txb_ctx, args->ftxs_mode, args->skip_trellis,
3046*77c1e3ccSAndroid Build Coastguard Worker                  args->best_rd - args->current_rd, &this_rd_stats);
3047*77c1e3ccSAndroid Build Coastguard Worker 
3048*77c1e3ccSAndroid Build Coastguard Worker #if !CONFIG_REALTIME_ONLY
3049*77c1e3ccSAndroid Build Coastguard Worker   if (plane == AOM_PLANE_Y && xd->cfl.store_y) {
3050*77c1e3ccSAndroid Build Coastguard Worker     assert(!is_inter || plane_bsize < BLOCK_8X8);
3051*77c1e3ccSAndroid Build Coastguard Worker     cfl_store_tx(xd, blk_row, blk_col, tx_size, plane_bsize);
3052*77c1e3ccSAndroid Build Coastguard Worker   }
3053*77c1e3ccSAndroid Build Coastguard Worker #endif
3054*77c1e3ccSAndroid Build Coastguard Worker 
3055*77c1e3ccSAndroid Build Coastguard Worker #if CONFIG_RD_DEBUG
3056*77c1e3ccSAndroid Build Coastguard Worker   update_txb_coeff_cost(&this_rd_stats, plane, this_rd_stats.rate);
3057*77c1e3ccSAndroid Build Coastguard Worker #endif  // CONFIG_RD_DEBUG
3058*77c1e3ccSAndroid Build Coastguard Worker   av1_set_txb_context(x, plane, block, tx_size, a, l);
3059*77c1e3ccSAndroid Build Coastguard Worker 
3060*77c1e3ccSAndroid Build Coastguard Worker   const int blk_idx =
3061*77c1e3ccSAndroid Build Coastguard Worker       blk_row * (block_size_wide[plane_bsize] >> MI_SIZE_LOG2) + blk_col;
3062*77c1e3ccSAndroid Build Coastguard Worker 
3063*77c1e3ccSAndroid Build Coastguard Worker   TxfmSearchInfo *txfm_info = &x->txfm_search_info;
3064*77c1e3ccSAndroid Build Coastguard Worker   if (plane == 0)
3065*77c1e3ccSAndroid Build Coastguard Worker     set_blk_skip(txfm_info->blk_skip, plane, blk_idx,
3066*77c1e3ccSAndroid Build Coastguard Worker                  x->plane[plane].eobs[block] == 0);
3067*77c1e3ccSAndroid Build Coastguard Worker   else
3068*77c1e3ccSAndroid Build Coastguard Worker     set_blk_skip(txfm_info->blk_skip, plane, blk_idx, 0);
3069*77c1e3ccSAndroid Build Coastguard Worker 
3070*77c1e3ccSAndroid Build Coastguard Worker   int64_t rd;
3071*77c1e3ccSAndroid Build Coastguard Worker   if (is_inter) {
3072*77c1e3ccSAndroid Build Coastguard Worker     const int64_t no_skip_txfm_rd =
3073*77c1e3ccSAndroid Build Coastguard Worker         RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
3074*77c1e3ccSAndroid Build Coastguard Worker     const int64_t skip_txfm_rd = RDCOST(x->rdmult, 0, this_rd_stats.sse);
3075*77c1e3ccSAndroid Build Coastguard Worker     rd = AOMMIN(no_skip_txfm_rd, skip_txfm_rd);
3076*77c1e3ccSAndroid Build Coastguard Worker     this_rd_stats.skip_txfm &= !x->plane[plane].eobs[block];
3077*77c1e3ccSAndroid Build Coastguard Worker   } else {
3078*77c1e3ccSAndroid Build Coastguard Worker     // Signal non-skip_txfm for Intra blocks
3079*77c1e3ccSAndroid Build Coastguard Worker     rd = RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
3080*77c1e3ccSAndroid Build Coastguard Worker     this_rd_stats.skip_txfm = 0;
3081*77c1e3ccSAndroid Build Coastguard Worker   }
3082*77c1e3ccSAndroid Build Coastguard Worker 
3083*77c1e3ccSAndroid Build Coastguard Worker   av1_merge_rd_stats(&args->rd_stats, &this_rd_stats);
3084*77c1e3ccSAndroid Build Coastguard Worker 
3085*77c1e3ccSAndroid Build Coastguard Worker   args->current_rd += rd;
3086*77c1e3ccSAndroid Build Coastguard Worker   if (args->current_rd > args->best_rd) args->exit_early = 1;
3087*77c1e3ccSAndroid Build Coastguard Worker }
3088*77c1e3ccSAndroid Build Coastguard Worker 
av1_estimate_txfm_yrd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs,TX_SIZE tx_size)3089*77c1e3ccSAndroid Build Coastguard Worker int64_t av1_estimate_txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
3090*77c1e3ccSAndroid Build Coastguard Worker                               RD_STATS *rd_stats, int64_t ref_best_rd,
3091*77c1e3ccSAndroid Build Coastguard Worker                               BLOCK_SIZE bs, TX_SIZE tx_size) {
3092*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCKD *const xd = &x->e_mbd;
3093*77c1e3ccSAndroid Build Coastguard Worker   MB_MODE_INFO *const mbmi = xd->mi[0];
3094*77c1e3ccSAndroid Build Coastguard Worker   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3095*77c1e3ccSAndroid Build Coastguard Worker   const ModeCosts *mode_costs = &x->mode_costs;
3096*77c1e3ccSAndroid Build Coastguard Worker   const int is_inter = is_inter_block(mbmi);
3097*77c1e3ccSAndroid Build Coastguard Worker   const int tx_select = txfm_params->tx_mode_search_type == TX_MODE_SELECT &&
3098*77c1e3ccSAndroid Build Coastguard Worker                         block_signals_txsize(mbmi->bsize);
3099*77c1e3ccSAndroid Build Coastguard Worker   int tx_size_rate = 0;
3100*77c1e3ccSAndroid Build Coastguard Worker   if (tx_select) {
3101*77c1e3ccSAndroid Build Coastguard Worker     const int ctx = txfm_partition_context(
3102*77c1e3ccSAndroid Build Coastguard Worker         xd->above_txfm_context, xd->left_txfm_context, mbmi->bsize, tx_size);
3103*77c1e3ccSAndroid Build Coastguard Worker     tx_size_rate = mode_costs->txfm_partition_cost[ctx][0];
3104*77c1e3ccSAndroid Build Coastguard Worker   }
3105*77c1e3ccSAndroid Build Coastguard Worker   const int skip_ctx = av1_get_skip_txfm_context(xd);
3106*77c1e3ccSAndroid Build Coastguard Worker   const int no_skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][0];
3107*77c1e3ccSAndroid Build Coastguard Worker   const int skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][1];
3108*77c1e3ccSAndroid Build Coastguard Worker   const int64_t skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_rate, 0);
3109*77c1e3ccSAndroid Build Coastguard Worker   const int64_t no_this_rd =
3110*77c1e3ccSAndroid Build Coastguard Worker       RDCOST(x->rdmult, no_skip_txfm_rate + tx_size_rate, 0);
3111*77c1e3ccSAndroid Build Coastguard Worker   mbmi->tx_size = tx_size;
3112*77c1e3ccSAndroid Build Coastguard Worker 
3113*77c1e3ccSAndroid Build Coastguard Worker   const uint8_t txw_unit = tx_size_wide_unit[tx_size];
3114*77c1e3ccSAndroid Build Coastguard Worker   const uint8_t txh_unit = tx_size_high_unit[tx_size];
3115*77c1e3ccSAndroid Build Coastguard Worker   const int step = txw_unit * txh_unit;
3116*77c1e3ccSAndroid Build Coastguard Worker   const int max_blocks_wide = max_block_wide(xd, bs, 0);
3117*77c1e3ccSAndroid Build Coastguard Worker   const int max_blocks_high = max_block_high(xd, bs, 0);
3118*77c1e3ccSAndroid Build Coastguard Worker 
3119*77c1e3ccSAndroid Build Coastguard Worker   struct rdcost_block_args args;
3120*77c1e3ccSAndroid Build Coastguard Worker   av1_zero(args);
3121*77c1e3ccSAndroid Build Coastguard Worker   args.x = x;
3122*77c1e3ccSAndroid Build Coastguard Worker   args.cpi = cpi;
3123*77c1e3ccSAndroid Build Coastguard Worker   args.best_rd = ref_best_rd;
3124*77c1e3ccSAndroid Build Coastguard Worker   args.current_rd = AOMMIN(no_this_rd, skip_txfm_rd);
3125*77c1e3ccSAndroid Build Coastguard Worker   av1_init_rd_stats(&args.rd_stats);
3126*77c1e3ccSAndroid Build Coastguard Worker   av1_get_entropy_contexts(bs, &xd->plane[0], args.t_above, args.t_left);
3127*77c1e3ccSAndroid Build Coastguard Worker   int i = 0;
3128*77c1e3ccSAndroid Build Coastguard Worker   for (int blk_row = 0; blk_row < max_blocks_high && !args.incomplete_exit;
3129*77c1e3ccSAndroid Build Coastguard Worker        blk_row += txh_unit) {
3130*77c1e3ccSAndroid Build Coastguard Worker     for (int blk_col = 0; blk_col < max_blocks_wide; blk_col += txw_unit) {
3131*77c1e3ccSAndroid Build Coastguard Worker       RD_STATS this_rd_stats;
3132*77c1e3ccSAndroid Build Coastguard Worker       av1_init_rd_stats(&this_rd_stats);
3133*77c1e3ccSAndroid Build Coastguard Worker 
3134*77c1e3ccSAndroid Build Coastguard Worker       if (args.exit_early) {
3135*77c1e3ccSAndroid Build Coastguard Worker         args.incomplete_exit = 1;
3136*77c1e3ccSAndroid Build Coastguard Worker         break;
3137*77c1e3ccSAndroid Build Coastguard Worker       }
3138*77c1e3ccSAndroid Build Coastguard Worker 
3139*77c1e3ccSAndroid Build Coastguard Worker       ENTROPY_CONTEXT *a = args.t_above + blk_col;
3140*77c1e3ccSAndroid Build Coastguard Worker       ENTROPY_CONTEXT *l = args.t_left + blk_row;
3141*77c1e3ccSAndroid Build Coastguard Worker       TXB_CTX txb_ctx;
3142*77c1e3ccSAndroid Build Coastguard Worker       get_txb_ctx(bs, tx_size, 0, a, l, &txb_ctx);
3143*77c1e3ccSAndroid Build Coastguard Worker 
3144*77c1e3ccSAndroid Build Coastguard Worker       TxfmParam txfm_param;
3145*77c1e3ccSAndroid Build Coastguard Worker       QUANT_PARAM quant_param;
3146*77c1e3ccSAndroid Build Coastguard Worker       av1_setup_xform(&cpi->common, x, tx_size, DCT_DCT, &txfm_param);
3147*77c1e3ccSAndroid Build Coastguard Worker       av1_setup_quant(tx_size, 0, AV1_XFORM_QUANT_B, 0, &quant_param);
3148*77c1e3ccSAndroid Build Coastguard Worker 
3149*77c1e3ccSAndroid Build Coastguard Worker       av1_xform(x, 0, i, blk_row, blk_col, bs, &txfm_param);
3150*77c1e3ccSAndroid Build Coastguard Worker       av1_quant(x, 0, i, &txfm_param, &quant_param);
3151*77c1e3ccSAndroid Build Coastguard Worker 
3152*77c1e3ccSAndroid Build Coastguard Worker       this_rd_stats.rate =
3153*77c1e3ccSAndroid Build Coastguard Worker           cost_coeffs(x, 0, i, tx_size, txfm_param.tx_type, &txb_ctx, 0);
3154*77c1e3ccSAndroid Build Coastguard Worker 
3155*77c1e3ccSAndroid Build Coastguard Worker       const SCAN_ORDER *const scan_order =
3156*77c1e3ccSAndroid Build Coastguard Worker           get_scan(txfm_param.tx_size, txfm_param.tx_type);
3157*77c1e3ccSAndroid Build Coastguard Worker       dist_block_tx_domain(x, 0, i, tx_size, quant_param.qmatrix,
3158*77c1e3ccSAndroid Build Coastguard Worker                            scan_order->scan, &this_rd_stats.dist,
3159*77c1e3ccSAndroid Build Coastguard Worker                            &this_rd_stats.sse);
3160*77c1e3ccSAndroid Build Coastguard Worker 
3161*77c1e3ccSAndroid Build Coastguard Worker       const int64_t no_skip_txfm_rd =
3162*77c1e3ccSAndroid Build Coastguard Worker           RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
3163*77c1e3ccSAndroid Build Coastguard Worker       const int64_t skip_rd = RDCOST(x->rdmult, 0, this_rd_stats.sse);
3164*77c1e3ccSAndroid Build Coastguard Worker 
3165*77c1e3ccSAndroid Build Coastguard Worker       this_rd_stats.skip_txfm &= !x->plane[0].eobs[i];
3166*77c1e3ccSAndroid Build Coastguard Worker 
3167*77c1e3ccSAndroid Build Coastguard Worker       av1_merge_rd_stats(&args.rd_stats, &this_rd_stats);
3168*77c1e3ccSAndroid Build Coastguard Worker       args.current_rd += AOMMIN(no_skip_txfm_rd, skip_rd);
3169*77c1e3ccSAndroid Build Coastguard Worker 
3170*77c1e3ccSAndroid Build Coastguard Worker       if (args.current_rd > ref_best_rd) {
3171*77c1e3ccSAndroid Build Coastguard Worker         args.exit_early = 1;
3172*77c1e3ccSAndroid Build Coastguard Worker         break;
3173*77c1e3ccSAndroid Build Coastguard Worker       }
3174*77c1e3ccSAndroid Build Coastguard Worker 
3175*77c1e3ccSAndroid Build Coastguard Worker       av1_set_txb_context(x, 0, i, tx_size, a, l);
3176*77c1e3ccSAndroid Build Coastguard Worker       i += step;
3177*77c1e3ccSAndroid Build Coastguard Worker     }
3178*77c1e3ccSAndroid Build Coastguard Worker   }
3179*77c1e3ccSAndroid Build Coastguard Worker 
3180*77c1e3ccSAndroid Build Coastguard Worker   if (args.incomplete_exit) av1_invalid_rd_stats(&args.rd_stats);
3181*77c1e3ccSAndroid Build Coastguard Worker 
3182*77c1e3ccSAndroid Build Coastguard Worker   *rd_stats = args.rd_stats;
3183*77c1e3ccSAndroid Build Coastguard Worker   if (rd_stats->rate == INT_MAX) return INT64_MAX;
3184*77c1e3ccSAndroid Build Coastguard Worker 
3185*77c1e3ccSAndroid Build Coastguard Worker   int64_t rd;
3186*77c1e3ccSAndroid Build Coastguard Worker   // rdstats->rate should include all the rate except skip/non-skip cost as the
3187*77c1e3ccSAndroid Build Coastguard Worker   // same is accounted in the caller functions after rd evaluation of all
3188*77c1e3ccSAndroid Build Coastguard Worker   // planes. However the decisions should be done after considering the
3189*77c1e3ccSAndroid Build Coastguard Worker   // skip/non-skip header cost
3190*77c1e3ccSAndroid Build Coastguard Worker   if (rd_stats->skip_txfm && is_inter) {
3191*77c1e3ccSAndroid Build Coastguard Worker     rd = RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
3192*77c1e3ccSAndroid Build Coastguard Worker   } else {
3193*77c1e3ccSAndroid Build Coastguard Worker     // Intra blocks are always signalled as non-skip
3194*77c1e3ccSAndroid Build Coastguard Worker     rd = RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_rate + tx_size_rate,
3195*77c1e3ccSAndroid Build Coastguard Worker                 rd_stats->dist);
3196*77c1e3ccSAndroid Build Coastguard Worker     rd_stats->rate += tx_size_rate;
3197*77c1e3ccSAndroid Build Coastguard Worker   }
3198*77c1e3ccSAndroid Build Coastguard Worker   // Check if forcing the block to skip transform leads to smaller RD cost.
3199*77c1e3ccSAndroid Build Coastguard Worker   if (is_inter && !rd_stats->skip_txfm && !xd->lossless[mbmi->segment_id]) {
3200*77c1e3ccSAndroid Build Coastguard Worker     int64_t temp_skip_txfm_rd =
3201*77c1e3ccSAndroid Build Coastguard Worker         RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
3202*77c1e3ccSAndroid Build Coastguard Worker     if (temp_skip_txfm_rd <= rd) {
3203*77c1e3ccSAndroid Build Coastguard Worker       rd = temp_skip_txfm_rd;
3204*77c1e3ccSAndroid Build Coastguard Worker       rd_stats->rate = 0;
3205*77c1e3ccSAndroid Build Coastguard Worker       rd_stats->dist = rd_stats->sse;
3206*77c1e3ccSAndroid Build Coastguard Worker       rd_stats->skip_txfm = 1;
3207*77c1e3ccSAndroid Build Coastguard Worker     }
3208*77c1e3ccSAndroid Build Coastguard Worker   }
3209*77c1e3ccSAndroid Build Coastguard Worker 
3210*77c1e3ccSAndroid Build Coastguard Worker   return rd;
3211*77c1e3ccSAndroid Build Coastguard Worker }
3212*77c1e3ccSAndroid Build Coastguard Worker 
3213*77c1e3ccSAndroid Build Coastguard Worker // Search for the best transform type for a luma inter-predicted block, given
3214*77c1e3ccSAndroid Build Coastguard Worker // the transform block partitions.
3215*77c1e3ccSAndroid Build Coastguard Worker // This function is used only when some speed features are enabled.
tx_block_yrd(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,BLOCK_SIZE plane_bsize,int depth,ENTROPY_CONTEXT * above_ctx,ENTROPY_CONTEXT * left_ctx,TXFM_CONTEXT * tx_above,TXFM_CONTEXT * tx_left,int64_t ref_best_rd,RD_STATS * rd_stats,FAST_TX_SEARCH_MODE ftxs_mode)3216*77c1e3ccSAndroid Build Coastguard Worker static inline void tx_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row,
3217*77c1e3ccSAndroid Build Coastguard Worker                                 int blk_col, int block, TX_SIZE tx_size,
3218*77c1e3ccSAndroid Build Coastguard Worker                                 BLOCK_SIZE plane_bsize, int depth,
3219*77c1e3ccSAndroid Build Coastguard Worker                                 ENTROPY_CONTEXT *above_ctx,
3220*77c1e3ccSAndroid Build Coastguard Worker                                 ENTROPY_CONTEXT *left_ctx,
3221*77c1e3ccSAndroid Build Coastguard Worker                                 TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
3222*77c1e3ccSAndroid Build Coastguard Worker                                 int64_t ref_best_rd, RD_STATS *rd_stats,
3223*77c1e3ccSAndroid Build Coastguard Worker                                 FAST_TX_SEARCH_MODE ftxs_mode) {
3224*77c1e3ccSAndroid Build Coastguard Worker   assert(tx_size < TX_SIZES_ALL);
3225*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCKD *const xd = &x->e_mbd;
3226*77c1e3ccSAndroid Build Coastguard Worker   MB_MODE_INFO *const mbmi = xd->mi[0];
3227*77c1e3ccSAndroid Build Coastguard Worker   assert(is_inter_block(mbmi));
3228*77c1e3ccSAndroid Build Coastguard Worker   const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
3229*77c1e3ccSAndroid Build Coastguard Worker   const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
3230*77c1e3ccSAndroid Build Coastguard Worker 
3231*77c1e3ccSAndroid Build Coastguard Worker   if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return;
3232*77c1e3ccSAndroid Build Coastguard Worker 
3233*77c1e3ccSAndroid Build Coastguard Worker   const TX_SIZE plane_tx_size = mbmi->inter_tx_size[av1_get_txb_size_index(
3234*77c1e3ccSAndroid Build Coastguard Worker       plane_bsize, blk_row, blk_col)];
3235*77c1e3ccSAndroid Build Coastguard Worker   const int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
3236*77c1e3ccSAndroid Build Coastguard Worker                                          mbmi->bsize, tx_size);
3237*77c1e3ccSAndroid Build Coastguard Worker 
3238*77c1e3ccSAndroid Build Coastguard Worker   av1_init_rd_stats(rd_stats);
3239*77c1e3ccSAndroid Build Coastguard Worker   if (tx_size == plane_tx_size) {
3240*77c1e3ccSAndroid Build Coastguard Worker     ENTROPY_CONTEXT *ta = above_ctx + blk_col;
3241*77c1e3ccSAndroid Build Coastguard Worker     ENTROPY_CONTEXT *tl = left_ctx + blk_row;
3242*77c1e3ccSAndroid Build Coastguard Worker     const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
3243*77c1e3ccSAndroid Build Coastguard Worker     TXB_CTX txb_ctx;
3244*77c1e3ccSAndroid Build Coastguard Worker     get_txb_ctx(plane_bsize, tx_size, 0, ta, tl, &txb_ctx);
3245*77c1e3ccSAndroid Build Coastguard Worker 
3246*77c1e3ccSAndroid Build Coastguard Worker     const int zero_blk_rate =
3247*77c1e3ccSAndroid Build Coastguard Worker         x->coeff_costs.coeff_costs[txs_ctx][get_plane_type(0)]
3248*77c1e3ccSAndroid Build Coastguard Worker             .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
3249*77c1e3ccSAndroid Build Coastguard Worker     rd_stats->zero_rate = zero_blk_rate;
3250*77c1e3ccSAndroid Build Coastguard Worker     tx_type_rd(cpi, x, tx_size, blk_row, blk_col, block, plane_bsize, &txb_ctx,
3251*77c1e3ccSAndroid Build Coastguard Worker                rd_stats, ftxs_mode, ref_best_rd);
3252*77c1e3ccSAndroid Build Coastguard Worker     const int mi_width = mi_size_wide[plane_bsize];
3253*77c1e3ccSAndroid Build Coastguard Worker     TxfmSearchInfo *txfm_info = &x->txfm_search_info;
3254*77c1e3ccSAndroid Build Coastguard Worker     if (RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >=
3255*77c1e3ccSAndroid Build Coastguard Worker             RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse) ||
3256*77c1e3ccSAndroid Build Coastguard Worker         rd_stats->skip_txfm == 1) {
3257*77c1e3ccSAndroid Build Coastguard Worker       rd_stats->rate = zero_blk_rate;
3258*77c1e3ccSAndroid Build Coastguard Worker       rd_stats->dist = rd_stats->sse;
3259*77c1e3ccSAndroid Build Coastguard Worker       rd_stats->skip_txfm = 1;
3260*77c1e3ccSAndroid Build Coastguard Worker       set_blk_skip(txfm_info->blk_skip, 0, blk_row * mi_width + blk_col, 1);
3261*77c1e3ccSAndroid Build Coastguard Worker       x->plane[0].eobs[block] = 0;
3262*77c1e3ccSAndroid Build Coastguard Worker       x->plane[0].txb_entropy_ctx[block] = 0;
3263*77c1e3ccSAndroid Build Coastguard Worker       update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
3264*77c1e3ccSAndroid Build Coastguard Worker     } else {
3265*77c1e3ccSAndroid Build Coastguard Worker       rd_stats->skip_txfm = 0;
3266*77c1e3ccSAndroid Build Coastguard Worker       set_blk_skip(txfm_info->blk_skip, 0, blk_row * mi_width + blk_col, 0);
3267*77c1e3ccSAndroid Build Coastguard Worker     }
3268*77c1e3ccSAndroid Build Coastguard Worker     if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
3269*77c1e3ccSAndroid Build Coastguard Worker       rd_stats->rate += x->mode_costs.txfm_partition_cost[ctx][0];
3270*77c1e3ccSAndroid Build Coastguard Worker     av1_set_txb_context(x, 0, block, tx_size, ta, tl);
3271*77c1e3ccSAndroid Build Coastguard Worker     txfm_partition_update(tx_above + blk_col, tx_left + blk_row, tx_size,
3272*77c1e3ccSAndroid Build Coastguard Worker                           tx_size);
3273*77c1e3ccSAndroid Build Coastguard Worker   } else {
3274*77c1e3ccSAndroid Build Coastguard Worker     const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
3275*77c1e3ccSAndroid Build Coastguard Worker     const int txb_width = tx_size_wide_unit[sub_txs];
3276*77c1e3ccSAndroid Build Coastguard Worker     const int txb_height = tx_size_high_unit[sub_txs];
3277*77c1e3ccSAndroid Build Coastguard Worker     const int step = txb_height * txb_width;
3278*77c1e3ccSAndroid Build Coastguard Worker     const int row_end =
3279*77c1e3ccSAndroid Build Coastguard Worker         AOMMIN(tx_size_high_unit[tx_size], max_blocks_high - blk_row);
3280*77c1e3ccSAndroid Build Coastguard Worker     const int col_end =
3281*77c1e3ccSAndroid Build Coastguard Worker         AOMMIN(tx_size_wide_unit[tx_size], max_blocks_wide - blk_col);
3282*77c1e3ccSAndroid Build Coastguard Worker     RD_STATS pn_rd_stats;
3283*77c1e3ccSAndroid Build Coastguard Worker     int64_t this_rd = 0;
3284*77c1e3ccSAndroid Build Coastguard Worker     assert(txb_width > 0 && txb_height > 0);
3285*77c1e3ccSAndroid Build Coastguard Worker 
3286*77c1e3ccSAndroid Build Coastguard Worker     for (int row = 0; row < row_end; row += txb_height) {
3287*77c1e3ccSAndroid Build Coastguard Worker       const int offsetr = blk_row + row;
3288*77c1e3ccSAndroid Build Coastguard Worker       for (int col = 0; col < col_end; col += txb_width) {
3289*77c1e3ccSAndroid Build Coastguard Worker         const int offsetc = blk_col + col;
3290*77c1e3ccSAndroid Build Coastguard Worker 
3291*77c1e3ccSAndroid Build Coastguard Worker         av1_init_rd_stats(&pn_rd_stats);
3292*77c1e3ccSAndroid Build Coastguard Worker         tx_block_yrd(cpi, x, offsetr, offsetc, block, sub_txs, plane_bsize,
3293*77c1e3ccSAndroid Build Coastguard Worker                      depth + 1, above_ctx, left_ctx, tx_above, tx_left,
3294*77c1e3ccSAndroid Build Coastguard Worker                      ref_best_rd - this_rd, &pn_rd_stats, ftxs_mode);
3295*77c1e3ccSAndroid Build Coastguard Worker         if (pn_rd_stats.rate == INT_MAX) {
3296*77c1e3ccSAndroid Build Coastguard Worker           av1_invalid_rd_stats(rd_stats);
3297*77c1e3ccSAndroid Build Coastguard Worker           return;
3298*77c1e3ccSAndroid Build Coastguard Worker         }
3299*77c1e3ccSAndroid Build Coastguard Worker         av1_merge_rd_stats(rd_stats, &pn_rd_stats);
3300*77c1e3ccSAndroid Build Coastguard Worker         this_rd += RDCOST(x->rdmult, pn_rd_stats.rate, pn_rd_stats.dist);
3301*77c1e3ccSAndroid Build Coastguard Worker         block += step;
3302*77c1e3ccSAndroid Build Coastguard Worker       }
3303*77c1e3ccSAndroid Build Coastguard Worker     }
3304*77c1e3ccSAndroid Build Coastguard Worker 
3305*77c1e3ccSAndroid Build Coastguard Worker     if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
3306*77c1e3ccSAndroid Build Coastguard Worker       rd_stats->rate += x->mode_costs.txfm_partition_cost[ctx][1];
3307*77c1e3ccSAndroid Build Coastguard Worker   }
3308*77c1e3ccSAndroid Build Coastguard Worker }
3309*77c1e3ccSAndroid Build Coastguard Worker 
3310*77c1e3ccSAndroid Build Coastguard Worker // search for tx type with tx sizes already decided for a inter-predicted luma
3311*77c1e3ccSAndroid Build Coastguard Worker // partition block. It's used only when some speed features are enabled.
3312*77c1e3ccSAndroid Build Coastguard Worker // Return value 0: early termination triggered, no valid rd cost available;
3313*77c1e3ccSAndroid Build Coastguard Worker //              1: rd cost values are valid.
inter_block_yrd(const AV1_COMP * cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd,FAST_TX_SEARCH_MODE ftxs_mode)3314*77c1e3ccSAndroid Build Coastguard Worker static int inter_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
3315*77c1e3ccSAndroid Build Coastguard Worker                            RD_STATS *rd_stats, BLOCK_SIZE bsize,
3316*77c1e3ccSAndroid Build Coastguard Worker                            int64_t ref_best_rd, FAST_TX_SEARCH_MODE ftxs_mode) {
3317*77c1e3ccSAndroid Build Coastguard Worker   if (ref_best_rd < 0) {
3318*77c1e3ccSAndroid Build Coastguard Worker     av1_invalid_rd_stats(rd_stats);
3319*77c1e3ccSAndroid Build Coastguard Worker     return 0;
3320*77c1e3ccSAndroid Build Coastguard Worker   }
3321*77c1e3ccSAndroid Build Coastguard Worker 
3322*77c1e3ccSAndroid Build Coastguard Worker   av1_init_rd_stats(rd_stats);
3323*77c1e3ccSAndroid Build Coastguard Worker 
3324*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCKD *const xd = &x->e_mbd;
3325*77c1e3ccSAndroid Build Coastguard Worker   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3326*77c1e3ccSAndroid Build Coastguard Worker   const struct macroblockd_plane *const pd = &xd->plane[0];
3327*77c1e3ccSAndroid Build Coastguard Worker   const int mi_width = mi_size_wide[bsize];
3328*77c1e3ccSAndroid Build Coastguard Worker   const int mi_height = mi_size_high[bsize];
3329*77c1e3ccSAndroid Build Coastguard Worker   const TX_SIZE max_tx_size = get_vartx_max_txsize(xd, bsize, 0);
3330*77c1e3ccSAndroid Build Coastguard Worker   const int bh = tx_size_high_unit[max_tx_size];
3331*77c1e3ccSAndroid Build Coastguard Worker   const int bw = tx_size_wide_unit[max_tx_size];
3332*77c1e3ccSAndroid Build Coastguard Worker   const int step = bw * bh;
3333*77c1e3ccSAndroid Build Coastguard Worker   const int init_depth = get_search_init_depth(
3334*77c1e3ccSAndroid Build Coastguard Worker       mi_width, mi_height, 1, &cpi->sf, txfm_params->tx_size_search_method);
3335*77c1e3ccSAndroid Build Coastguard Worker   ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
3336*77c1e3ccSAndroid Build Coastguard Worker   ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
3337*77c1e3ccSAndroid Build Coastguard Worker   TXFM_CONTEXT tx_above[MAX_MIB_SIZE];
3338*77c1e3ccSAndroid Build Coastguard Worker   TXFM_CONTEXT tx_left[MAX_MIB_SIZE];
3339*77c1e3ccSAndroid Build Coastguard Worker   av1_get_entropy_contexts(bsize, pd, ctxa, ctxl);
3340*77c1e3ccSAndroid Build Coastguard Worker   memcpy(tx_above, xd->above_txfm_context, sizeof(TXFM_CONTEXT) * mi_width);
3341*77c1e3ccSAndroid Build Coastguard Worker   memcpy(tx_left, xd->left_txfm_context, sizeof(TXFM_CONTEXT) * mi_height);
3342*77c1e3ccSAndroid Build Coastguard Worker 
3343*77c1e3ccSAndroid Build Coastguard Worker   int64_t this_rd = 0;
3344*77c1e3ccSAndroid Build Coastguard Worker   for (int idy = 0, block = 0; idy < mi_height; idy += bh) {
3345*77c1e3ccSAndroid Build Coastguard Worker     for (int idx = 0; idx < mi_width; idx += bw) {
3346*77c1e3ccSAndroid Build Coastguard Worker       RD_STATS pn_rd_stats;
3347*77c1e3ccSAndroid Build Coastguard Worker       av1_init_rd_stats(&pn_rd_stats);
3348*77c1e3ccSAndroid Build Coastguard Worker       tx_block_yrd(cpi, x, idy, idx, block, max_tx_size, bsize, init_depth,
3349*77c1e3ccSAndroid Build Coastguard Worker                    ctxa, ctxl, tx_above, tx_left, ref_best_rd - this_rd,
3350*77c1e3ccSAndroid Build Coastguard Worker                    &pn_rd_stats, ftxs_mode);
3351*77c1e3ccSAndroid Build Coastguard Worker       if (pn_rd_stats.rate == INT_MAX) {
3352*77c1e3ccSAndroid Build Coastguard Worker         av1_invalid_rd_stats(rd_stats);
3353*77c1e3ccSAndroid Build Coastguard Worker         return 0;
3354*77c1e3ccSAndroid Build Coastguard Worker       }
3355*77c1e3ccSAndroid Build Coastguard Worker       av1_merge_rd_stats(rd_stats, &pn_rd_stats);
3356*77c1e3ccSAndroid Build Coastguard Worker       this_rd +=
3357*77c1e3ccSAndroid Build Coastguard Worker           AOMMIN(RDCOST(x->rdmult, pn_rd_stats.rate, pn_rd_stats.dist),
3358*77c1e3ccSAndroid Build Coastguard Worker                  RDCOST(x->rdmult, pn_rd_stats.zero_rate, pn_rd_stats.sse));
3359*77c1e3ccSAndroid Build Coastguard Worker       block += step;
3360*77c1e3ccSAndroid Build Coastguard Worker     }
3361*77c1e3ccSAndroid Build Coastguard Worker   }
3362*77c1e3ccSAndroid Build Coastguard Worker 
3363*77c1e3ccSAndroid Build Coastguard Worker   const int skip_ctx = av1_get_skip_txfm_context(xd);
3364*77c1e3ccSAndroid Build Coastguard Worker   const int no_skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][0];
3365*77c1e3ccSAndroid Build Coastguard Worker   const int skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][1];
3366*77c1e3ccSAndroid Build Coastguard Worker   const int64_t skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
3367*77c1e3ccSAndroid Build Coastguard Worker   this_rd =
3368*77c1e3ccSAndroid Build Coastguard Worker       RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_rate, rd_stats->dist);
3369*77c1e3ccSAndroid Build Coastguard Worker   if (skip_txfm_rd < this_rd) {
3370*77c1e3ccSAndroid Build Coastguard Worker     this_rd = skip_txfm_rd;
3371*77c1e3ccSAndroid Build Coastguard Worker     rd_stats->rate = 0;
3372*77c1e3ccSAndroid Build Coastguard Worker     rd_stats->dist = rd_stats->sse;
3373*77c1e3ccSAndroid Build Coastguard Worker     rd_stats->skip_txfm = 1;
3374*77c1e3ccSAndroid Build Coastguard Worker   }
3375*77c1e3ccSAndroid Build Coastguard Worker 
3376*77c1e3ccSAndroid Build Coastguard Worker   const int is_cost_valid = this_rd > ref_best_rd;
3377*77c1e3ccSAndroid Build Coastguard Worker   if (!is_cost_valid) {
3378*77c1e3ccSAndroid Build Coastguard Worker     // reset cost value
3379*77c1e3ccSAndroid Build Coastguard Worker     av1_invalid_rd_stats(rd_stats);
3380*77c1e3ccSAndroid Build Coastguard Worker   }
3381*77c1e3ccSAndroid Build Coastguard Worker   return is_cost_valid;
3382*77c1e3ccSAndroid Build Coastguard Worker }
3383*77c1e3ccSAndroid Build Coastguard Worker 
3384*77c1e3ccSAndroid Build Coastguard Worker // Search for the best transform size and type for current inter-predicted
3385*77c1e3ccSAndroid Build Coastguard Worker // luma block with recursive transform block partitioning. The obtained
3386*77c1e3ccSAndroid Build Coastguard Worker // transform selection will be saved in xd->mi[0], the corresponding RD stats
3387*77c1e3ccSAndroid Build Coastguard Worker // will be saved in rd_stats. The returned value is the corresponding RD cost.
select_tx_size_and_type(const AV1_COMP * cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd)3388*77c1e3ccSAndroid Build Coastguard Worker static int64_t select_tx_size_and_type(const AV1_COMP *cpi, MACROBLOCK *x,
3389*77c1e3ccSAndroid Build Coastguard Worker                                        RD_STATS *rd_stats, BLOCK_SIZE bsize,
3390*77c1e3ccSAndroid Build Coastguard Worker                                        int64_t ref_best_rd) {
3391*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCKD *const xd = &x->e_mbd;
3392*77c1e3ccSAndroid Build Coastguard Worker   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3393*77c1e3ccSAndroid Build Coastguard Worker   assert(is_inter_block(xd->mi[0]));
3394*77c1e3ccSAndroid Build Coastguard Worker   assert(bsize < BLOCK_SIZES_ALL);
3395*77c1e3ccSAndroid Build Coastguard Worker   const int fast_tx_search = txfm_params->tx_size_search_method > USE_FULL_RD;
3396*77c1e3ccSAndroid Build Coastguard Worker   int64_t rd_thresh = ref_best_rd;
3397*77c1e3ccSAndroid Build Coastguard Worker   if (rd_thresh == 0) {
3398*77c1e3ccSAndroid Build Coastguard Worker     av1_invalid_rd_stats(rd_stats);
3399*77c1e3ccSAndroid Build Coastguard Worker     return INT64_MAX;
3400*77c1e3ccSAndroid Build Coastguard Worker   }
3401*77c1e3ccSAndroid Build Coastguard Worker   if (fast_tx_search && rd_thresh < INT64_MAX) {
3402*77c1e3ccSAndroid Build Coastguard Worker     if (INT64_MAX - rd_thresh > (rd_thresh >> 3)) rd_thresh += (rd_thresh >> 3);
3403*77c1e3ccSAndroid Build Coastguard Worker   }
3404*77c1e3ccSAndroid Build Coastguard Worker   assert(rd_thresh > 0);
3405*77c1e3ccSAndroid Build Coastguard Worker   const FAST_TX_SEARCH_MODE ftxs_mode =
3406*77c1e3ccSAndroid Build Coastguard Worker       fast_tx_search ? FTXS_DCT_AND_1D_DCT_ONLY : FTXS_NONE;
3407*77c1e3ccSAndroid Build Coastguard Worker   const struct macroblockd_plane *const pd = &xd->plane[0];
3408*77c1e3ccSAndroid Build Coastguard Worker   assert(bsize < BLOCK_SIZES_ALL);
3409*77c1e3ccSAndroid Build Coastguard Worker   const int mi_width = mi_size_wide[bsize];
3410*77c1e3ccSAndroid Build Coastguard Worker   const int mi_height = mi_size_high[bsize];
3411*77c1e3ccSAndroid Build Coastguard Worker   ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
3412*77c1e3ccSAndroid Build Coastguard Worker   ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
3413*77c1e3ccSAndroid Build Coastguard Worker   TXFM_CONTEXT tx_above[MAX_MIB_SIZE];
3414*77c1e3ccSAndroid Build Coastguard Worker   TXFM_CONTEXT tx_left[MAX_MIB_SIZE];
3415*77c1e3ccSAndroid Build Coastguard Worker   av1_get_entropy_contexts(bsize, pd, ctxa, ctxl);
3416*77c1e3ccSAndroid Build Coastguard Worker   memcpy(tx_above, xd->above_txfm_context, sizeof(TXFM_CONTEXT) * mi_width);
3417*77c1e3ccSAndroid Build Coastguard Worker   memcpy(tx_left, xd->left_txfm_context, sizeof(TXFM_CONTEXT) * mi_height);
3418*77c1e3ccSAndroid Build Coastguard Worker   const int init_depth = get_search_init_depth(
3419*77c1e3ccSAndroid Build Coastguard Worker       mi_width, mi_height, 1, &cpi->sf, txfm_params->tx_size_search_method);
3420*77c1e3ccSAndroid Build Coastguard Worker   const TX_SIZE max_tx_size = max_txsize_rect_lookup[bsize];
3421*77c1e3ccSAndroid Build Coastguard Worker   const int bh = tx_size_high_unit[max_tx_size];
3422*77c1e3ccSAndroid Build Coastguard Worker   const int bw = tx_size_wide_unit[max_tx_size];
3423*77c1e3ccSAndroid Build Coastguard Worker   const int step = bw * bh;
3424*77c1e3ccSAndroid Build Coastguard Worker   const int skip_ctx = av1_get_skip_txfm_context(xd);
3425*77c1e3ccSAndroid Build Coastguard Worker   const int no_skip_txfm_cost = x->mode_costs.skip_txfm_cost[skip_ctx][0];
3426*77c1e3ccSAndroid Build Coastguard Worker   const int skip_txfm_cost = x->mode_costs.skip_txfm_cost[skip_ctx][1];
3427*77c1e3ccSAndroid Build Coastguard Worker   int64_t skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_cost, 0);
3428*77c1e3ccSAndroid Build Coastguard Worker   int64_t no_skip_txfm_rd = RDCOST(x->rdmult, no_skip_txfm_cost, 0);
3429*77c1e3ccSAndroid Build Coastguard Worker   int block = 0;
3430*77c1e3ccSAndroid Build Coastguard Worker 
3431*77c1e3ccSAndroid Build Coastguard Worker   av1_init_rd_stats(rd_stats);
3432*77c1e3ccSAndroid Build Coastguard Worker   for (int idy = 0; idy < max_block_high(xd, bsize, 0); idy += bh) {
3433*77c1e3ccSAndroid Build Coastguard Worker     for (int idx = 0; idx < max_block_wide(xd, bsize, 0); idx += bw) {
3434*77c1e3ccSAndroid Build Coastguard Worker       const int64_t best_rd_sofar =
3435*77c1e3ccSAndroid Build Coastguard Worker           (rd_thresh == INT64_MAX)
3436*77c1e3ccSAndroid Build Coastguard Worker               ? INT64_MAX
3437*77c1e3ccSAndroid Build Coastguard Worker               : (rd_thresh - (AOMMIN(skip_txfm_rd, no_skip_txfm_rd)));
3438*77c1e3ccSAndroid Build Coastguard Worker       int is_cost_valid = 1;
3439*77c1e3ccSAndroid Build Coastguard Worker       RD_STATS pn_rd_stats;
3440*77c1e3ccSAndroid Build Coastguard Worker       // Search for the best transform block size and type for the sub-block.
3441*77c1e3ccSAndroid Build Coastguard Worker       select_tx_block(cpi, x, idy, idx, block, max_tx_size, init_depth, bsize,
3442*77c1e3ccSAndroid Build Coastguard Worker                       ctxa, ctxl, tx_above, tx_left, &pn_rd_stats, INT64_MAX,
3443*77c1e3ccSAndroid Build Coastguard Worker                       best_rd_sofar, &is_cost_valid, ftxs_mode);
3444*77c1e3ccSAndroid Build Coastguard Worker       if (!is_cost_valid || pn_rd_stats.rate == INT_MAX) {
3445*77c1e3ccSAndroid Build Coastguard Worker         av1_invalid_rd_stats(rd_stats);
3446*77c1e3ccSAndroid Build Coastguard Worker         return INT64_MAX;
3447*77c1e3ccSAndroid Build Coastguard Worker       }
3448*77c1e3ccSAndroid Build Coastguard Worker       av1_merge_rd_stats(rd_stats, &pn_rd_stats);
3449*77c1e3ccSAndroid Build Coastguard Worker       skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_cost, rd_stats->sse);
3450*77c1e3ccSAndroid Build Coastguard Worker       no_skip_txfm_rd =
3451*77c1e3ccSAndroid Build Coastguard Worker           RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_cost, rd_stats->dist);
3452*77c1e3ccSAndroid Build Coastguard Worker       block += step;
3453*77c1e3ccSAndroid Build Coastguard Worker     }
3454*77c1e3ccSAndroid Build Coastguard Worker   }
3455*77c1e3ccSAndroid Build Coastguard Worker 
3456*77c1e3ccSAndroid Build Coastguard Worker   if (rd_stats->rate == INT_MAX) return INT64_MAX;
3457*77c1e3ccSAndroid Build Coastguard Worker 
3458*77c1e3ccSAndroid Build Coastguard Worker   rd_stats->skip_txfm = (skip_txfm_rd <= no_skip_txfm_rd);
3459*77c1e3ccSAndroid Build Coastguard Worker 
3460*77c1e3ccSAndroid Build Coastguard Worker   // If fast_tx_search is true, only DCT and 1D DCT were tested in
3461*77c1e3ccSAndroid Build Coastguard Worker   // select_inter_block_yrd() above. Do a better search for tx type with
3462*77c1e3ccSAndroid Build Coastguard Worker   // tx sizes already decided.
3463*77c1e3ccSAndroid Build Coastguard Worker   if (fast_tx_search && cpi->sf.tx_sf.refine_fast_tx_search_results) {
3464*77c1e3ccSAndroid Build Coastguard Worker     if (!inter_block_yrd(cpi, x, rd_stats, bsize, ref_best_rd, FTXS_NONE))
3465*77c1e3ccSAndroid Build Coastguard Worker       return INT64_MAX;
3466*77c1e3ccSAndroid Build Coastguard Worker   }
3467*77c1e3ccSAndroid Build Coastguard Worker 
3468*77c1e3ccSAndroid Build Coastguard Worker   int64_t final_rd;
3469*77c1e3ccSAndroid Build Coastguard Worker   if (rd_stats->skip_txfm) {
3470*77c1e3ccSAndroid Build Coastguard Worker     final_rd = RDCOST(x->rdmult, skip_txfm_cost, rd_stats->sse);
3471*77c1e3ccSAndroid Build Coastguard Worker   } else {
3472*77c1e3ccSAndroid Build Coastguard Worker     final_rd =
3473*77c1e3ccSAndroid Build Coastguard Worker         RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_cost, rd_stats->dist);
3474*77c1e3ccSAndroid Build Coastguard Worker     if (!xd->lossless[xd->mi[0]->segment_id]) {
3475*77c1e3ccSAndroid Build Coastguard Worker       final_rd =
3476*77c1e3ccSAndroid Build Coastguard Worker           AOMMIN(final_rd, RDCOST(x->rdmult, skip_txfm_cost, rd_stats->sse));
3477*77c1e3ccSAndroid Build Coastguard Worker     }
3478*77c1e3ccSAndroid Build Coastguard Worker   }
3479*77c1e3ccSAndroid Build Coastguard Worker 
3480*77c1e3ccSAndroid Build Coastguard Worker   return final_rd;
3481*77c1e3ccSAndroid Build Coastguard Worker }
3482*77c1e3ccSAndroid Build Coastguard Worker 
3483*77c1e3ccSAndroid Build Coastguard Worker // Return 1 to terminate transform search early. The decision is made based on
3484*77c1e3ccSAndroid Build Coastguard Worker // the comparison with the reference RD cost and the model-estimated RD cost.
model_based_tx_search_prune(const AV1_COMP * cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int64_t ref_best_rd)3485*77c1e3ccSAndroid Build Coastguard Worker static inline int model_based_tx_search_prune(const AV1_COMP *cpi,
3486*77c1e3ccSAndroid Build Coastguard Worker                                               MACROBLOCK *x, BLOCK_SIZE bsize,
3487*77c1e3ccSAndroid Build Coastguard Worker                                               int64_t ref_best_rd) {
3488*77c1e3ccSAndroid Build Coastguard Worker   const int level = cpi->sf.tx_sf.model_based_prune_tx_search_level;
3489*77c1e3ccSAndroid Build Coastguard Worker   assert(level >= 0 && level <= 2);
3490*77c1e3ccSAndroid Build Coastguard Worker   int model_rate;
3491*77c1e3ccSAndroid Build Coastguard Worker   int64_t model_dist;
3492*77c1e3ccSAndroid Build Coastguard Worker   uint8_t model_skip;
3493*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCKD *const xd = &x->e_mbd;
3494*77c1e3ccSAndroid Build Coastguard Worker   model_rd_sb_fn[MODELRD_TYPE_TX_SEARCH_PRUNE](
3495*77c1e3ccSAndroid Build Coastguard Worker       cpi, bsize, x, xd, 0, 0, &model_rate, &model_dist, &model_skip, NULL,
3496*77c1e3ccSAndroid Build Coastguard Worker       NULL, NULL, NULL);
3497*77c1e3ccSAndroid Build Coastguard Worker   if (model_skip) return 0;
3498*77c1e3ccSAndroid Build Coastguard Worker   const int64_t model_rd = RDCOST(x->rdmult, model_rate, model_dist);
3499*77c1e3ccSAndroid Build Coastguard Worker   // TODO(debargha, urvang): Improve the model and make the check below
3500*77c1e3ccSAndroid Build Coastguard Worker   // tighter.
3501*77c1e3ccSAndroid Build Coastguard Worker   static const int prune_factor_by8[] = { 3, 5 };
3502*77c1e3ccSAndroid Build Coastguard Worker   const int factor = prune_factor_by8[level - 1];
3503*77c1e3ccSAndroid Build Coastguard Worker   return ((model_rd * factor) >> 3) > ref_best_rd;
3504*77c1e3ccSAndroid Build Coastguard Worker }
3505*77c1e3ccSAndroid Build Coastguard Worker 
av1_pick_recursive_tx_size_type_yrd(const AV1_COMP * cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd)3506*77c1e3ccSAndroid Build Coastguard Worker void av1_pick_recursive_tx_size_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
3507*77c1e3ccSAndroid Build Coastguard Worker                                          RD_STATS *rd_stats, BLOCK_SIZE bsize,
3508*77c1e3ccSAndroid Build Coastguard Worker                                          int64_t ref_best_rd) {
3509*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCKD *const xd = &x->e_mbd;
3510*77c1e3ccSAndroid Build Coastguard Worker   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3511*77c1e3ccSAndroid Build Coastguard Worker   assert(is_inter_block(xd->mi[0]));
3512*77c1e3ccSAndroid Build Coastguard Worker 
3513*77c1e3ccSAndroid Build Coastguard Worker   av1_invalid_rd_stats(rd_stats);
3514*77c1e3ccSAndroid Build Coastguard Worker 
3515*77c1e3ccSAndroid Build Coastguard Worker   // If modeled RD cost is a lot worse than the best so far, terminate early.
3516*77c1e3ccSAndroid Build Coastguard Worker   if (cpi->sf.tx_sf.model_based_prune_tx_search_level &&
3517*77c1e3ccSAndroid Build Coastguard Worker       ref_best_rd != INT64_MAX) {
3518*77c1e3ccSAndroid Build Coastguard Worker     if (model_based_tx_search_prune(cpi, x, bsize, ref_best_rd)) return;
3519*77c1e3ccSAndroid Build Coastguard Worker   }
3520*77c1e3ccSAndroid Build Coastguard Worker 
3521*77c1e3ccSAndroid Build Coastguard Worker   // Hashing based speed feature. If the hash of the prediction residue block is
3522*77c1e3ccSAndroid Build Coastguard Worker   // found in the hash table, use previous search results and terminate early.
3523*77c1e3ccSAndroid Build Coastguard Worker   uint32_t hash = 0;
3524*77c1e3ccSAndroid Build Coastguard Worker   MB_RD_RECORD *mb_rd_record = NULL;
3525*77c1e3ccSAndroid Build Coastguard Worker   const int mi_row = x->e_mbd.mi_row;
3526*77c1e3ccSAndroid Build Coastguard Worker   const int mi_col = x->e_mbd.mi_col;
3527*77c1e3ccSAndroid Build Coastguard Worker   const int within_border =
3528*77c1e3ccSAndroid Build Coastguard Worker       mi_row >= xd->tile.mi_row_start &&
3529*77c1e3ccSAndroid Build Coastguard Worker       (mi_row + mi_size_high[bsize] < xd->tile.mi_row_end) &&
3530*77c1e3ccSAndroid Build Coastguard Worker       mi_col >= xd->tile.mi_col_start &&
3531*77c1e3ccSAndroid Build Coastguard Worker       (mi_col + mi_size_wide[bsize] < xd->tile.mi_col_end);
3532*77c1e3ccSAndroid Build Coastguard Worker   const int is_mb_rd_hash_enabled =
3533*77c1e3ccSAndroid Build Coastguard Worker       (within_border && cpi->sf.rd_sf.use_mb_rd_hash);
3534*77c1e3ccSAndroid Build Coastguard Worker   const int n4 = bsize_to_num_blk(bsize);
3535*77c1e3ccSAndroid Build Coastguard Worker   if (is_mb_rd_hash_enabled) {
3536*77c1e3ccSAndroid Build Coastguard Worker     hash = get_block_residue_hash(x, bsize);
3537*77c1e3ccSAndroid Build Coastguard Worker     mb_rd_record = x->txfm_search_info.mb_rd_record;
3538*77c1e3ccSAndroid Build Coastguard Worker     const int match_index = find_mb_rd_info(mb_rd_record, ref_best_rd, hash);
3539*77c1e3ccSAndroid Build Coastguard Worker     if (match_index != -1) {
3540*77c1e3ccSAndroid Build Coastguard Worker       MB_RD_INFO *mb_rd_info = &mb_rd_record->mb_rd_info[match_index];
3541*77c1e3ccSAndroid Build Coastguard Worker       fetch_mb_rd_info(n4, mb_rd_info, rd_stats, x);
3542*77c1e3ccSAndroid Build Coastguard Worker       return;
3543*77c1e3ccSAndroid Build Coastguard Worker     }
3544*77c1e3ccSAndroid Build Coastguard Worker   }
3545*77c1e3ccSAndroid Build Coastguard Worker 
3546*77c1e3ccSAndroid Build Coastguard Worker   // If we predict that skip is the optimal RD decision - set the respective
3547*77c1e3ccSAndroid Build Coastguard Worker   // context and terminate early.
3548*77c1e3ccSAndroid Build Coastguard Worker   int64_t dist;
3549*77c1e3ccSAndroid Build Coastguard Worker   if (txfm_params->skip_txfm_level &&
3550*77c1e3ccSAndroid Build Coastguard Worker       predict_skip_txfm(x, bsize, &dist,
3551*77c1e3ccSAndroid Build Coastguard Worker                         cpi->common.features.reduced_tx_set_used)) {
3552*77c1e3ccSAndroid Build Coastguard Worker     set_skip_txfm(x, rd_stats, bsize, dist);
3553*77c1e3ccSAndroid Build Coastguard Worker     // Save the RD search results into mb_rd_record.
3554*77c1e3ccSAndroid Build Coastguard Worker     if (is_mb_rd_hash_enabled)
3555*77c1e3ccSAndroid Build Coastguard Worker       save_mb_rd_info(n4, hash, x, rd_stats, mb_rd_record);
3556*77c1e3ccSAndroid Build Coastguard Worker     return;
3557*77c1e3ccSAndroid Build Coastguard Worker   }
3558*77c1e3ccSAndroid Build Coastguard Worker #if CONFIG_SPEED_STATS
3559*77c1e3ccSAndroid Build Coastguard Worker   ++x->txfm_search_info.tx_search_count;
3560*77c1e3ccSAndroid Build Coastguard Worker #endif  // CONFIG_SPEED_STATS
3561*77c1e3ccSAndroid Build Coastguard Worker 
3562*77c1e3ccSAndroid Build Coastguard Worker   const int64_t rd =
3563*77c1e3ccSAndroid Build Coastguard Worker       select_tx_size_and_type(cpi, x, rd_stats, bsize, ref_best_rd);
3564*77c1e3ccSAndroid Build Coastguard Worker 
3565*77c1e3ccSAndroid Build Coastguard Worker   if (rd == INT64_MAX) {
3566*77c1e3ccSAndroid Build Coastguard Worker     // We should always find at least one candidate unless ref_best_rd is less
3567*77c1e3ccSAndroid Build Coastguard Worker     // than INT64_MAX (in which case, all the calls to select_tx_size_fix_type
3568*77c1e3ccSAndroid Build Coastguard Worker     // might have failed to find something better)
3569*77c1e3ccSAndroid Build Coastguard Worker     assert(ref_best_rd != INT64_MAX);
3570*77c1e3ccSAndroid Build Coastguard Worker     av1_invalid_rd_stats(rd_stats);
3571*77c1e3ccSAndroid Build Coastguard Worker     return;
3572*77c1e3ccSAndroid Build Coastguard Worker   }
3573*77c1e3ccSAndroid Build Coastguard Worker 
3574*77c1e3ccSAndroid Build Coastguard Worker   // Save the RD search results into mb_rd_record.
3575*77c1e3ccSAndroid Build Coastguard Worker   if (is_mb_rd_hash_enabled) {
3576*77c1e3ccSAndroid Build Coastguard Worker     assert(mb_rd_record != NULL);
3577*77c1e3ccSAndroid Build Coastguard Worker     save_mb_rd_info(n4, hash, x, rd_stats, mb_rd_record);
3578*77c1e3ccSAndroid Build Coastguard Worker   }
3579*77c1e3ccSAndroid Build Coastguard Worker }
3580*77c1e3ccSAndroid Build Coastguard Worker 
av1_pick_uniform_tx_size_type_yrd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bs,int64_t ref_best_rd)3581*77c1e3ccSAndroid Build Coastguard Worker void av1_pick_uniform_tx_size_type_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
3582*77c1e3ccSAndroid Build Coastguard Worker                                        RD_STATS *rd_stats, BLOCK_SIZE bs,
3583*77c1e3ccSAndroid Build Coastguard Worker                                        int64_t ref_best_rd) {
3584*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCKD *const xd = &x->e_mbd;
3585*77c1e3ccSAndroid Build Coastguard Worker   MB_MODE_INFO *const mbmi = xd->mi[0];
3586*77c1e3ccSAndroid Build Coastguard Worker   const TxfmSearchParams *tx_params = &x->txfm_search_params;
3587*77c1e3ccSAndroid Build Coastguard Worker   assert(bs == mbmi->bsize);
3588*77c1e3ccSAndroid Build Coastguard Worker   const int is_inter = is_inter_block(mbmi);
3589*77c1e3ccSAndroid Build Coastguard Worker   const int mi_row = xd->mi_row;
3590*77c1e3ccSAndroid Build Coastguard Worker   const int mi_col = xd->mi_col;
3591*77c1e3ccSAndroid Build Coastguard Worker 
3592*77c1e3ccSAndroid Build Coastguard Worker   av1_init_rd_stats(rd_stats);
3593*77c1e3ccSAndroid Build Coastguard Worker 
3594*77c1e3ccSAndroid Build Coastguard Worker   // Hashing based speed feature for inter blocks. If the hash of the residue
3595*77c1e3ccSAndroid Build Coastguard Worker   // block is found in the table, use previously saved search results and
3596*77c1e3ccSAndroid Build Coastguard Worker   // terminate early.
3597*77c1e3ccSAndroid Build Coastguard Worker   uint32_t hash = 0;
3598*77c1e3ccSAndroid Build Coastguard Worker   MB_RD_RECORD *mb_rd_record = NULL;
3599*77c1e3ccSAndroid Build Coastguard Worker   const int num_blks = bsize_to_num_blk(bs);
3600*77c1e3ccSAndroid Build Coastguard Worker   if (is_inter && cpi->sf.rd_sf.use_mb_rd_hash) {
3601*77c1e3ccSAndroid Build Coastguard Worker     const int within_border =
3602*77c1e3ccSAndroid Build Coastguard Worker         mi_row >= xd->tile.mi_row_start &&
3603*77c1e3ccSAndroid Build Coastguard Worker         (mi_row + mi_size_high[bs] < xd->tile.mi_row_end) &&
3604*77c1e3ccSAndroid Build Coastguard Worker         mi_col >= xd->tile.mi_col_start &&
3605*77c1e3ccSAndroid Build Coastguard Worker         (mi_col + mi_size_wide[bs] < xd->tile.mi_col_end);
3606*77c1e3ccSAndroid Build Coastguard Worker     if (within_border) {
3607*77c1e3ccSAndroid Build Coastguard Worker       hash = get_block_residue_hash(x, bs);
3608*77c1e3ccSAndroid Build Coastguard Worker       mb_rd_record = x->txfm_search_info.mb_rd_record;
3609*77c1e3ccSAndroid Build Coastguard Worker       const int match_index = find_mb_rd_info(mb_rd_record, ref_best_rd, hash);
3610*77c1e3ccSAndroid Build Coastguard Worker       if (match_index != -1) {
3611*77c1e3ccSAndroid Build Coastguard Worker         MB_RD_INFO *mb_rd_info = &mb_rd_record->mb_rd_info[match_index];
3612*77c1e3ccSAndroid Build Coastguard Worker         fetch_mb_rd_info(num_blks, mb_rd_info, rd_stats, x);
3613*77c1e3ccSAndroid Build Coastguard Worker         return;
3614*77c1e3ccSAndroid Build Coastguard Worker       }
3615*77c1e3ccSAndroid Build Coastguard Worker     }
3616*77c1e3ccSAndroid Build Coastguard Worker   }
3617*77c1e3ccSAndroid Build Coastguard Worker 
3618*77c1e3ccSAndroid Build Coastguard Worker   // If we predict that skip is the optimal RD decision - set the respective
3619*77c1e3ccSAndroid Build Coastguard Worker   // context and terminate early.
3620*77c1e3ccSAndroid Build Coastguard Worker   int64_t dist;
3621*77c1e3ccSAndroid Build Coastguard Worker   if (tx_params->skip_txfm_level && is_inter &&
3622*77c1e3ccSAndroid Build Coastguard Worker       !xd->lossless[mbmi->segment_id] &&
3623*77c1e3ccSAndroid Build Coastguard Worker       predict_skip_txfm(x, bs, &dist,
3624*77c1e3ccSAndroid Build Coastguard Worker                         cpi->common.features.reduced_tx_set_used)) {
3625*77c1e3ccSAndroid Build Coastguard Worker     // Populate rdstats as per skip decision
3626*77c1e3ccSAndroid Build Coastguard Worker     set_skip_txfm(x, rd_stats, bs, dist);
3627*77c1e3ccSAndroid Build Coastguard Worker     // Save the RD search results into mb_rd_record.
3628*77c1e3ccSAndroid Build Coastguard Worker     if (mb_rd_record) {
3629*77c1e3ccSAndroid Build Coastguard Worker       save_mb_rd_info(num_blks, hash, x, rd_stats, mb_rd_record);
3630*77c1e3ccSAndroid Build Coastguard Worker     }
3631*77c1e3ccSAndroid Build Coastguard Worker     return;
3632*77c1e3ccSAndroid Build Coastguard Worker   }
3633*77c1e3ccSAndroid Build Coastguard Worker 
3634*77c1e3ccSAndroid Build Coastguard Worker   if (xd->lossless[mbmi->segment_id]) {
3635*77c1e3ccSAndroid Build Coastguard Worker     // Lossless mode can only pick the smallest (4x4) transform size.
3636*77c1e3ccSAndroid Build Coastguard Worker     choose_smallest_tx_size(cpi, x, rd_stats, ref_best_rd, bs);
3637*77c1e3ccSAndroid Build Coastguard Worker   } else if (tx_params->tx_size_search_method == USE_LARGESTALL) {
3638*77c1e3ccSAndroid Build Coastguard Worker     choose_largest_tx_size(cpi, x, rd_stats, ref_best_rd, bs);
3639*77c1e3ccSAndroid Build Coastguard Worker   } else {
3640*77c1e3ccSAndroid Build Coastguard Worker     choose_tx_size_type_from_rd(cpi, x, rd_stats, ref_best_rd, bs);
3641*77c1e3ccSAndroid Build Coastguard Worker   }
3642*77c1e3ccSAndroid Build Coastguard Worker 
3643*77c1e3ccSAndroid Build Coastguard Worker   // Save the RD search results into mb_rd_record for possible reuse in future.
3644*77c1e3ccSAndroid Build Coastguard Worker   if (mb_rd_record) {
3645*77c1e3ccSAndroid Build Coastguard Worker     save_mb_rd_info(num_blks, hash, x, rd_stats, mb_rd_record);
3646*77c1e3ccSAndroid Build Coastguard Worker   }
3647*77c1e3ccSAndroid Build Coastguard Worker }
3648*77c1e3ccSAndroid Build Coastguard Worker 
av1_txfm_uvrd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd)3649*77c1e3ccSAndroid Build Coastguard Worker int av1_txfm_uvrd(const AV1_COMP *const cpi, MACROBLOCK *x, RD_STATS *rd_stats,
3650*77c1e3ccSAndroid Build Coastguard Worker                   BLOCK_SIZE bsize, int64_t ref_best_rd) {
3651*77c1e3ccSAndroid Build Coastguard Worker   av1_init_rd_stats(rd_stats);
3652*77c1e3ccSAndroid Build Coastguard Worker   if (ref_best_rd < 0) return 0;
3653*77c1e3ccSAndroid Build Coastguard Worker   if (!x->e_mbd.is_chroma_ref) return 1;
3654*77c1e3ccSAndroid Build Coastguard Worker 
3655*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCKD *const xd = &x->e_mbd;
3656*77c1e3ccSAndroid Build Coastguard Worker   MB_MODE_INFO *const mbmi = xd->mi[0];
3657*77c1e3ccSAndroid Build Coastguard Worker   struct macroblockd_plane *const pd = &xd->plane[AOM_PLANE_U];
3658*77c1e3ccSAndroid Build Coastguard Worker   const int is_inter = is_inter_block(mbmi);
3659*77c1e3ccSAndroid Build Coastguard Worker   int64_t this_rd = 0, skip_txfm_rd = 0;
3660*77c1e3ccSAndroid Build Coastguard Worker   const BLOCK_SIZE plane_bsize =
3661*77c1e3ccSAndroid Build Coastguard Worker       get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
3662*77c1e3ccSAndroid Build Coastguard Worker 
3663*77c1e3ccSAndroid Build Coastguard Worker   if (is_inter) {
3664*77c1e3ccSAndroid Build Coastguard Worker     for (int plane = 1; plane < MAX_MB_PLANE; ++plane)
3665*77c1e3ccSAndroid Build Coastguard Worker       av1_subtract_plane(x, plane_bsize, plane);
3666*77c1e3ccSAndroid Build Coastguard Worker   }
3667*77c1e3ccSAndroid Build Coastguard Worker 
3668*77c1e3ccSAndroid Build Coastguard Worker   const int skip_trellis = 0;
3669*77c1e3ccSAndroid Build Coastguard Worker   const TX_SIZE uv_tx_size = av1_get_tx_size(AOM_PLANE_U, xd);
3670*77c1e3ccSAndroid Build Coastguard Worker   int is_cost_valid = 1;
3671*77c1e3ccSAndroid Build Coastguard Worker   for (int plane = 1; plane < MAX_MB_PLANE; ++plane) {
3672*77c1e3ccSAndroid Build Coastguard Worker     RD_STATS this_rd_stats;
3673*77c1e3ccSAndroid Build Coastguard Worker     int64_t chroma_ref_best_rd = ref_best_rd;
3674*77c1e3ccSAndroid Build Coastguard Worker     // For inter blocks, refined ref_best_rd is used for early exit
3675*77c1e3ccSAndroid Build Coastguard Worker     // For intra blocks, even though current rd crosses ref_best_rd, early
3676*77c1e3ccSAndroid Build Coastguard Worker     // exit is not recommended as current rd is used for gating subsequent
3677*77c1e3ccSAndroid Build Coastguard Worker     // modes as well (say, for angular modes)
3678*77c1e3ccSAndroid Build Coastguard Worker     // TODO(any): Extend the early exit mechanism for intra modes as well
3679*77c1e3ccSAndroid Build Coastguard Worker     if (cpi->sf.inter_sf.perform_best_rd_based_gating_for_chroma && is_inter &&
3680*77c1e3ccSAndroid Build Coastguard Worker         chroma_ref_best_rd != INT64_MAX)
3681*77c1e3ccSAndroid Build Coastguard Worker       chroma_ref_best_rd = ref_best_rd - AOMMIN(this_rd, skip_txfm_rd);
3682*77c1e3ccSAndroid Build Coastguard Worker     av1_txfm_rd_in_plane(x, cpi, &this_rd_stats, chroma_ref_best_rd, 0, plane,
3683*77c1e3ccSAndroid Build Coastguard Worker                          plane_bsize, uv_tx_size, FTXS_NONE, skip_trellis);
3684*77c1e3ccSAndroid Build Coastguard Worker     if (this_rd_stats.rate == INT_MAX) {
3685*77c1e3ccSAndroid Build Coastguard Worker       is_cost_valid = 0;
3686*77c1e3ccSAndroid Build Coastguard Worker       break;
3687*77c1e3ccSAndroid Build Coastguard Worker     }
3688*77c1e3ccSAndroid Build Coastguard Worker     av1_merge_rd_stats(rd_stats, &this_rd_stats);
3689*77c1e3ccSAndroid Build Coastguard Worker     this_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
3690*77c1e3ccSAndroid Build Coastguard Worker     skip_txfm_rd = RDCOST(x->rdmult, 0, rd_stats->sse);
3691*77c1e3ccSAndroid Build Coastguard Worker     if (AOMMIN(this_rd, skip_txfm_rd) > ref_best_rd) {
3692*77c1e3ccSAndroid Build Coastguard Worker       is_cost_valid = 0;
3693*77c1e3ccSAndroid Build Coastguard Worker       break;
3694*77c1e3ccSAndroid Build Coastguard Worker     }
3695*77c1e3ccSAndroid Build Coastguard Worker   }
3696*77c1e3ccSAndroid Build Coastguard Worker 
3697*77c1e3ccSAndroid Build Coastguard Worker   if (!is_cost_valid) {
3698*77c1e3ccSAndroid Build Coastguard Worker     // reset cost value
3699*77c1e3ccSAndroid Build Coastguard Worker     av1_invalid_rd_stats(rd_stats);
3700*77c1e3ccSAndroid Build Coastguard Worker   }
3701*77c1e3ccSAndroid Build Coastguard Worker 
3702*77c1e3ccSAndroid Build Coastguard Worker   return is_cost_valid;
3703*77c1e3ccSAndroid Build Coastguard Worker }
3704*77c1e3ccSAndroid Build Coastguard Worker 
av1_txfm_rd_in_plane(MACROBLOCK * x,const AV1_COMP * cpi,RD_STATS * rd_stats,int64_t ref_best_rd,int64_t current_rd,int plane,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,FAST_TX_SEARCH_MODE ftxs_mode,int skip_trellis)3705*77c1e3ccSAndroid Build Coastguard Worker void av1_txfm_rd_in_plane(MACROBLOCK *x, const AV1_COMP *cpi,
3706*77c1e3ccSAndroid Build Coastguard Worker                           RD_STATS *rd_stats, int64_t ref_best_rd,
3707*77c1e3ccSAndroid Build Coastguard Worker                           int64_t current_rd, int plane, BLOCK_SIZE plane_bsize,
3708*77c1e3ccSAndroid Build Coastguard Worker                           TX_SIZE tx_size, FAST_TX_SEARCH_MODE ftxs_mode,
3709*77c1e3ccSAndroid Build Coastguard Worker                           int skip_trellis) {
3710*77c1e3ccSAndroid Build Coastguard Worker   assert(IMPLIES(plane == 0, x->e_mbd.mi[0]->tx_size == tx_size));
3711*77c1e3ccSAndroid Build Coastguard Worker 
3712*77c1e3ccSAndroid Build Coastguard Worker   if (!cpi->oxcf.txfm_cfg.enable_tx64 &&
3713*77c1e3ccSAndroid Build Coastguard Worker       txsize_sqr_up_map[tx_size] == TX_64X64) {
3714*77c1e3ccSAndroid Build Coastguard Worker     av1_invalid_rd_stats(rd_stats);
3715*77c1e3ccSAndroid Build Coastguard Worker     return;
3716*77c1e3ccSAndroid Build Coastguard Worker   }
3717*77c1e3ccSAndroid Build Coastguard Worker 
3718*77c1e3ccSAndroid Build Coastguard Worker   if (current_rd > ref_best_rd) {
3719*77c1e3ccSAndroid Build Coastguard Worker     av1_invalid_rd_stats(rd_stats);
3720*77c1e3ccSAndroid Build Coastguard Worker     return;
3721*77c1e3ccSAndroid Build Coastguard Worker   }
3722*77c1e3ccSAndroid Build Coastguard Worker 
3723*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCKD *const xd = &x->e_mbd;
3724*77c1e3ccSAndroid Build Coastguard Worker   const struct macroblockd_plane *const pd = &xd->plane[plane];
3725*77c1e3ccSAndroid Build Coastguard Worker   struct rdcost_block_args args;
3726*77c1e3ccSAndroid Build Coastguard Worker   av1_zero(args);
3727*77c1e3ccSAndroid Build Coastguard Worker   args.x = x;
3728*77c1e3ccSAndroid Build Coastguard Worker   args.cpi = cpi;
3729*77c1e3ccSAndroid Build Coastguard Worker   args.best_rd = ref_best_rd;
3730*77c1e3ccSAndroid Build Coastguard Worker   args.current_rd = current_rd;
3731*77c1e3ccSAndroid Build Coastguard Worker   args.ftxs_mode = ftxs_mode;
3732*77c1e3ccSAndroid Build Coastguard Worker   args.skip_trellis = skip_trellis;
3733*77c1e3ccSAndroid Build Coastguard Worker   av1_init_rd_stats(&args.rd_stats);
3734*77c1e3ccSAndroid Build Coastguard Worker 
3735*77c1e3ccSAndroid Build Coastguard Worker   av1_get_entropy_contexts(plane_bsize, pd, args.t_above, args.t_left);
3736*77c1e3ccSAndroid Build Coastguard Worker   av1_foreach_transformed_block_in_plane(xd, plane_bsize, plane, block_rd_txfm,
3737*77c1e3ccSAndroid Build Coastguard Worker                                          &args);
3738*77c1e3ccSAndroid Build Coastguard Worker 
3739*77c1e3ccSAndroid Build Coastguard Worker   MB_MODE_INFO *const mbmi = xd->mi[0];
3740*77c1e3ccSAndroid Build Coastguard Worker   const int is_inter = is_inter_block(mbmi);
3741*77c1e3ccSAndroid Build Coastguard Worker   const int invalid_rd = is_inter ? args.incomplete_exit : args.exit_early;
3742*77c1e3ccSAndroid Build Coastguard Worker 
3743*77c1e3ccSAndroid Build Coastguard Worker   if (invalid_rd) {
3744*77c1e3ccSAndroid Build Coastguard Worker     av1_invalid_rd_stats(rd_stats);
3745*77c1e3ccSAndroid Build Coastguard Worker   } else {
3746*77c1e3ccSAndroid Build Coastguard Worker     *rd_stats = args.rd_stats;
3747*77c1e3ccSAndroid Build Coastguard Worker   }
3748*77c1e3ccSAndroid Build Coastguard Worker }
3749*77c1e3ccSAndroid Build Coastguard Worker 
av1_txfm_search(const AV1_COMP * cpi,MACROBLOCK * x,BLOCK_SIZE bsize,RD_STATS * rd_stats,RD_STATS * rd_stats_y,RD_STATS * rd_stats_uv,int mode_rate,int64_t ref_best_rd)3750*77c1e3ccSAndroid Build Coastguard Worker int av1_txfm_search(const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bsize,
3751*77c1e3ccSAndroid Build Coastguard Worker                     RD_STATS *rd_stats, RD_STATS *rd_stats_y,
3752*77c1e3ccSAndroid Build Coastguard Worker                     RD_STATS *rd_stats_uv, int mode_rate, int64_t ref_best_rd) {
3753*77c1e3ccSAndroid Build Coastguard Worker   MACROBLOCKD *const xd = &x->e_mbd;
3754*77c1e3ccSAndroid Build Coastguard Worker   TxfmSearchParams *txfm_params = &x->txfm_search_params;
3755*77c1e3ccSAndroid Build Coastguard Worker   const int skip_ctx = av1_get_skip_txfm_context(xd);
3756*77c1e3ccSAndroid Build Coastguard Worker   const int skip_txfm_cost[2] = { x->mode_costs.skip_txfm_cost[skip_ctx][0],
3757*77c1e3ccSAndroid Build Coastguard Worker                                   x->mode_costs.skip_txfm_cost[skip_ctx][1] };
3758*77c1e3ccSAndroid Build Coastguard Worker   const int64_t min_header_rate =
3759*77c1e3ccSAndroid Build Coastguard Worker       mode_rate + AOMMIN(skip_txfm_cost[0], skip_txfm_cost[1]);
3760*77c1e3ccSAndroid Build Coastguard Worker   // Account for minimum skip and non_skip rd.
3761*77c1e3ccSAndroid Build Coastguard Worker   // Eventually either one of them will be added to mode_rate
3762*77c1e3ccSAndroid Build Coastguard Worker   const int64_t min_header_rd_possible = RDCOST(x->rdmult, min_header_rate, 0);
3763*77c1e3ccSAndroid Build Coastguard Worker   if (min_header_rd_possible > ref_best_rd) {
3764*77c1e3ccSAndroid Build Coastguard Worker     av1_invalid_rd_stats(rd_stats_y);
3765*77c1e3ccSAndroid Build Coastguard Worker     return 0;
3766*77c1e3ccSAndroid Build Coastguard Worker   }
3767*77c1e3ccSAndroid Build Coastguard Worker 
3768*77c1e3ccSAndroid Build Coastguard Worker   const AV1_COMMON *cm = &cpi->common;
3769*77c1e3ccSAndroid Build Coastguard Worker   MB_MODE_INFO *const mbmi = xd->mi[0];
3770*77c1e3ccSAndroid Build Coastguard Worker   const int64_t mode_rd = RDCOST(x->rdmult, mode_rate, 0);
3771*77c1e3ccSAndroid Build Coastguard Worker   const int64_t rd_thresh =
3772*77c1e3ccSAndroid Build Coastguard Worker       ref_best_rd == INT64_MAX ? INT64_MAX : ref_best_rd - mode_rd;
3773*77c1e3ccSAndroid Build Coastguard Worker   av1_init_rd_stats(rd_stats);
3774*77c1e3ccSAndroid Build Coastguard Worker   av1_init_rd_stats(rd_stats_y);
3775*77c1e3ccSAndroid Build Coastguard Worker   rd_stats->rate = mode_rate;
3776*77c1e3ccSAndroid Build Coastguard Worker 
3777*77c1e3ccSAndroid Build Coastguard Worker   // cost and distortion
3778*77c1e3ccSAndroid Build Coastguard Worker   av1_subtract_plane(x, bsize, 0);
3779*77c1e3ccSAndroid Build Coastguard Worker   if (txfm_params->tx_mode_search_type == TX_MODE_SELECT &&
3780*77c1e3ccSAndroid Build Coastguard Worker       !xd->lossless[mbmi->segment_id]) {
3781*77c1e3ccSAndroid Build Coastguard Worker     av1_pick_recursive_tx_size_type_yrd(cpi, x, rd_stats_y, bsize, rd_thresh);
3782*77c1e3ccSAndroid Build Coastguard Worker #if CONFIG_COLLECT_RD_STATS == 2
3783*77c1e3ccSAndroid Build Coastguard Worker     PrintPredictionUnitStats(cpi, tile_data, x, rd_stats_y, bsize);
3784*77c1e3ccSAndroid Build Coastguard Worker #endif  // CONFIG_COLLECT_RD_STATS == 2
3785*77c1e3ccSAndroid Build Coastguard Worker   } else {
3786*77c1e3ccSAndroid Build Coastguard Worker     av1_pick_uniform_tx_size_type_yrd(cpi, x, rd_stats_y, bsize, rd_thresh);
3787*77c1e3ccSAndroid Build Coastguard Worker     memset(mbmi->inter_tx_size, mbmi->tx_size, sizeof(mbmi->inter_tx_size));
3788*77c1e3ccSAndroid Build Coastguard Worker     for (int i = 0; i < xd->height * xd->width; ++i)
3789*77c1e3ccSAndroid Build Coastguard Worker       set_blk_skip(x->txfm_search_info.blk_skip, 0, i, rd_stats_y->skip_txfm);
3790*77c1e3ccSAndroid Build Coastguard Worker   }
3791*77c1e3ccSAndroid Build Coastguard Worker 
3792*77c1e3ccSAndroid Build Coastguard Worker   if (rd_stats_y->rate == INT_MAX) return 0;
3793*77c1e3ccSAndroid Build Coastguard Worker 
3794*77c1e3ccSAndroid Build Coastguard Worker   av1_merge_rd_stats(rd_stats, rd_stats_y);
3795*77c1e3ccSAndroid Build Coastguard Worker 
3796*77c1e3ccSAndroid Build Coastguard Worker   const int64_t non_skip_txfm_rdcosty =
3797*77c1e3ccSAndroid Build Coastguard Worker       RDCOST(x->rdmult, rd_stats->rate + skip_txfm_cost[0], rd_stats->dist);
3798*77c1e3ccSAndroid Build Coastguard Worker   const int64_t skip_txfm_rdcosty =
3799*77c1e3ccSAndroid Build Coastguard Worker       RDCOST(x->rdmult, mode_rate + skip_txfm_cost[1], rd_stats->sse);
3800*77c1e3ccSAndroid Build Coastguard Worker   const int64_t min_rdcosty = AOMMIN(non_skip_txfm_rdcosty, skip_txfm_rdcosty);
3801*77c1e3ccSAndroid Build Coastguard Worker   if (min_rdcosty > ref_best_rd) return 0;
3802*77c1e3ccSAndroid Build Coastguard Worker 
3803*77c1e3ccSAndroid Build Coastguard Worker   av1_init_rd_stats(rd_stats_uv);
3804*77c1e3ccSAndroid Build Coastguard Worker   const int num_planes = av1_num_planes(cm);
3805*77c1e3ccSAndroid Build Coastguard Worker   if (num_planes > 1) {
3806*77c1e3ccSAndroid Build Coastguard Worker     int64_t ref_best_chroma_rd = ref_best_rd;
3807*77c1e3ccSAndroid Build Coastguard Worker     // Calculate best rd cost possible for chroma
3808*77c1e3ccSAndroid Build Coastguard Worker     if (cpi->sf.inter_sf.perform_best_rd_based_gating_for_chroma &&
3809*77c1e3ccSAndroid Build Coastguard Worker         (ref_best_chroma_rd != INT64_MAX)) {
3810*77c1e3ccSAndroid Build Coastguard Worker       ref_best_chroma_rd = (ref_best_chroma_rd -
3811*77c1e3ccSAndroid Build Coastguard Worker                             AOMMIN(non_skip_txfm_rdcosty, skip_txfm_rdcosty));
3812*77c1e3ccSAndroid Build Coastguard Worker     }
3813*77c1e3ccSAndroid Build Coastguard Worker     const int is_cost_valid_uv =
3814*77c1e3ccSAndroid Build Coastguard Worker         av1_txfm_uvrd(cpi, x, rd_stats_uv, bsize, ref_best_chroma_rd);
3815*77c1e3ccSAndroid Build Coastguard Worker     if (!is_cost_valid_uv) return 0;
3816*77c1e3ccSAndroid Build Coastguard Worker     av1_merge_rd_stats(rd_stats, rd_stats_uv);
3817*77c1e3ccSAndroid Build Coastguard Worker   }
3818*77c1e3ccSAndroid Build Coastguard Worker 
3819*77c1e3ccSAndroid Build Coastguard Worker   int choose_skip_txfm = rd_stats->skip_txfm;
3820*77c1e3ccSAndroid Build Coastguard Worker   if (!choose_skip_txfm && !xd->lossless[mbmi->segment_id]) {
3821*77c1e3ccSAndroid Build Coastguard Worker     const int64_t rdcost_no_skip_txfm = RDCOST(
3822*77c1e3ccSAndroid Build Coastguard Worker         x->rdmult, rd_stats_y->rate + rd_stats_uv->rate + skip_txfm_cost[0],
3823*77c1e3ccSAndroid Build Coastguard Worker         rd_stats->dist);
3824*77c1e3ccSAndroid Build Coastguard Worker     const int64_t rdcost_skip_txfm =
3825*77c1e3ccSAndroid Build Coastguard Worker         RDCOST(x->rdmult, skip_txfm_cost[1], rd_stats->sse);
3826*77c1e3ccSAndroid Build Coastguard Worker     if (rdcost_no_skip_txfm >= rdcost_skip_txfm) choose_skip_txfm = 1;
3827*77c1e3ccSAndroid Build Coastguard Worker   }
3828*77c1e3ccSAndroid Build Coastguard Worker   if (choose_skip_txfm) {
3829*77c1e3ccSAndroid Build Coastguard Worker     rd_stats_y->rate = 0;
3830*77c1e3ccSAndroid Build Coastguard Worker     rd_stats_uv->rate = 0;
3831*77c1e3ccSAndroid Build Coastguard Worker     rd_stats->rate = mode_rate + skip_txfm_cost[1];
3832*77c1e3ccSAndroid Build Coastguard Worker     rd_stats->dist = rd_stats->sse;
3833*77c1e3ccSAndroid Build Coastguard Worker     rd_stats_y->dist = rd_stats_y->sse;
3834*77c1e3ccSAndroid Build Coastguard Worker     rd_stats_uv->dist = rd_stats_uv->sse;
3835*77c1e3ccSAndroid Build Coastguard Worker     mbmi->skip_txfm = 1;
3836*77c1e3ccSAndroid Build Coastguard Worker     if (rd_stats->skip_txfm) {
3837*77c1e3ccSAndroid Build Coastguard Worker       const int64_t tmprd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
3838*77c1e3ccSAndroid Build Coastguard Worker       if (tmprd > ref_best_rd) return 0;
3839*77c1e3ccSAndroid Build Coastguard Worker     }
3840*77c1e3ccSAndroid Build Coastguard Worker   } else {
3841*77c1e3ccSAndroid Build Coastguard Worker     rd_stats->rate += skip_txfm_cost[0];
3842*77c1e3ccSAndroid Build Coastguard Worker     mbmi->skip_txfm = 0;
3843*77c1e3ccSAndroid Build Coastguard Worker   }
3844*77c1e3ccSAndroid Build Coastguard Worker 
3845*77c1e3ccSAndroid Build Coastguard Worker   return 1;
3846*77c1e3ccSAndroid Build Coastguard Worker }
3847