xref: /aosp_15_r20/external/libaom/aom_dsp/arm/highbd_variance_neon_dotprod.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2023 The WebM project authors. All rights reserved.
3  * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
4  *
5  * This source code is subject to the terms of the BSD 2 Clause License and
6  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
7  * was not distributed with this source code in the LICENSE file, you can
8  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
9  * Media Patent License 1.0 was not distributed with this source code in the
10  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
11  */
12 
13 #include <arm_neon.h>
14 
15 #include "aom_dsp/arm/sum_neon.h"
16 #include "config/aom_config.h"
17 #include "config/aom_dsp_rtcd.h"
18 
highbd_mse8_8xh_neon_dotprod(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,int h,unsigned int * sse)19 static inline uint32_t highbd_mse8_8xh_neon_dotprod(const uint16_t *src_ptr,
20                                                     int src_stride,
21                                                     const uint16_t *ref_ptr,
22                                                     int ref_stride, int h,
23                                                     unsigned int *sse) {
24   uint32x4_t sse_u32 = vdupq_n_u32(0);
25 
26   int i = h / 2;
27   do {
28     uint16x8_t s0 = vld1q_u16(src_ptr);
29     src_ptr += src_stride;
30     uint16x8_t s1 = vld1q_u16(src_ptr);
31     src_ptr += src_stride;
32     uint16x8_t r0 = vld1q_u16(ref_ptr);
33     ref_ptr += ref_stride;
34     uint16x8_t r1 = vld1q_u16(ref_ptr);
35     ref_ptr += ref_stride;
36 
37     uint8x16_t s = vcombine_u8(vmovn_u16(s0), vmovn_u16(s1));
38     uint8x16_t r = vcombine_u8(vmovn_u16(r0), vmovn_u16(r1));
39 
40     uint8x16_t diff = vabdq_u8(s, r);
41     sse_u32 = vdotq_u32(sse_u32, diff, diff);
42   } while (--i != 0);
43 
44   *sse = horizontal_add_u32x4(sse_u32);
45   return *sse;
46 }
47 
highbd_mse8_16xh_neon_dotprod(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,int h,unsigned int * sse)48 static inline uint32_t highbd_mse8_16xh_neon_dotprod(const uint16_t *src_ptr,
49                                                      int src_stride,
50                                                      const uint16_t *ref_ptr,
51                                                      int ref_stride, int h,
52                                                      unsigned int *sse) {
53   uint32x4_t sse_u32 = vdupq_n_u32(0);
54 
55   int i = h;
56   do {
57     uint16x8_t s0 = vld1q_u16(src_ptr);
58     uint16x8_t s1 = vld1q_u16(src_ptr + 8);
59     uint16x8_t r0 = vld1q_u16(ref_ptr);
60     uint16x8_t r1 = vld1q_u16(ref_ptr + 8);
61 
62     uint8x16_t s = vcombine_u8(vmovn_u16(s0), vmovn_u16(s1));
63     uint8x16_t r = vcombine_u8(vmovn_u16(r0), vmovn_u16(r1));
64 
65     uint8x16_t diff = vabdq_u8(s, r);
66     sse_u32 = vdotq_u32(sse_u32, diff, diff);
67 
68     src_ptr += src_stride;
69     ref_ptr += ref_stride;
70   } while (--i != 0);
71 
72   *sse = horizontal_add_u32x4(sse_u32);
73   return *sse;
74 }
75 
76 #define HIGHBD_MSE_WXH_NEON_DOTPROD(w, h)                                 \
77   uint32_t aom_highbd_8_mse##w##x##h##_neon_dotprod(                      \
78       const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,     \
79       int ref_stride, uint32_t *sse) {                                    \
80     uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                         \
81     uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                         \
82     highbd_mse8_##w##xh_neon_dotprod(src, src_stride, ref, ref_stride, h, \
83                                      sse);                                \
84     return *sse;                                                          \
85   }
86 
87 HIGHBD_MSE_WXH_NEON_DOTPROD(16, 16)
88 HIGHBD_MSE_WXH_NEON_DOTPROD(16, 8)
89 HIGHBD_MSE_WXH_NEON_DOTPROD(8, 16)
90 HIGHBD_MSE_WXH_NEON_DOTPROD(8, 8)
91 
92 #undef HIGHBD_MSE_WXH_NEON_DOTPROD
93