xref: /aosp_15_r20/external/libvpx/vpx_dsp/x86/highbd_variance_sse2.c (revision fb1b10ab9aebc7c7068eedab379b749d7e3900be)
1 /*
2  *  Copyright (c) 2014 The WebM project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 #include <emmintrin.h>  // SSE2
11 
12 #include "./vpx_config.h"
13 #include "./vpx_dsp_rtcd.h"
14 #include "vpx_ports/mem.h"
15 
16 typedef uint32_t (*high_variance_fn_t)(const uint16_t *src, int src_stride,
17                                        const uint16_t *ref, int ref_stride,
18                                        uint32_t *sse, int *sum);
19 
20 uint32_t vpx_highbd_calc8x8var_sse2(const uint16_t *src, int src_stride,
21                                     const uint16_t *ref, int ref_stride,
22                                     uint32_t *sse, int *sum);
23 
24 uint32_t vpx_highbd_calc16x16var_sse2(const uint16_t *src, int src_stride,
25                                       const uint16_t *ref, int ref_stride,
26                                       uint32_t *sse, int *sum);
27 
highbd_8_variance_sse2(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int w,int h,uint32_t * sse,int * sum,high_variance_fn_t var_fn,int block_size)28 static void highbd_8_variance_sse2(const uint16_t *src, int src_stride,
29                                    const uint16_t *ref, int ref_stride, int w,
30                                    int h, uint32_t *sse, int *sum,
31                                    high_variance_fn_t var_fn, int block_size) {
32   int i, j;
33 
34   *sse = 0;
35   *sum = 0;
36 
37   for (i = 0; i < h; i += block_size) {
38     for (j = 0; j < w; j += block_size) {
39       unsigned int sse0;
40       int sum0;
41       var_fn(src + src_stride * i + j, src_stride, ref + ref_stride * i + j,
42              ref_stride, &sse0, &sum0);
43       *sse += sse0;
44       *sum += sum0;
45     }
46   }
47 }
48 
highbd_10_variance_sse2(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int w,int h,uint32_t * sse,int * sum,high_variance_fn_t var_fn,int block_size)49 static void highbd_10_variance_sse2(const uint16_t *src, int src_stride,
50                                     const uint16_t *ref, int ref_stride, int w,
51                                     int h, uint32_t *sse, int *sum,
52                                     high_variance_fn_t var_fn, int block_size) {
53   int i, j;
54   uint64_t sse_long = 0;
55   int32_t sum_long = 0;
56 
57   for (i = 0; i < h; i += block_size) {
58     for (j = 0; j < w; j += block_size) {
59       unsigned int sse0;
60       int sum0;
61       var_fn(src + src_stride * i + j, src_stride, ref + ref_stride * i + j,
62              ref_stride, &sse0, &sum0);
63       sse_long += sse0;
64       sum_long += sum0;
65     }
66   }
67   *sum = ROUND_POWER_OF_TWO(sum_long, 2);
68   *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 4);
69 }
70 
highbd_12_variance_sse2(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int w,int h,uint32_t * sse,int * sum,high_variance_fn_t var_fn,int block_size)71 static void highbd_12_variance_sse2(const uint16_t *src, int src_stride,
72                                     const uint16_t *ref, int ref_stride, int w,
73                                     int h, uint32_t *sse, int *sum,
74                                     high_variance_fn_t var_fn, int block_size) {
75   int i, j;
76   uint64_t sse_long = 0;
77   int32_t sum_long = 0;
78 
79   for (i = 0; i < h; i += block_size) {
80     for (j = 0; j < w; j += block_size) {
81       unsigned int sse0;
82       int sum0;
83       var_fn(src + src_stride * i + j, src_stride, ref + ref_stride * i + j,
84              ref_stride, &sse0, &sum0);
85       sse_long += sse0;
86       sum_long += sum0;
87     }
88   }
89   *sum = ROUND_POWER_OF_TWO(sum_long, 4);
90   *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 8);
91 }
92 
93 #define HIGH_GET_VAR(S)                                                       \
94   void vpx_highbd_8_get##S##x##S##var_sse2(                                   \
95       const uint8_t *src8, int src_stride, const uint8_t *ref8,               \
96       int ref_stride, uint32_t *sse, int *sum) {                              \
97     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                \
98     uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                                \
99     vpx_highbd_calc##S##x##S##var_sse2(src, src_stride, ref, ref_stride, sse, \
100                                        sum);                                  \
101   }                                                                           \
102                                                                               \
103   void vpx_highbd_10_get##S##x##S##var_sse2(                                  \
104       const uint8_t *src8, int src_stride, const uint8_t *ref8,               \
105       int ref_stride, uint32_t *sse, int *sum) {                              \
106     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                \
107     uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                                \
108     vpx_highbd_calc##S##x##S##var_sse2(src, src_stride, ref, ref_stride, sse, \
109                                        sum);                                  \
110     *sum = ROUND_POWER_OF_TWO(*sum, 2);                                       \
111     *sse = ROUND_POWER_OF_TWO(*sse, 4);                                       \
112   }                                                                           \
113                                                                               \
114   void vpx_highbd_12_get##S##x##S##var_sse2(                                  \
115       const uint8_t *src8, int src_stride, const uint8_t *ref8,               \
116       int ref_stride, uint32_t *sse, int *sum) {                              \
117     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                \
118     uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                                \
119     vpx_highbd_calc##S##x##S##var_sse2(src, src_stride, ref, ref_stride, sse, \
120                                        sum);                                  \
121     *sum = ROUND_POWER_OF_TWO(*sum, 4);                                       \
122     *sse = ROUND_POWER_OF_TWO(*sse, 8);                                       \
123   }
124 
125 HIGH_GET_VAR(16)
126 HIGH_GET_VAR(8)
127 
128 #undef HIGH_GET_VAR
129 
130 #define VAR_FN(w, h, block_size, shift)                                    \
131   uint32_t vpx_highbd_8_variance##w##x##h##_sse2(                          \
132       const uint8_t *src8, int src_stride, const uint8_t *ref8,            \
133       int ref_stride, uint32_t *sse) {                                     \
134     int sum;                                                               \
135     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                             \
136     uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                             \
137     highbd_8_variance_sse2(                                                \
138         src, src_stride, ref, ref_stride, w, h, sse, &sum,                 \
139         vpx_highbd_calc##block_size##x##block_size##var_sse2, block_size); \
140     return *sse - (uint32_t)(((int64_t)sum * sum) >> (shift));             \
141   }                                                                        \
142                                                                            \
143   uint32_t vpx_highbd_10_variance##w##x##h##_sse2(                         \
144       const uint8_t *src8, int src_stride, const uint8_t *ref8,            \
145       int ref_stride, uint32_t *sse) {                                     \
146     int sum;                                                               \
147     int64_t var;                                                           \
148     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                             \
149     uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                             \
150     highbd_10_variance_sse2(                                               \
151         src, src_stride, ref, ref_stride, w, h, sse, &sum,                 \
152         vpx_highbd_calc##block_size##x##block_size##var_sse2, block_size); \
153     var = (int64_t)(*sse) - (((int64_t)sum * sum) >> (shift));             \
154     return (var >= 0) ? (uint32_t)var : 0;                                 \
155   }                                                                        \
156                                                                            \
157   uint32_t vpx_highbd_12_variance##w##x##h##_sse2(                         \
158       const uint8_t *src8, int src_stride, const uint8_t *ref8,            \
159       int ref_stride, uint32_t *sse) {                                     \
160     int sum;                                                               \
161     int64_t var;                                                           \
162     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                             \
163     uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                             \
164     highbd_12_variance_sse2(                                               \
165         src, src_stride, ref, ref_stride, w, h, sse, &sum,                 \
166         vpx_highbd_calc##block_size##x##block_size##var_sse2, block_size); \
167     var = (int64_t)(*sse) - (((int64_t)sum * sum) >> (shift));             \
168     return (var >= 0) ? (uint32_t)var : 0;                                 \
169   }
170 
171 VAR_FN(64, 64, 16, 12)
172 VAR_FN(64, 32, 16, 11)
173 VAR_FN(32, 64, 16, 11)
174 VAR_FN(32, 32, 16, 10)
175 VAR_FN(32, 16, 16, 9)
176 VAR_FN(16, 32, 16, 9)
177 VAR_FN(16, 16, 16, 8)
178 VAR_FN(16, 8, 8, 7)
179 VAR_FN(8, 16, 8, 7)
180 VAR_FN(8, 8, 8, 6)
181 
182 #undef VAR_FN
183 
vpx_highbd_8_mse16x16_sse2(const uint8_t * src8,int src_stride,const uint8_t * ref8,int ref_stride,unsigned int * sse)184 unsigned int vpx_highbd_8_mse16x16_sse2(const uint8_t *src8, int src_stride,
185                                         const uint8_t *ref8, int ref_stride,
186                                         unsigned int *sse) {
187   int sum;
188   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
189   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
190   highbd_8_variance_sse2(src, src_stride, ref, ref_stride, 16, 16, sse, &sum,
191                          vpx_highbd_calc16x16var_sse2, 16);
192   return *sse;
193 }
194 
vpx_highbd_10_mse16x16_sse2(const uint8_t * src8,int src_stride,const uint8_t * ref8,int ref_stride,unsigned int * sse)195 unsigned int vpx_highbd_10_mse16x16_sse2(const uint8_t *src8, int src_stride,
196                                          const uint8_t *ref8, int ref_stride,
197                                          unsigned int *sse) {
198   int sum;
199   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
200   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
201   highbd_10_variance_sse2(src, src_stride, ref, ref_stride, 16, 16, sse, &sum,
202                           vpx_highbd_calc16x16var_sse2, 16);
203   return *sse;
204 }
205 
vpx_highbd_12_mse16x16_sse2(const uint8_t * src8,int src_stride,const uint8_t * ref8,int ref_stride,unsigned int * sse)206 unsigned int vpx_highbd_12_mse16x16_sse2(const uint8_t *src8, int src_stride,
207                                          const uint8_t *ref8, int ref_stride,
208                                          unsigned int *sse) {
209   int sum;
210   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
211   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
212   highbd_12_variance_sse2(src, src_stride, ref, ref_stride, 16, 16, sse, &sum,
213                           vpx_highbd_calc16x16var_sse2, 16);
214   return *sse;
215 }
216 
vpx_highbd_8_mse8x8_sse2(const uint8_t * src8,int src_stride,const uint8_t * ref8,int ref_stride,unsigned int * sse)217 unsigned int vpx_highbd_8_mse8x8_sse2(const uint8_t *src8, int src_stride,
218                                       const uint8_t *ref8, int ref_stride,
219                                       unsigned int *sse) {
220   int sum;
221   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
222   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
223   highbd_8_variance_sse2(src, src_stride, ref, ref_stride, 8, 8, sse, &sum,
224                          vpx_highbd_calc8x8var_sse2, 8);
225   return *sse;
226 }
227 
vpx_highbd_10_mse8x8_sse2(const uint8_t * src8,int src_stride,const uint8_t * ref8,int ref_stride,unsigned int * sse)228 unsigned int vpx_highbd_10_mse8x8_sse2(const uint8_t *src8, int src_stride,
229                                        const uint8_t *ref8, int ref_stride,
230                                        unsigned int *sse) {
231   int sum;
232   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
233   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
234   highbd_10_variance_sse2(src, src_stride, ref, ref_stride, 8, 8, sse, &sum,
235                           vpx_highbd_calc8x8var_sse2, 8);
236   return *sse;
237 }
238 
vpx_highbd_12_mse8x8_sse2(const uint8_t * src8,int src_stride,const uint8_t * ref8,int ref_stride,unsigned int * sse)239 unsigned int vpx_highbd_12_mse8x8_sse2(const uint8_t *src8, int src_stride,
240                                        const uint8_t *ref8, int ref_stride,
241                                        unsigned int *sse) {
242   int sum;
243   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
244   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
245   highbd_12_variance_sse2(src, src_stride, ref, ref_stride, 8, 8, sse, &sum,
246                           vpx_highbd_calc8x8var_sse2, 8);
247   return *sse;
248 }
249 
250 // The 2 unused parameters are place holders for PIC enabled build.
251 // These definitions are for functions defined in
252 // highbd_subpel_variance_impl_sse2.asm
253 #define DECL(w, opt)                                                         \
254   int vpx_highbd_sub_pixel_variance##w##xh_##opt(                            \
255       const uint16_t *src, ptrdiff_t src_stride, int x_offset, int y_offset, \
256       const uint16_t *ref, ptrdiff_t ref_stride, int height,                 \
257       unsigned int *sse, void *unused0, void *unused);
258 #define DECLS(opt) \
259   DECL(8, opt)     \
260   DECL(16, opt)
261 
262 DECLS(sse2)
263 
264 #undef DECLS
265 #undef DECL
266 
267 #define FN(w, h, wf, wlog2, hlog2, opt, cast)                                  \
268   uint32_t vpx_highbd_8_sub_pixel_variance##w##x##h##_##opt(                   \
269       const uint8_t *src8, int src_stride, int x_offset, int y_offset,         \
270       const uint8_t *ref8, int ref_stride, uint32_t *sse_ptr) {                \
271     uint32_t sse;                                                              \
272     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                 \
273     uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                                 \
274     int se = vpx_highbd_sub_pixel_variance##wf##xh_##opt(                      \
275         src, src_stride, x_offset, y_offset, ref, ref_stride, h, &sse, NULL,   \
276         NULL);                                                                 \
277     if (w > wf) {                                                              \
278       unsigned int sse2;                                                       \
279       int se2 = vpx_highbd_sub_pixel_variance##wf##xh_##opt(                   \
280           src + 16, src_stride, x_offset, y_offset, ref + 16, ref_stride, h,   \
281           &sse2, NULL, NULL);                                                  \
282       se += se2;                                                               \
283       sse += sse2;                                                             \
284       if (w > wf * 2) {                                                        \
285         se2 = vpx_highbd_sub_pixel_variance##wf##xh_##opt(                     \
286             src + 32, src_stride, x_offset, y_offset, ref + 32, ref_stride, h, \
287             &sse2, NULL, NULL);                                                \
288         se += se2;                                                             \
289         sse += sse2;                                                           \
290         se2 = vpx_highbd_sub_pixel_variance##wf##xh_##opt(                     \
291             src + 48, src_stride, x_offset, y_offset, ref + 48, ref_stride, h, \
292             &sse2, NULL, NULL);                                                \
293         se += se2;                                                             \
294         sse += sse2;                                                           \
295       }                                                                        \
296     }                                                                          \
297     *sse_ptr = sse;                                                            \
298     return sse - (uint32_t)((cast se * se) >> (wlog2 + hlog2));                \
299   }                                                                            \
300                                                                                \
301   uint32_t vpx_highbd_10_sub_pixel_variance##w##x##h##_##opt(                  \
302       const uint8_t *src8, int src_stride, int x_offset, int y_offset,         \
303       const uint8_t *ref8, int ref_stride, uint32_t *sse_ptr) {                \
304     int64_t var;                                                               \
305     uint32_t sse;                                                              \
306     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                 \
307     uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                                 \
308     int se = vpx_highbd_sub_pixel_variance##wf##xh_##opt(                      \
309         src, src_stride, x_offset, y_offset, ref, ref_stride, h, &sse, NULL,   \
310         NULL);                                                                 \
311     if (w > wf) {                                                              \
312       uint32_t sse2;                                                           \
313       int se2 = vpx_highbd_sub_pixel_variance##wf##xh_##opt(                   \
314           src + 16, src_stride, x_offset, y_offset, ref + 16, ref_stride, h,   \
315           &sse2, NULL, NULL);                                                  \
316       se += se2;                                                               \
317       sse += sse2;                                                             \
318       if (w > wf * 2) {                                                        \
319         se2 = vpx_highbd_sub_pixel_variance##wf##xh_##opt(                     \
320             src + 32, src_stride, x_offset, y_offset, ref + 32, ref_stride, h, \
321             &sse2, NULL, NULL);                                                \
322         se += se2;                                                             \
323         sse += sse2;                                                           \
324         se2 = vpx_highbd_sub_pixel_variance##wf##xh_##opt(                     \
325             src + 48, src_stride, x_offset, y_offset, ref + 48, ref_stride, h, \
326             &sse2, NULL, NULL);                                                \
327         se += se2;                                                             \
328         sse += sse2;                                                           \
329       }                                                                        \
330     }                                                                          \
331     se = ROUND_POWER_OF_TWO(se, 2);                                            \
332     sse = ROUND_POWER_OF_TWO(sse, 4);                                          \
333     *sse_ptr = sse;                                                            \
334     var = (int64_t)(sse) - ((cast se * se) >> (wlog2 + hlog2));                \
335     return (var >= 0) ? (uint32_t)var : 0;                                     \
336   }                                                                            \
337                                                                                \
338   uint32_t vpx_highbd_12_sub_pixel_variance##w##x##h##_##opt(                  \
339       const uint8_t *src8, int src_stride, int x_offset, int y_offset,         \
340       const uint8_t *ref8, int ref_stride, uint32_t *sse_ptr) {                \
341     int start_row;                                                             \
342     uint32_t sse;                                                              \
343     int se = 0;                                                                \
344     int64_t var;                                                               \
345     uint64_t long_sse = 0;                                                     \
346     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                 \
347     uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                                 \
348     for (start_row = 0; start_row < h; start_row += 16) {                      \
349       uint32_t sse2;                                                           \
350       int height = h - start_row < 16 ? h - start_row : 16;                    \
351       int se2 = vpx_highbd_sub_pixel_variance##wf##xh_##opt(                   \
352           src + (start_row * src_stride), src_stride, x_offset, y_offset,      \
353           ref + (start_row * ref_stride), ref_stride, height, &sse2, NULL,     \
354           NULL);                                                               \
355       se += se2;                                                               \
356       long_sse += sse2;                                                        \
357       if (w > wf) {                                                            \
358         se2 = vpx_highbd_sub_pixel_variance##wf##xh_##opt(                     \
359             src + 16 + (start_row * src_stride), src_stride, x_offset,         \
360             y_offset, ref + 16 + (start_row * ref_stride), ref_stride, height, \
361             &sse2, NULL, NULL);                                                \
362         se += se2;                                                             \
363         long_sse += sse2;                                                      \
364         if (w > wf * 2) {                                                      \
365           se2 = vpx_highbd_sub_pixel_variance##wf##xh_##opt(                   \
366               src + 32 + (start_row * src_stride), src_stride, x_offset,       \
367               y_offset, ref + 32 + (start_row * ref_stride), ref_stride,       \
368               height, &sse2, NULL, NULL);                                      \
369           se += se2;                                                           \
370           long_sse += sse2;                                                    \
371           se2 = vpx_highbd_sub_pixel_variance##wf##xh_##opt(                   \
372               src + 48 + (start_row * src_stride), src_stride, x_offset,       \
373               y_offset, ref + 48 + (start_row * ref_stride), ref_stride,       \
374               height, &sse2, NULL, NULL);                                      \
375           se += se2;                                                           \
376           long_sse += sse2;                                                    \
377         }                                                                      \
378       }                                                                        \
379     }                                                                          \
380     se = ROUND_POWER_OF_TWO(se, 4);                                            \
381     sse = (uint32_t)ROUND_POWER_OF_TWO(long_sse, 8);                           \
382     *sse_ptr = sse;                                                            \
383     var = (int64_t)(sse) - ((cast se * se) >> (wlog2 + hlog2));                \
384     return (var >= 0) ? (uint32_t)var : 0;                                     \
385   }
386 
387 #define FNS(opt)                       \
388   FN(64, 64, 16, 6, 6, opt, (int64_t)) \
389   FN(64, 32, 16, 6, 5, opt, (int64_t)) \
390   FN(32, 64, 16, 5, 6, opt, (int64_t)) \
391   FN(32, 32, 16, 5, 5, opt, (int64_t)) \
392   FN(32, 16, 16, 5, 4, opt, (int64_t)) \
393   FN(16, 32, 16, 4, 5, opt, (int64_t)) \
394   FN(16, 16, 16, 4, 4, opt, (int64_t)) \
395   FN(16, 8, 16, 4, 3, opt, (int64_t))  \
396   FN(8, 16, 8, 3, 4, opt, (int64_t))   \
397   FN(8, 8, 8, 3, 3, opt, (int64_t))    \
398   FN(8, 4, 8, 3, 2, opt, (int64_t))
399 
FNS(sse2)400 FNS(sse2)
401 
402 #undef FNS
403 #undef FN
404 
405 // The 2 unused parameters are place holders for PIC enabled build.
406 #define DECL(w, opt)                                                         \
407   int vpx_highbd_sub_pixel_avg_variance##w##xh_##opt(                        \
408       const uint16_t *src, ptrdiff_t src_stride, int x_offset, int y_offset, \
409       const uint16_t *ref, ptrdiff_t ref_stride, const uint16_t *second,     \
410       ptrdiff_t second_stride, int height, unsigned int *sse, void *unused0, \
411       void *unused);
412 #define DECLS(opt1) \
413   DECL(16, opt1)    \
414   DECL(8, opt1)
415 
416 DECLS(sse2)
417 #undef DECL
418 #undef DECLS
419 
420 #define FN(w, h, wf, wlog2, hlog2, opt, cast)                                  \
421   uint32_t vpx_highbd_8_sub_pixel_avg_variance##w##x##h##_##opt(               \
422       const uint8_t *src8, int src_stride, int x_offset, int y_offset,         \
423       const uint8_t *ref8, int ref_stride, uint32_t *sse_ptr,                  \
424       const uint8_t *sec8) {                                                   \
425     uint32_t sse;                                                              \
426     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                 \
427     uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                                 \
428     uint16_t *sec = CONVERT_TO_SHORTPTR(sec8);                                 \
429     int se = vpx_highbd_sub_pixel_avg_variance##wf##xh_##opt(                  \
430         src, src_stride, x_offset, y_offset, ref, ref_stride, sec, w, h, &sse, \
431         NULL, NULL);                                                           \
432     if (w > wf) {                                                              \
433       uint32_t sse2;                                                           \
434       int se2 = vpx_highbd_sub_pixel_avg_variance##wf##xh_##opt(               \
435           src + 16, src_stride, x_offset, y_offset, ref + 16, ref_stride,      \
436           sec + 16, w, h, &sse2, NULL, NULL);                                  \
437       se += se2;                                                               \
438       sse += sse2;                                                             \
439       if (w > wf * 2) {                                                        \
440         se2 = vpx_highbd_sub_pixel_avg_variance##wf##xh_##opt(                 \
441             src + 32, src_stride, x_offset, y_offset, ref + 32, ref_stride,    \
442             sec + 32, w, h, &sse2, NULL, NULL);                                \
443         se += se2;                                                             \
444         sse += sse2;                                                           \
445         se2 = vpx_highbd_sub_pixel_avg_variance##wf##xh_##opt(                 \
446             src + 48, src_stride, x_offset, y_offset, ref + 48, ref_stride,    \
447             sec + 48, w, h, &sse2, NULL, NULL);                                \
448         se += se2;                                                             \
449         sse += sse2;                                                           \
450       }                                                                        \
451     }                                                                          \
452     *sse_ptr = sse;                                                            \
453     return sse - (uint32_t)((cast se * se) >> (wlog2 + hlog2));                \
454   }                                                                            \
455                                                                                \
456   uint32_t vpx_highbd_10_sub_pixel_avg_variance##w##x##h##_##opt(              \
457       const uint8_t *src8, int src_stride, int x_offset, int y_offset,         \
458       const uint8_t *ref8, int ref_stride, uint32_t *sse_ptr,                  \
459       const uint8_t *sec8) {                                                   \
460     int64_t var;                                                               \
461     uint32_t sse;                                                              \
462     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                 \
463     uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                                 \
464     uint16_t *sec = CONVERT_TO_SHORTPTR(sec8);                                 \
465     int se = vpx_highbd_sub_pixel_avg_variance##wf##xh_##opt(                  \
466         src, src_stride, x_offset, y_offset, ref, ref_stride, sec, w, h, &sse, \
467         NULL, NULL);                                                           \
468     if (w > wf) {                                                              \
469       uint32_t sse2;                                                           \
470       int se2 = vpx_highbd_sub_pixel_avg_variance##wf##xh_##opt(               \
471           src + 16, src_stride, x_offset, y_offset, ref + 16, ref_stride,      \
472           sec + 16, w, h, &sse2, NULL, NULL);                                  \
473       se += se2;                                                               \
474       sse += sse2;                                                             \
475       if (w > wf * 2) {                                                        \
476         se2 = vpx_highbd_sub_pixel_avg_variance##wf##xh_##opt(                 \
477             src + 32, src_stride, x_offset, y_offset, ref + 32, ref_stride,    \
478             sec + 32, w, h, &sse2, NULL, NULL);                                \
479         se += se2;                                                             \
480         sse += sse2;                                                           \
481         se2 = vpx_highbd_sub_pixel_avg_variance##wf##xh_##opt(                 \
482             src + 48, src_stride, x_offset, y_offset, ref + 48, ref_stride,    \
483             sec + 48, w, h, &sse2, NULL, NULL);                                \
484         se += se2;                                                             \
485         sse += sse2;                                                           \
486       }                                                                        \
487     }                                                                          \
488     se = ROUND_POWER_OF_TWO(se, 2);                                            \
489     sse = ROUND_POWER_OF_TWO(sse, 4);                                          \
490     *sse_ptr = sse;                                                            \
491     var = (int64_t)(sse) - ((cast se * se) >> (wlog2 + hlog2));                \
492     return (var >= 0) ? (uint32_t)var : 0;                                     \
493   }                                                                            \
494                                                                                \
495   uint32_t vpx_highbd_12_sub_pixel_avg_variance##w##x##h##_##opt(              \
496       const uint8_t *src8, int src_stride, int x_offset, int y_offset,         \
497       const uint8_t *ref8, int ref_stride, uint32_t *sse_ptr,                  \
498       const uint8_t *sec8) {                                                   \
499     int start_row;                                                             \
500     int64_t var;                                                               \
501     uint32_t sse;                                                              \
502     int se = 0;                                                                \
503     uint64_t long_sse = 0;                                                     \
504     uint16_t *src = CONVERT_TO_SHORTPTR(src8);                                 \
505     uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);                                 \
506     uint16_t *sec = CONVERT_TO_SHORTPTR(sec8);                                 \
507     for (start_row = 0; start_row < h; start_row += 16) {                      \
508       uint32_t sse2;                                                           \
509       int height = h - start_row < 16 ? h - start_row : 16;                    \
510       int se2 = vpx_highbd_sub_pixel_avg_variance##wf##xh_##opt(               \
511           src + (start_row * src_stride), src_stride, x_offset, y_offset,      \
512           ref + (start_row * ref_stride), ref_stride, sec + (start_row * w),   \
513           w, height, &sse2, NULL, NULL);                                       \
514       se += se2;                                                               \
515       long_sse += sse2;                                                        \
516       if (w > wf) {                                                            \
517         se2 = vpx_highbd_sub_pixel_avg_variance##wf##xh_##opt(                 \
518             src + 16 + (start_row * src_stride), src_stride, x_offset,         \
519             y_offset, ref + 16 + (start_row * ref_stride), ref_stride,         \
520             sec + 16 + (start_row * w), w, height, &sse2, NULL, NULL);         \
521         se += se2;                                                             \
522         long_sse += sse2;                                                      \
523         if (w > wf * 2) {                                                      \
524           se2 = vpx_highbd_sub_pixel_avg_variance##wf##xh_##opt(               \
525               src + 32 + (start_row * src_stride), src_stride, x_offset,       \
526               y_offset, ref + 32 + (start_row * ref_stride), ref_stride,       \
527               sec + 32 + (start_row * w), w, height, &sse2, NULL, NULL);       \
528           se += se2;                                                           \
529           long_sse += sse2;                                                    \
530           se2 = vpx_highbd_sub_pixel_avg_variance##wf##xh_##opt(               \
531               src + 48 + (start_row * src_stride), src_stride, x_offset,       \
532               y_offset, ref + 48 + (start_row * ref_stride), ref_stride,       \
533               sec + 48 + (start_row * w), w, height, &sse2, NULL, NULL);       \
534           se += se2;                                                           \
535           long_sse += sse2;                                                    \
536         }                                                                      \
537       }                                                                        \
538     }                                                                          \
539     se = ROUND_POWER_OF_TWO(se, 4);                                            \
540     sse = (uint32_t)ROUND_POWER_OF_TWO(long_sse, 8);                           \
541     *sse_ptr = sse;                                                            \
542     var = (int64_t)(sse) - ((cast se * se) >> (wlog2 + hlog2));                \
543     return (var >= 0) ? (uint32_t)var : 0;                                     \
544   }
545 
546 #define FNS(opt1)                       \
547   FN(64, 64, 16, 6, 6, opt1, (int64_t)) \
548   FN(64, 32, 16, 6, 5, opt1, (int64_t)) \
549   FN(32, 64, 16, 5, 6, opt1, (int64_t)) \
550   FN(32, 32, 16, 5, 5, opt1, (int64_t)) \
551   FN(32, 16, 16, 5, 4, opt1, (int64_t)) \
552   FN(16, 32, 16, 4, 5, opt1, (int64_t)) \
553   FN(16, 16, 16, 4, 4, opt1, (int64_t)) \
554   FN(16, 8, 16, 4, 3, opt1, (int64_t))  \
555   FN(8, 16, 8, 4, 3, opt1, (int64_t))   \
556   FN(8, 8, 8, 3, 3, opt1, (int64_t))    \
557   FN(8, 4, 8, 3, 2, opt1, (int64_t))
558 
559 FNS(sse2)
560 
561 #undef FNS
562 #undef FN
563 
564 void vpx_highbd_comp_avg_pred_sse2(uint16_t *comp_pred, const uint16_t *pred,
565                                    int width, int height, const uint16_t *ref,
566                                    int ref_stride) {
567   int i, j;
568   if (width > 8) {
569     for (i = 0; i < height; ++i) {
570       for (j = 0; j < width; j += 16) {
571         const __m128i p0 = _mm_loadu_si128((const __m128i *)&pred[j]);
572         const __m128i p1 = _mm_loadu_si128((const __m128i *)&pred[j + 8]);
573         const __m128i r0 = _mm_loadu_si128((const __m128i *)&ref[j]);
574         const __m128i r1 = _mm_loadu_si128((const __m128i *)&ref[j + 8]);
575         _mm_storeu_si128((__m128i *)&comp_pred[j], _mm_avg_epu16(p0, r0));
576         _mm_storeu_si128((__m128i *)&comp_pred[j + 8], _mm_avg_epu16(p1, r1));
577       }
578       comp_pred += width;
579       pred += width;
580       ref += ref_stride;
581     }
582   } else if (width == 8) {
583     for (i = 0; i < height; i += 2) {
584       const __m128i p0 = _mm_loadu_si128((const __m128i *)&pred[0]);
585       const __m128i p1 = _mm_loadu_si128((const __m128i *)&pred[8]);
586       const __m128i r0 = _mm_loadu_si128((const __m128i *)&ref[0]);
587       const __m128i r1 = _mm_loadu_si128((const __m128i *)&ref[ref_stride]);
588       _mm_storeu_si128((__m128i *)&comp_pred[0], _mm_avg_epu16(p0, r0));
589       _mm_storeu_si128((__m128i *)&comp_pred[8], _mm_avg_epu16(p1, r1));
590       comp_pred += 8 << 1;
591       pred += 8 << 1;
592       ref += ref_stride << 1;
593     }
594   } else {
595     assert(width == 4);
596     for (i = 0; i < height; i += 2) {
597       const __m128i p0 = _mm_loadl_epi64((const __m128i *)&pred[0]);
598       const __m128i p1 = _mm_loadl_epi64((const __m128i *)&pred[4]);
599       const __m128i r0 = _mm_loadl_epi64((const __m128i *)&ref[0]);
600       const __m128i r1 = _mm_loadl_epi64((const __m128i *)&ref[ref_stride]);
601       _mm_storel_epi64((__m128i *)&comp_pred[0], _mm_avg_epu16(p0, r0));
602       _mm_storel_epi64((__m128i *)&comp_pred[4], _mm_avg_epu16(p1, r1));
603       comp_pred += 4 << 1;
604       pred += 4 << 1;
605       ref += ref_stride << 1;
606     }
607   }
608 }
609