xref: /aosp_15_r20/external/libaom/aom_dsp/x86/highbd_variance_sse2.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <assert.h>
13 #include <emmintrin.h>  // SSE2
14 
15 #include "config/aom_config.h"
16 #include "config/aom_dsp_rtcd.h"
17 
18 #include "aom_dsp/x86/synonyms.h"
19 #include "aom_ports/mem.h"
20 
21 #include "av1/common/filter.h"
22 #include "av1/common/reconinter.h"
23 
24 typedef uint32_t (*high_variance_fn_t)(const uint16_t *src, int src_stride,
25                                        const uint16_t *ref, int ref_stride,
26                                        uint32_t *sse, int *sum);
27 
28 uint32_t aom_highbd_calc8x8var_sse2(const uint16_t *src, int src_stride,
29                                     const uint16_t *ref, int ref_stride,
30                                     uint32_t *sse, int *sum);
31 
32 uint32_t aom_highbd_calc16x16var_sse2(const uint16_t *src, int src_stride,
33                                       const uint16_t *ref, int ref_stride,
34                                       uint32_t *sse, int *sum);
35 
highbd_8_variance_sse2(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int w,int h,uint32_t * sse,int * sum,high_variance_fn_t var_fn,int block_size)36 static void highbd_8_variance_sse2(const uint16_t *src, int src_stride,
37                                    const uint16_t *ref, int ref_stride, int w,
38                                    int h, uint32_t *sse, int *sum,
39                                    high_variance_fn_t var_fn, int block_size) {
40   int i, j;
41 
42   *sse = 0;
43   *sum = 0;
44 
45   for (i = 0; i < h; i += block_size) {
46     for (j = 0; j < w; j += block_size) {
47       unsigned int sse0;
48       int sum0;
49       var_fn(src + src_stride * i + j, src_stride, ref + ref_stride * i + j,
50              ref_stride, &sse0, &sum0);
51       *sse += sse0;
52       *sum += sum0;
53     }
54   }
55 }
56 
highbd_10_variance_sse2(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int w,int h,uint32_t * sse,int * sum,high_variance_fn_t var_fn,int block_size)57 static void highbd_10_variance_sse2(const uint16_t *src, int src_stride,
58                                     const uint16_t *ref, int ref_stride, int w,
59                                     int h, uint32_t *sse, int *sum,
60                                     high_variance_fn_t var_fn, int block_size) {
61   int i, j;
62   uint64_t sse_long = 0;
63   int32_t sum_long = 0;
64 
65   for (i = 0; i < h; i += block_size) {
66     for (j = 0; j < w; j += block_size) {
67       unsigned int sse0;
68       int sum0;
69       var_fn(src + src_stride * i + j, src_stride, ref + ref_stride * i + j,
70              ref_stride, &sse0, &sum0);
71       sse_long += sse0;
72       sum_long += sum0;
73     }
74   }
75   *sum = ROUND_POWER_OF_TWO(sum_long, 2);
76   *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 4);
77 }
78 
highbd_12_variance_sse2(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int w,int h,uint32_t * sse,int * sum,high_variance_fn_t var_fn,int block_size)79 static void highbd_12_variance_sse2(const uint16_t *src, int src_stride,
80                                     const uint16_t *ref, int ref_stride, int w,
81                                     int h, uint32_t *sse, int *sum,
82                                     high_variance_fn_t var_fn, int block_size) {
83   int i, j;
84   uint64_t sse_long = 0;
85   int32_t sum_long = 0;
86 
87   for (i = 0; i < h; i += block_size) {
88     for (j = 0; j < w; j += block_size) {
89       unsigned int sse0;
90       int sum0;
91       var_fn(src + src_stride * i + j, src_stride, ref + ref_stride * i + j,
92              ref_stride, &sse0, &sum0);
93       sse_long += sse0;
94       sum_long += sum0;
95     }
96   }
97   *sum = ROUND_POWER_OF_TWO(sum_long, 4);
98   *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 8);
99 }
100 
101 #define VAR_FN(w, h, block_size, shift)                                    \
102   uint32_t aom_highbd_8_variance##w##x##h##_sse2(                          \
103       const uint8_t *src8, int src_stride, const uint8_t *ref8,            \
104       int ref_stride, uint32_t *sse) {                                     \
105     int sum;                                                               \
106     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                             \
107     uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                             \
108     highbd_8_variance_sse2(                                                \
109         src, src_stride, ref, ref_stride, w, h, sse, &sum,                 \
110         aom_highbd_calc##block_size##x##block_size##var_sse2, block_size); \
111     return *sse - (uint32_t)(((int64_t)sum * sum) >> shift);               \
112   }                                                                        \
113                                                                            \
114   uint32_t aom_highbd_10_variance##w##x##h##_sse2(                         \
115       const uint8_t *src8, int src_stride, const uint8_t *ref8,            \
116       int ref_stride, uint32_t *sse) {                                     \
117     int sum;                                                               \
118     int64_t var;                                                           \
119     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                             \
120     uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                             \
121     highbd_10_variance_sse2(                                               \
122         src, src_stride, ref, ref_stride, w, h, sse, &sum,                 \
123         aom_highbd_calc##block_size##x##block_size##var_sse2, block_size); \
124     var = (int64_t)(*sse) - (((int64_t)sum * sum) >> shift);               \
125     return (var >= 0) ? (uint32_t)var : 0;                                 \
126   }                                                                        \
127                                                                            \
128   uint32_t aom_highbd_12_variance##w##x##h##_sse2(                         \
129       const uint8_t *src8, int src_stride, const uint8_t *ref8,            \
130       int ref_stride, uint32_t *sse) {                                     \
131     int sum;                                                               \
132     int64_t var;                                                           \
133     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                             \
134     uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                             \
135     highbd_12_variance_sse2(                                               \
136         src, src_stride, ref, ref_stride, w, h, sse, &sum,                 \
137         aom_highbd_calc##block_size##x##block_size##var_sse2, block_size); \
138     var = (int64_t)(*sse) - (((int64_t)sum * sum) >> shift);               \
139     return (var >= 0) ? (uint32_t)var : 0;                                 \
140   }
141 
142 VAR_FN(128, 128, 16, 14)
143 VAR_FN(128, 64, 16, 13)
144 VAR_FN(64, 128, 16, 13)
145 VAR_FN(64, 64, 16, 12)
146 VAR_FN(64, 32, 16, 11)
147 VAR_FN(32, 64, 16, 11)
148 VAR_FN(32, 32, 16, 10)
149 VAR_FN(32, 16, 16, 9)
150 VAR_FN(16, 32, 16, 9)
151 VAR_FN(16, 16, 16, 8)
152 VAR_FN(16, 8, 8, 7)
153 VAR_FN(8, 16, 8, 7)
154 VAR_FN(8, 8, 8, 6)
155 
156 #if !CONFIG_REALTIME_ONLY
157 VAR_FN(8, 32, 8, 8)
158 VAR_FN(32, 8, 8, 8)
159 VAR_FN(16, 64, 16, 10)
160 VAR_FN(64, 16, 16, 10)
161 #endif  // !CONFIG_REALTIME_ONLY
162 
163 #undef VAR_FN
164 
aom_highbd_8_mse16x16_sse2(const uint8_t * src8,int src_stride,const uint8_t * ref8,int ref_stride,unsigned int * sse)165 unsigned int aom_highbd_8_mse16x16_sse2(const uint8_t *src8, int src_stride,
166                                         const uint8_t *ref8, int ref_stride,
167                                         unsigned int *sse) {
168   int sum;
169   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
170   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
171   highbd_8_variance_sse2(src, src_stride, ref, ref_stride, 16, 16, sse, &sum,
172                          aom_highbd_calc16x16var_sse2, 16);
173   return *sse;
174 }
175 
aom_highbd_10_mse16x16_sse2(const uint8_t * src8,int src_stride,const uint8_t * ref8,int ref_stride,unsigned int * sse)176 unsigned int aom_highbd_10_mse16x16_sse2(const uint8_t *src8, int src_stride,
177                                          const uint8_t *ref8, int ref_stride,
178                                          unsigned int *sse) {
179   int sum;
180   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
181   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
182   highbd_10_variance_sse2(src, src_stride, ref, ref_stride, 16, 16, sse, &sum,
183                           aom_highbd_calc16x16var_sse2, 16);
184   return *sse;
185 }
186 
aom_highbd_12_mse16x16_sse2(const uint8_t * src8,int src_stride,const uint8_t * ref8,int ref_stride,unsigned int * sse)187 unsigned int aom_highbd_12_mse16x16_sse2(const uint8_t *src8, int src_stride,
188                                          const uint8_t *ref8, int ref_stride,
189                                          unsigned int *sse) {
190   int sum;
191   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
192   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
193   highbd_12_variance_sse2(src, src_stride, ref, ref_stride, 16, 16, sse, &sum,
194                           aom_highbd_calc16x16var_sse2, 16);
195   return *sse;
196 }
197 
aom_highbd_8_mse8x8_sse2(const uint8_t * src8,int src_stride,const uint8_t * ref8,int ref_stride,unsigned int * sse)198 unsigned int aom_highbd_8_mse8x8_sse2(const uint8_t *src8, int src_stride,
199                                       const uint8_t *ref8, int ref_stride,
200                                       unsigned int *sse) {
201   int sum;
202   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
203   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
204   highbd_8_variance_sse2(src, src_stride, ref, ref_stride, 8, 8, sse, &sum,
205                          aom_highbd_calc8x8var_sse2, 8);
206   return *sse;
207 }
208 
aom_highbd_10_mse8x8_sse2(const uint8_t * src8,int src_stride,const uint8_t * ref8,int ref_stride,unsigned int * sse)209 unsigned int aom_highbd_10_mse8x8_sse2(const uint8_t *src8, int src_stride,
210                                        const uint8_t *ref8, int ref_stride,
211                                        unsigned int *sse) {
212   int sum;
213   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
214   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
215   highbd_10_variance_sse2(src, src_stride, ref, ref_stride, 8, 8, sse, &sum,
216                           aom_highbd_calc8x8var_sse2, 8);
217   return *sse;
218 }
219 
aom_highbd_12_mse8x8_sse2(const uint8_t * src8,int src_stride,const uint8_t * ref8,int ref_stride,unsigned int * sse)220 unsigned int aom_highbd_12_mse8x8_sse2(const uint8_t *src8, int src_stride,
221                                        const uint8_t *ref8, int ref_stride,
222                                        unsigned int *sse) {
223   int sum;
224   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
225   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
226   highbd_12_variance_sse2(src, src_stride, ref, ref_stride, 8, 8, sse, &sum,
227                           aom_highbd_calc8x8var_sse2, 8);
228   return *sse;
229 }
230 
231 // The 2 unused parameters are place holders for PIC enabled build.
232 // These definitions are for functions defined in
233 // highbd_subpel_variance_impl_sse2.asm
234 #define DECL(w, opt)                                                         \
235   int aom_highbd_sub_pixel_variance##w##xh_##opt(                            \
236       const uint16_t *src, ptrdiff_t src_stride, int x_offset, int y_offset, \
237       const uint16_t *dst, ptrdiff_t dst_stride, int height,                 \
238       unsigned int *sse, void *unused0, void *unused);
239 #define DECLS(opt) \
240   DECL(8, opt)     \
241   DECL(16, opt)
242 
243 DECLS(sse2)
244 
245 #undef DECLS
246 #undef DECL
247 
248 #define FN(w, h, wf, wlog2, hlog2, opt, cast)                                  \
249   uint32_t aom_highbd_8_sub_pixel_variance##w##x##h##_##opt(                   \
250       const uint8_t *src8, int src_stride, int x_offset, int y_offset,         \
251       const uint8_t *dst8, int dst_stride, uint32_t *sse_ptr) {                \
252     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                 \
253     uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);                                 \
254     int se = 0;                                                                \
255     unsigned int sse = 0;                                                      \
256     unsigned int sse2;                                                         \
257     int row_rep = (w > 64) ? 2 : 1;                                            \
258     for (int wd_64 = 0; wd_64 < row_rep; wd_64++) {                            \
259       src += wd_64 * 64;                                                       \
260       dst += wd_64 * 64;                                                       \
261       int se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                   \
262           src, src_stride, x_offset, y_offset, dst, dst_stride, h, &sse2,      \
263           NULL, NULL);                                                         \
264       se += se2;                                                               \
265       sse += sse2;                                                             \
266       if (w > wf) {                                                            \
267         se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                     \
268             src + wf, src_stride, x_offset, y_offset, dst + wf, dst_stride, h, \
269             &sse2, NULL, NULL);                                                \
270         se += se2;                                                             \
271         sse += sse2;                                                           \
272         if (w > wf * 2) {                                                      \
273           se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                   \
274               src + 2 * wf, src_stride, x_offset, y_offset, dst + 2 * wf,      \
275               dst_stride, h, &sse2, NULL, NULL);                               \
276           se += se2;                                                           \
277           sse += sse2;                                                         \
278           se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                   \
279               src + 3 * wf, src_stride, x_offset, y_offset, dst + 3 * wf,      \
280               dst_stride, h, &sse2, NULL, NULL);                               \
281           se += se2;                                                           \
282           sse += sse2;                                                         \
283         }                                                                      \
284       }                                                                        \
285     }                                                                          \
286     *sse_ptr = sse;                                                            \
287     return sse - (uint32_t)((cast se * se) >> (wlog2 + hlog2));                \
288   }                                                                            \
289                                                                                \
290   uint32_t aom_highbd_10_sub_pixel_variance##w##x##h##_##opt(                  \
291       const uint8_t *src8, int src_stride, int x_offset, int y_offset,         \
292       const uint8_t *dst8, int dst_stride, uint32_t *sse_ptr) {                \
293     int64_t var;                                                               \
294     uint32_t sse;                                                              \
295     uint64_t long_sse = 0;                                                     \
296     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                 \
297     uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);                                 \
298     int se = 0;                                                                \
299     int row_rep = (w > 64) ? 2 : 1;                                            \
300     for (int wd_64 = 0; wd_64 < row_rep; wd_64++) {                            \
301       src += wd_64 * 64;                                                       \
302       dst += wd_64 * 64;                                                       \
303       int se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                   \
304           src, src_stride, x_offset, y_offset, dst, dst_stride, h, &sse, NULL, \
305           NULL);                                                               \
306       se += se2;                                                               \
307       long_sse += sse;                                                         \
308       if (w > wf) {                                                            \
309         uint32_t sse2;                                                         \
310         se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                     \
311             src + wf, src_stride, x_offset, y_offset, dst + wf, dst_stride, h, \
312             &sse2, NULL, NULL);                                                \
313         se += se2;                                                             \
314         long_sse += sse2;                                                      \
315         if (w > wf * 2) {                                                      \
316           se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                   \
317               src + 2 * wf, src_stride, x_offset, y_offset, dst + 2 * wf,      \
318               dst_stride, h, &sse2, NULL, NULL);                               \
319           se += se2;                                                           \
320           long_sse += sse2;                                                    \
321           se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                   \
322               src + 3 * wf, src_stride, x_offset, y_offset, dst + 3 * wf,      \
323               dst_stride, h, &sse2, NULL, NULL);                               \
324           se += se2;                                                           \
325           long_sse += sse2;                                                    \
326         }                                                                      \
327       }                                                                        \
328     }                                                                          \
329     se = ROUND_POWER_OF_TWO(se, 2);                                            \
330     sse = (uint32_t)ROUND_POWER_OF_TWO(long_sse, 4);                           \
331     *sse_ptr = sse;                                                            \
332     var = (int64_t)(sse) - ((cast se * se) >> (wlog2 + hlog2));                \
333     return (var >= 0) ? (uint32_t)var : 0;                                     \
334   }                                                                            \
335                                                                                \
336   uint32_t aom_highbd_12_sub_pixel_variance##w##x##h##_##opt(                  \
337       const uint8_t *src8, int src_stride, int x_offset, int y_offset,         \
338       const uint8_t *dst8, int dst_stride, uint32_t *sse_ptr) {                \
339     int start_row;                                                             \
340     uint32_t sse;                                                              \
341     int se = 0;                                                                \
342     int64_t var;                                                               \
343     uint64_t long_sse = 0;                                                     \
344     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                 \
345     uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);                                 \
346     int row_rep = (w > 64) ? 2 : 1;                                            \
347     for (start_row = 0; start_row < h; start_row += 16) {                      \
348       uint32_t sse2;                                                           \
349       int height = h - start_row < 16 ? h - start_row : 16;                    \
350       uint16_t *src_tmp = src + (start_row * src_stride);                      \
351       uint16_t *dst_tmp = dst + (start_row * dst_stride);                      \
352       for (int wd_64 = 0; wd_64 < row_rep; wd_64++) {                          \
353         src_tmp += wd_64 * 64;                                                 \
354         dst_tmp += wd_64 * 64;                                                 \
355         int se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                 \
356             src_tmp, src_stride, x_offset, y_offset, dst_tmp, dst_stride,      \
357             height, &sse2, NULL, NULL);                                        \
358         se += se2;                                                             \
359         long_sse += sse2;                                                      \
360         if (w > wf) {                                                          \
361           se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                   \
362               src_tmp + wf, src_stride, x_offset, y_offset, dst_tmp + wf,      \
363               dst_stride, height, &sse2, NULL, NULL);                          \
364           se += se2;                                                           \
365           long_sse += sse2;                                                    \
366           if (w > wf * 2) {                                                    \
367             se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                 \
368                 src_tmp + 2 * wf, src_stride, x_offset, y_offset,              \
369                 dst_tmp + 2 * wf, dst_stride, height, &sse2, NULL, NULL);      \
370             se += se2;                                                         \
371             long_sse += sse2;                                                  \
372             se2 = aom_highbd_sub_pixel_variance##wf##xh_##opt(                 \
373                 src_tmp + 3 * wf, src_stride, x_offset, y_offset,              \
374                 dst_tmp + 3 * wf, dst_stride, height, &sse2, NULL, NULL);      \
375             se += se2;                                                         \
376             long_sse += sse2;                                                  \
377           }                                                                    \
378         }                                                                      \
379       }                                                                        \
380     }                                                                          \
381     se = ROUND_POWER_OF_TWO(se, 4);                                            \
382     sse = (uint32_t)ROUND_POWER_OF_TWO(long_sse, 8);                           \
383     *sse_ptr = sse;                                                            \
384     var = (int64_t)(sse) - ((cast se * se) >> (wlog2 + hlog2));                \
385     return (var >= 0) ? (uint32_t)var : 0;                                     \
386   }
387 
388 #if CONFIG_REALTIME_ONLY
389 #define FNS(opt)                         \
390   FN(128, 128, 16, 7, 7, opt, (int64_t)) \
391   FN(128, 64, 16, 7, 6, opt, (int64_t))  \
392   FN(64, 128, 16, 6, 7, opt, (int64_t))  \
393   FN(64, 64, 16, 6, 6, opt, (int64_t))   \
394   FN(64, 32, 16, 6, 5, opt, (int64_t))   \
395   FN(32, 64, 16, 5, 6, opt, (int64_t))   \
396   FN(32, 32, 16, 5, 5, opt, (int64_t))   \
397   FN(32, 16, 16, 5, 4, opt, (int64_t))   \
398   FN(16, 32, 16, 4, 5, opt, (int64_t))   \
399   FN(16, 16, 16, 4, 4, opt, (int64_t))   \
400   FN(16, 8, 16, 4, 3, opt, (int64_t))    \
401   FN(8, 16, 8, 3, 4, opt, (int64_t))     \
402   FN(8, 8, 8, 3, 3, opt, (int64_t))      \
403   FN(8, 4, 8, 3, 2, opt, (int64_t))
404 #else  // !CONFIG_REALTIME_ONLY
405 #define FNS(opt)                         \
406   FN(128, 128, 16, 7, 7, opt, (int64_t)) \
407   FN(128, 64, 16, 7, 6, opt, (int64_t))  \
408   FN(64, 128, 16, 6, 7, opt, (int64_t))  \
409   FN(64, 64, 16, 6, 6, opt, (int64_t))   \
410   FN(64, 32, 16, 6, 5, opt, (int64_t))   \
411   FN(32, 64, 16, 5, 6, opt, (int64_t))   \
412   FN(32, 32, 16, 5, 5, opt, (int64_t))   \
413   FN(32, 16, 16, 5, 4, opt, (int64_t))   \
414   FN(16, 32, 16, 4, 5, opt, (int64_t))   \
415   FN(16, 16, 16, 4, 4, opt, (int64_t))   \
416   FN(16, 8, 16, 4, 3, opt, (int64_t))    \
417   FN(8, 16, 8, 3, 4, opt, (int64_t))     \
418   FN(8, 8, 8, 3, 3, opt, (int64_t))      \
419   FN(8, 4, 8, 3, 2, opt, (int64_t))      \
420   FN(16, 4, 16, 4, 2, opt, (int64_t))    \
421   FN(8, 32, 8, 3, 5, opt, (int64_t))     \
422   FN(32, 8, 16, 5, 3, opt, (int64_t))    \
423   FN(16, 64, 16, 4, 6, opt, (int64_t))   \
424   FN(64, 16, 16, 6, 4, opt, (int64_t))
425 #endif  // CONFIG_REALTIME_ONLY
426 
FNS(sse2)427 FNS(sse2)
428 
429 #undef FNS
430 #undef FN
431 
432 // The 2 unused parameters are place holders for PIC enabled build.
433 #define DECL(w, opt)                                                         \
434   int aom_highbd_sub_pixel_avg_variance##w##xh_##opt(                        \
435       const uint16_t *src, ptrdiff_t src_stride, int x_offset, int y_offset, \
436       const uint16_t *dst, ptrdiff_t dst_stride, const uint16_t *sec,        \
437       ptrdiff_t sec_stride, int height, unsigned int *sse, void *unused0,    \
438       void *unused);
439 #define DECLS(opt) \
440   DECL(16, opt)    \
441   DECL(8, opt)
442 
443 DECLS(sse2)
444 #undef DECL
445 #undef DECLS
446 
447 #define FN(w, h, wf, wlog2, hlog2, opt, cast)                                  \
448   uint32_t aom_highbd_8_sub_pixel_avg_variance##w##x##h##_##opt(               \
449       const uint8_t *src8, int src_stride, int x_offset, int y_offset,         \
450       const uint8_t *dst8, int dst_stride, uint32_t *sse_ptr,                  \
451       const uint8_t *sec8) {                                                   \
452     uint32_t sse;                                                              \
453     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                 \
454     uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);                                 \
455     uint16_t *sec = CONVERT_TO_SHORTPTR(sec8);                                 \
456     int se = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(                  \
457         src, src_stride, x_offset, y_offset, dst, dst_stride, sec, w, h, &sse, \
458         NULL, NULL);                                                           \
459     if (w > wf) {                                                              \
460       uint32_t sse2;                                                           \
461       int se2 = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(               \
462           src + wf, src_stride, x_offset, y_offset, dst + wf, dst_stride,      \
463           sec + wf, w, h, &sse2, NULL, NULL);                                  \
464       se += se2;                                                               \
465       sse += sse2;                                                             \
466       if (w > wf * 2) {                                                        \
467         se2 = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(                 \
468             src + 2 * wf, src_stride, x_offset, y_offset, dst + 2 * wf,        \
469             dst_stride, sec + 2 * wf, w, h, &sse2, NULL, NULL);                \
470         se += se2;                                                             \
471         sse += sse2;                                                           \
472         se2 = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(                 \
473             src + 3 * wf, src_stride, x_offset, y_offset, dst + 3 * wf,        \
474             dst_stride, sec + 3 * wf, w, h, &sse2, NULL, NULL);                \
475         se += se2;                                                             \
476         sse += sse2;                                                           \
477       }                                                                        \
478     }                                                                          \
479     *sse_ptr = sse;                                                            \
480     return sse - (uint32_t)((cast se * se) >> (wlog2 + hlog2));                \
481   }                                                                            \
482                                                                                \
483   uint32_t aom_highbd_10_sub_pixel_avg_variance##w##x##h##_##opt(              \
484       const uint8_t *src8, int src_stride, int x_offset, int y_offset,         \
485       const uint8_t *dst8, int dst_stride, uint32_t *sse_ptr,                  \
486       const uint8_t *sec8) {                                                   \
487     int64_t var;                                                               \
488     uint32_t sse;                                                              \
489     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                 \
490     uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);                                 \
491     uint16_t *sec = CONVERT_TO_SHORTPTR(sec8);                                 \
492     int se = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(                  \
493         src, src_stride, x_offset, y_offset, dst, dst_stride, sec, w, h, &sse, \
494         NULL, NULL);                                                           \
495     if (w > wf) {                                                              \
496       uint32_t sse2;                                                           \
497       int se2 = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(               \
498           src + wf, src_stride, x_offset, y_offset, dst + wf, dst_stride,      \
499           sec + wf, w, h, &sse2, NULL, NULL);                                  \
500       se += se2;                                                               \
501       sse += sse2;                                                             \
502       if (w > wf * 2) {                                                        \
503         se2 = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(                 \
504             src + 2 * wf, src_stride, x_offset, y_offset, dst + 2 * wf,        \
505             dst_stride, sec + 2 * wf, w, h, &sse2, NULL, NULL);                \
506         se += se2;                                                             \
507         sse += sse2;                                                           \
508         se2 = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(                 \
509             src + 3 * wf, src_stride, x_offset, y_offset, dst + 3 * wf,        \
510             dst_stride, sec + 3 * wf, w, h, &sse2, NULL, NULL);                \
511         se += se2;                                                             \
512         sse += sse2;                                                           \
513       }                                                                        \
514     }                                                                          \
515     se = ROUND_POWER_OF_TWO(se, 2);                                            \
516     sse = ROUND_POWER_OF_TWO(sse, 4);                                          \
517     *sse_ptr = sse;                                                            \
518     var = (int64_t)(sse) - ((cast se * se) >> (wlog2 + hlog2));                \
519     return (var >= 0) ? (uint32_t)var : 0;                                     \
520   }                                                                            \
521                                                                                \
522   uint32_t aom_highbd_12_sub_pixel_avg_variance##w##x##h##_##opt(              \
523       const uint8_t *src8, int src_stride, int x_offset, int y_offset,         \
524       const uint8_t *dst8, int dst_stride, uint32_t *sse_ptr,                  \
525       const uint8_t *sec8) {                                                   \
526     int start_row;                                                             \
527     int64_t var;                                                               \
528     uint32_t sse;                                                              \
529     int se = 0;                                                                \
530     uint64_t long_sse = 0;                                                     \
531     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                 \
532     uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);                                 \
533     uint16_t *sec = CONVERT_TO_SHORTPTR(sec8);                                 \
534     for (start_row = 0; start_row < h; start_row += 16) {                      \
535       uint32_t sse2;                                                           \
536       int height = h - start_row < 16 ? h - start_row : 16;                    \
537       int se2 = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(               \
538           src + (start_row * src_stride), src_stride, x_offset, y_offset,      \
539           dst + (start_row * dst_stride), dst_stride, sec + (start_row * w),   \
540           w, height, &sse2, NULL, NULL);                                       \
541       se += se2;                                                               \
542       long_sse += sse2;                                                        \
543       if (w > wf) {                                                            \
544         se2 = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(                 \
545             src + wf + (start_row * src_stride), src_stride, x_offset,         \
546             y_offset, dst + wf + (start_row * dst_stride), dst_stride,         \
547             sec + wf + (start_row * w), w, height, &sse2, NULL, NULL);         \
548         se += se2;                                                             \
549         long_sse += sse2;                                                      \
550         if (w > wf * 2) {                                                      \
551           se2 = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(               \
552               src + 2 * wf + (start_row * src_stride), src_stride, x_offset,   \
553               y_offset, dst + 2 * wf + (start_row * dst_stride), dst_stride,   \
554               sec + 2 * wf + (start_row * w), w, height, &sse2, NULL, NULL);   \
555           se += se2;                                                           \
556           long_sse += sse2;                                                    \
557           se2 = aom_highbd_sub_pixel_avg_variance##wf##xh_##opt(               \
558               src + 3 * wf + (start_row * src_stride), src_stride, x_offset,   \
559               y_offset, dst + 3 * wf + (start_row * dst_stride), dst_stride,   \
560               sec + 3 * wf + (start_row * w), w, height, &sse2, NULL, NULL);   \
561           se += se2;                                                           \
562           long_sse += sse2;                                                    \
563         }                                                                      \
564       }                                                                        \
565     }                                                                          \
566     se = ROUND_POWER_OF_TWO(se, 4);                                            \
567     sse = (uint32_t)ROUND_POWER_OF_TWO(long_sse, 8);                           \
568     *sse_ptr = sse;                                                            \
569     var = (int64_t)(sse) - ((cast se * se) >> (wlog2 + hlog2));                \
570     return (var >= 0) ? (uint32_t)var : 0;                                     \
571   }
572 
573 #if CONFIG_REALTIME_ONLY
574 #define FNS(opt)                       \
575   FN(64, 64, 16, 6, 6, opt, (int64_t)) \
576   FN(64, 32, 16, 6, 5, opt, (int64_t)) \
577   FN(32, 64, 16, 5, 6, opt, (int64_t)) \
578   FN(32, 32, 16, 5, 5, opt, (int64_t)) \
579   FN(32, 16, 16, 5, 4, opt, (int64_t)) \
580   FN(16, 32, 16, 4, 5, opt, (int64_t)) \
581   FN(16, 16, 16, 4, 4, opt, (int64_t)) \
582   FN(16, 8, 16, 4, 3, opt, (int64_t))  \
583   FN(8, 16, 8, 3, 4, opt, (int64_t))   \
584   FN(8, 8, 8, 3, 3, opt, (int64_t))    \
585   FN(8, 4, 8, 3, 2, opt, (int64_t))
586 #else  // !CONFIG_REALTIME_ONLY
587 #define FNS(opt)                       \
588   FN(64, 64, 16, 6, 6, opt, (int64_t)) \
589   FN(64, 32, 16, 6, 5, opt, (int64_t)) \
590   FN(32, 64, 16, 5, 6, opt, (int64_t)) \
591   FN(32, 32, 16, 5, 5, opt, (int64_t)) \
592   FN(32, 16, 16, 5, 4, opt, (int64_t)) \
593   FN(16, 32, 16, 4, 5, opt, (int64_t)) \
594   FN(16, 16, 16, 4, 4, opt, (int64_t)) \
595   FN(16, 8, 16, 4, 3, opt, (int64_t))  \
596   FN(8, 16, 8, 3, 4, opt, (int64_t))   \
597   FN(8, 8, 8, 3, 3, opt, (int64_t))    \
598   FN(8, 4, 8, 3, 2, opt, (int64_t))    \
599   FN(16, 4, 16, 4, 2, opt, (int64_t))  \
600   FN(8, 32, 8, 3, 5, opt, (int64_t))   \
601   FN(32, 8, 16, 5, 3, opt, (int64_t))  \
602   FN(16, 64, 16, 4, 6, opt, (int64_t)) \
603   FN(64, 16, 16, 6, 4, opt, (int64_t))
604 #endif  // CONFIG_REALTIME_ONLY
605 
606 FNS(sse2)
607 
608 #undef FNS
609 #undef FN
610 
611 static inline void highbd_compute_dist_wtd_comp_avg(__m128i *p0, __m128i *p1,
612                                                     const __m128i *w0,
613                                                     const __m128i *w1,
614                                                     const __m128i *r,
615                                                     void *const result) {
616   assert(DIST_PRECISION_BITS <= 4);
617   __m128i mult0 = _mm_mullo_epi16(*p0, *w0);
618   __m128i mult1 = _mm_mullo_epi16(*p1, *w1);
619   __m128i sum = _mm_adds_epu16(mult0, mult1);
620   __m128i round = _mm_adds_epu16(sum, *r);
621   __m128i shift = _mm_srli_epi16(round, DIST_PRECISION_BITS);
622 
623   xx_storeu_128(result, shift);
624 }
625 
aom_highbd_dist_wtd_comp_avg_pred_sse2(uint8_t * comp_pred8,const uint8_t * pred8,int width,int height,const uint8_t * ref8,int ref_stride,const DIST_WTD_COMP_PARAMS * jcp_param)626 void aom_highbd_dist_wtd_comp_avg_pred_sse2(
627     uint8_t *comp_pred8, const uint8_t *pred8, int width, int height,
628     const uint8_t *ref8, int ref_stride,
629     const DIST_WTD_COMP_PARAMS *jcp_param) {
630   int i;
631   const int16_t wt0 = (int16_t)jcp_param->fwd_offset;
632   const int16_t wt1 = (int16_t)jcp_param->bck_offset;
633   const __m128i w0 = _mm_set1_epi16(wt0);
634   const __m128i w1 = _mm_set1_epi16(wt1);
635   const int16_t round = (int16_t)((1 << DIST_PRECISION_BITS) >> 1);
636   const __m128i r = _mm_set1_epi16(round);
637   uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
638   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
639   uint16_t *comp_pred = CONVERT_TO_SHORTPTR(comp_pred8);
640 
641   if (width >= 8) {
642     // Read 8 pixels one row at a time
643     assert(!(width & 7));
644     for (i = 0; i < height; ++i) {
645       int j;
646       for (j = 0; j < width; j += 8) {
647         __m128i p0 = xx_loadu_128(ref);
648         __m128i p1 = xx_loadu_128(pred);
649 
650         highbd_compute_dist_wtd_comp_avg(&p0, &p1, &w0, &w1, &r, comp_pred);
651 
652         comp_pred += 8;
653         pred += 8;
654         ref += 8;
655       }
656       ref += ref_stride - width;
657     }
658   } else {
659     // Read 4 pixels two rows at a time
660     assert(!(width & 3));
661     for (i = 0; i < height; i += 2) {
662       __m128i p0_0 = xx_loadl_64(ref + 0 * ref_stride);
663       __m128i p0_1 = xx_loadl_64(ref + 1 * ref_stride);
664       __m128i p0 = _mm_unpacklo_epi64(p0_0, p0_1);
665       __m128i p1 = xx_loadu_128(pred);
666 
667       highbd_compute_dist_wtd_comp_avg(&p0, &p1, &w0, &w1, &r, comp_pred);
668 
669       comp_pred += 8;
670       pred += 8;
671       ref += 2 * ref_stride;
672     }
673   }
674 }
675 
mse_4xh_16bit_highbd_sse2(uint16_t * dst,int dstride,uint16_t * src,int sstride,int h)676 static uint64_t mse_4xh_16bit_highbd_sse2(uint16_t *dst, int dstride,
677                                           uint16_t *src, int sstride, int h) {
678   uint64_t sum = 0;
679   __m128i reg0_4x16, reg1_4x16;
680   __m128i src_8x16;
681   __m128i dst_8x16;
682   __m128i res0_4x32, res1_4x32, res0_4x64, res1_4x64, res2_4x64, res3_4x64;
683   __m128i sub_result_8x16;
684   const __m128i zeros = _mm_setzero_si128();
685   __m128i square_result = _mm_setzero_si128();
686   for (int i = 0; i < h; i += 2) {
687     reg0_4x16 = _mm_loadl_epi64((__m128i const *)(&dst[(i + 0) * dstride]));
688     reg1_4x16 = _mm_loadl_epi64((__m128i const *)(&dst[(i + 1) * dstride]));
689     dst_8x16 = _mm_unpacklo_epi64(reg0_4x16, reg1_4x16);
690 
691     reg0_4x16 = _mm_loadl_epi64((__m128i const *)(&src[(i + 0) * sstride]));
692     reg1_4x16 = _mm_loadl_epi64((__m128i const *)(&src[(i + 1) * sstride]));
693     src_8x16 = _mm_unpacklo_epi64(reg0_4x16, reg1_4x16);
694 
695     sub_result_8x16 = _mm_sub_epi16(src_8x16, dst_8x16);
696 
697     res0_4x32 = _mm_unpacklo_epi16(sub_result_8x16, zeros);
698     res1_4x32 = _mm_unpackhi_epi16(sub_result_8x16, zeros);
699 
700     res0_4x32 = _mm_madd_epi16(res0_4x32, res0_4x32);
701     res1_4x32 = _mm_madd_epi16(res1_4x32, res1_4x32);
702 
703     res0_4x64 = _mm_unpacklo_epi32(res0_4x32, zeros);
704     res1_4x64 = _mm_unpackhi_epi32(res0_4x32, zeros);
705     res2_4x64 = _mm_unpacklo_epi32(res1_4x32, zeros);
706     res3_4x64 = _mm_unpackhi_epi32(res1_4x32, zeros);
707 
708     square_result = _mm_add_epi64(
709         square_result,
710         _mm_add_epi64(
711             _mm_add_epi64(_mm_add_epi64(res0_4x64, res1_4x64), res2_4x64),
712             res3_4x64));
713   }
714 
715   const __m128i sum_1x64 =
716       _mm_add_epi64(square_result, _mm_srli_si128(square_result, 8));
717   xx_storel_64(&sum, sum_1x64);
718   return sum;
719 }
720 
mse_8xh_16bit_highbd_sse2(uint16_t * dst,int dstride,uint16_t * src,int sstride,int h)721 static uint64_t mse_8xh_16bit_highbd_sse2(uint16_t *dst, int dstride,
722                                           uint16_t *src, int sstride, int h) {
723   uint64_t sum = 0;
724   __m128i src_8x16;
725   __m128i dst_8x16;
726   __m128i res0_4x32, res1_4x32, res0_4x64, res1_4x64, res2_4x64, res3_4x64;
727   __m128i sub_result_8x16;
728   const __m128i zeros = _mm_setzero_si128();
729   __m128i square_result = _mm_setzero_si128();
730 
731   for (int i = 0; i < h; i++) {
732     dst_8x16 = _mm_loadu_si128((__m128i *)&dst[i * dstride]);
733     src_8x16 = _mm_loadu_si128((__m128i *)&src[i * sstride]);
734 
735     sub_result_8x16 = _mm_sub_epi16(src_8x16, dst_8x16);
736 
737     res0_4x32 = _mm_unpacklo_epi16(sub_result_8x16, zeros);
738     res1_4x32 = _mm_unpackhi_epi16(sub_result_8x16, zeros);
739 
740     res0_4x32 = _mm_madd_epi16(res0_4x32, res0_4x32);
741     res1_4x32 = _mm_madd_epi16(res1_4x32, res1_4x32);
742 
743     res0_4x64 = _mm_unpacklo_epi32(res0_4x32, zeros);
744     res1_4x64 = _mm_unpackhi_epi32(res0_4x32, zeros);
745     res2_4x64 = _mm_unpacklo_epi32(res1_4x32, zeros);
746     res3_4x64 = _mm_unpackhi_epi32(res1_4x32, zeros);
747 
748     square_result = _mm_add_epi64(
749         square_result,
750         _mm_add_epi64(
751             _mm_add_epi64(_mm_add_epi64(res0_4x64, res1_4x64), res2_4x64),
752             res3_4x64));
753   }
754 
755   const __m128i sum_1x64 =
756       _mm_add_epi64(square_result, _mm_srli_si128(square_result, 8));
757   xx_storel_64(&sum, sum_1x64);
758   return sum;
759 }
760 
aom_mse_wxh_16bit_highbd_sse2(uint16_t * dst,int dstride,uint16_t * src,int sstride,int w,int h)761 uint64_t aom_mse_wxh_16bit_highbd_sse2(uint16_t *dst, int dstride,
762                                        uint16_t *src, int sstride, int w,
763                                        int h) {
764   assert((w == 8 || w == 4) && (h == 8 || h == 4) &&
765          "w=8/4 and h=8/4 must satisfy");
766   switch (w) {
767     case 4: return mse_4xh_16bit_highbd_sse2(dst, dstride, src, sstride, h);
768     case 8: return mse_8xh_16bit_highbd_sse2(dst, dstride, src, sstride, h);
769     default: assert(0 && "unsupported width"); return -1;
770   }
771 }
772