xref: /aosp_15_r20/external/libaom/av1/common/arm/reconinter_neon.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  *
3  * Copyright (c) 2018, Alliance for Open Media. All rights reserved.
4  *
5  * This source code is subject to the terms of the BSD 2 Clause License and
6  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
7  * was not distributed with this source code in the LICENSE file, you can
8  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
9  * Media Patent License 1.0 was not distributed with this source code in the
10  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
11  */
12 
13 #include <arm_neon.h>
14 #include <assert.h>
15 #include <stdbool.h>
16 
17 #include "aom/aom_integer.h"
18 #include "aom_dsp/blend.h"
19 #include "aom_dsp/arm/mem_neon.h"
20 #include "aom_ports/mem.h"
21 #include "av1/common/blockd.h"
22 #include "config/av1_rtcd.h"
23 
diffwtd_mask_d16_neon(uint8_t * mask,const bool inverse,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,int h,int w,ConvolveParams * conv_params,int bd)24 static inline void diffwtd_mask_d16_neon(uint8_t *mask, const bool inverse,
25                                          const CONV_BUF_TYPE *src0,
26                                          int src0_stride,
27                                          const CONV_BUF_TYPE *src1,
28                                          int src1_stride, int h, int w,
29                                          ConvolveParams *conv_params, int bd) {
30   const int round =
31       2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1 + (bd - 8);
32   const int16x8_t round_vec = vdupq_n_s16((int16_t)(-round));
33 
34   if (w >= 16) {
35     int i = 0;
36     do {
37       int j = 0;
38       do {
39         uint16x8_t s0_lo = vld1q_u16(src0 + j);
40         uint16x8_t s1_lo = vld1q_u16(src1 + j);
41         uint16x8_t s0_hi = vld1q_u16(src0 + j + 8);
42         uint16x8_t s1_hi = vld1q_u16(src1 + j + 8);
43 
44         uint16x8_t diff_lo_u16 = vrshlq_u16(vabdq_u16(s0_lo, s1_lo), round_vec);
45         uint16x8_t diff_hi_u16 = vrshlq_u16(vabdq_u16(s0_hi, s1_hi), round_vec);
46         uint8x8_t diff_lo_u8 = vshrn_n_u16(diff_lo_u16, DIFF_FACTOR_LOG2);
47         uint8x8_t diff_hi_u8 = vshrn_n_u16(diff_hi_u16, DIFF_FACTOR_LOG2);
48         uint8x16_t diff = vcombine_u8(diff_lo_u8, diff_hi_u8);
49 
50         uint8x16_t m;
51         if (inverse) {
52           m = vqsubq_u8(vdupq_n_u8(64 - 38), diff);  // Saturating to 0
53         } else {
54           m = vminq_u8(vaddq_u8(diff, vdupq_n_u8(38)), vdupq_n_u8(64));
55         }
56 
57         vst1q_u8(mask, m);
58 
59         mask += 16;
60         j += 16;
61       } while (j < w);
62       src0 += src0_stride;
63       src1 += src1_stride;
64     } while (++i < h);
65   } else if (w == 8) {
66     int i = 0;
67     do {
68       uint16x8_t s0 = vld1q_u16(src0);
69       uint16x8_t s1 = vld1q_u16(src1);
70 
71       uint16x8_t diff_u16 = vrshlq_u16(vabdq_u16(s0, s1), round_vec);
72       uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, DIFF_FACTOR_LOG2);
73       uint8x8_t m;
74       if (inverse) {
75         m = vqsub_u8(vdup_n_u8(64 - 38), diff_u8);  // Saturating to 0
76       } else {
77         m = vmin_u8(vadd_u8(diff_u8, vdup_n_u8(38)), vdup_n_u8(64));
78       }
79 
80       vst1_u8(mask, m);
81 
82       mask += 8;
83       src0 += src0_stride;
84       src1 += src1_stride;
85     } while (++i < h);
86   } else if (w == 4) {
87     int i = 0;
88     do {
89       uint16x8_t s0 =
90           vcombine_u16(vld1_u16(src0), vld1_u16(src0 + src0_stride));
91       uint16x8_t s1 =
92           vcombine_u16(vld1_u16(src1), vld1_u16(src1 + src1_stride));
93 
94       uint16x8_t diff_u16 = vrshlq_u16(vabdq_u16(s0, s1), round_vec);
95       uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, DIFF_FACTOR_LOG2);
96       uint8x8_t m;
97       if (inverse) {
98         m = vqsub_u8(vdup_n_u8(64 - 38), diff_u8);  // Saturating to 0
99       } else {
100         m = vmin_u8(vadd_u8(diff_u8, vdup_n_u8(38)), vdup_n_u8(64));
101       }
102 
103       vst1_u8(mask, m);
104 
105       mask += 8;
106       src0 += 2 * src0_stride;
107       src1 += 2 * src1_stride;
108       i += 2;
109     } while (i < h);
110   }
111 }
112 
av1_build_compound_diffwtd_mask_d16_neon(uint8_t * mask,DIFFWTD_MASK_TYPE mask_type,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,int h,int w,ConvolveParams * conv_params,int bd)113 void av1_build_compound_diffwtd_mask_d16_neon(
114     uint8_t *mask, DIFFWTD_MASK_TYPE mask_type, const CONV_BUF_TYPE *src0,
115     int src0_stride, const CONV_BUF_TYPE *src1, int src1_stride, int h, int w,
116     ConvolveParams *conv_params, int bd) {
117   assert(h >= 4);
118   assert(w >= 4);
119   assert((mask_type == DIFFWTD_38_INV) || (mask_type == DIFFWTD_38));
120 
121   if (mask_type == DIFFWTD_38) {
122     diffwtd_mask_d16_neon(mask, /*inverse=*/false, src0, src0_stride, src1,
123                           src1_stride, h, w, conv_params, bd);
124   } else {  // mask_type == DIFFWTD_38_INV
125     diffwtd_mask_d16_neon(mask, /*inverse=*/true, src0, src0_stride, src1,
126                           src1_stride, h, w, conv_params, bd);
127   }
128 }
129 
diffwtd_mask_neon(uint8_t * mask,const bool inverse,const uint8_t * src0,int src0_stride,const uint8_t * src1,int src1_stride,int h,int w)130 static inline void diffwtd_mask_neon(uint8_t *mask, const bool inverse,
131                                      const uint8_t *src0, int src0_stride,
132                                      const uint8_t *src1, int src1_stride,
133                                      int h, int w) {
134   if (w >= 16) {
135     int i = 0;
136     do {
137       int j = 0;
138       do {
139         uint8x16_t s0 = vld1q_u8(src0 + j);
140         uint8x16_t s1 = vld1q_u8(src1 + j);
141 
142         uint8x16_t diff = vshrq_n_u8(vabdq_u8(s0, s1), DIFF_FACTOR_LOG2);
143         uint8x16_t m;
144         if (inverse) {
145           m = vqsubq_u8(vdupq_n_u8(64 - 38), diff);  // Saturating to 0
146         } else {
147           m = vminq_u8(vaddq_u8(diff, vdupq_n_u8(38)), vdupq_n_u8(64));
148         }
149 
150         vst1q_u8(mask, m);
151 
152         mask += 16;
153         j += 16;
154       } while (j < w);
155       src0 += src0_stride;
156       src1 += src1_stride;
157     } while (++i < h);
158   } else if (w == 8) {
159     int i = 0;
160     do {
161       uint8x16_t s0 = vcombine_u8(vld1_u8(src0), vld1_u8(src0 + src0_stride));
162       uint8x16_t s1 = vcombine_u8(vld1_u8(src1), vld1_u8(src1 + src0_stride));
163 
164       uint8x16_t diff = vshrq_n_u8(vabdq_u8(s0, s1), DIFF_FACTOR_LOG2);
165       uint8x16_t m;
166       if (inverse) {
167         m = vqsubq_u8(vdupq_n_u8(64 - 38), diff);  // Saturating to 0
168       } else {
169         m = vminq_u8(vaddq_u8(diff, vdupq_n_u8(38)), vdupq_n_u8(64));
170       }
171 
172       vst1q_u8(mask, m);
173 
174       mask += 16;
175       src0 += 2 * src0_stride;
176       src1 += 2 * src1_stride;
177       i += 2;
178     } while (i < h);
179   } else if (w == 4) {
180     int i = 0;
181     do {
182       uint8x16_t s0 = load_unaligned_u8q(src0, src0_stride);
183       uint8x16_t s1 = load_unaligned_u8q(src1, src1_stride);
184 
185       uint8x16_t diff = vshrq_n_u8(vabdq_u8(s0, s1), DIFF_FACTOR_LOG2);
186       uint8x16_t m;
187       if (inverse) {
188         m = vqsubq_u8(vdupq_n_u8(64 - 38), diff);  // Saturating to 0
189       } else {
190         m = vminq_u8(vaddq_u8(diff, vdupq_n_u8(38)), vdupq_n_u8(64));
191       }
192 
193       vst1q_u8(mask, m);
194 
195       mask += 16;
196       src0 += 4 * src0_stride;
197       src1 += 4 * src1_stride;
198       i += 4;
199     } while (i < h);
200   }
201 }
202 
av1_build_compound_diffwtd_mask_neon(uint8_t * mask,DIFFWTD_MASK_TYPE mask_type,const uint8_t * src0,int src0_stride,const uint8_t * src1,int src1_stride,int h,int w)203 void av1_build_compound_diffwtd_mask_neon(uint8_t *mask,
204                                           DIFFWTD_MASK_TYPE mask_type,
205                                           const uint8_t *src0, int src0_stride,
206                                           const uint8_t *src1, int src1_stride,
207                                           int h, int w) {
208   assert(h % 4 == 0);
209   assert(w % 4 == 0);
210   assert(mask_type == DIFFWTD_38_INV || mask_type == DIFFWTD_38);
211 
212   if (mask_type == DIFFWTD_38) {
213     diffwtd_mask_neon(mask, /*inverse=*/false, src0, src0_stride, src1,
214                       src1_stride, h, w);
215   } else {  // mask_type == DIFFWTD_38_INV
216     diffwtd_mask_neon(mask, /*inverse=*/true, src0, src0_stride, src1,
217                       src1_stride, h, w);
218   }
219 }
220