1 /*
2 *
3 * Copyright (c) 2018, Alliance for Open Media. All rights reserved.
4 *
5 * This source code is subject to the terms of the BSD 2 Clause License and
6 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
7 * was not distributed with this source code in the LICENSE file, you can
8 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
9 * Media Patent License 1.0 was not distributed with this source code in the
10 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
11 */
12
13 #include <arm_neon.h>
14 #include <assert.h>
15 #include <stdbool.h>
16
17 #include "aom/aom_integer.h"
18 #include "aom_dsp/blend.h"
19 #include "aom_dsp/arm/mem_neon.h"
20 #include "aom_ports/mem.h"
21 #include "av1/common/blockd.h"
22 #include "config/av1_rtcd.h"
23
diffwtd_mask_d16_neon(uint8_t * mask,const bool inverse,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,int h,int w,ConvolveParams * conv_params,int bd)24 static inline void diffwtd_mask_d16_neon(uint8_t *mask, const bool inverse,
25 const CONV_BUF_TYPE *src0,
26 int src0_stride,
27 const CONV_BUF_TYPE *src1,
28 int src1_stride, int h, int w,
29 ConvolveParams *conv_params, int bd) {
30 const int round =
31 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1 + (bd - 8);
32 const int16x8_t round_vec = vdupq_n_s16((int16_t)(-round));
33
34 if (w >= 16) {
35 int i = 0;
36 do {
37 int j = 0;
38 do {
39 uint16x8_t s0_lo = vld1q_u16(src0 + j);
40 uint16x8_t s1_lo = vld1q_u16(src1 + j);
41 uint16x8_t s0_hi = vld1q_u16(src0 + j + 8);
42 uint16x8_t s1_hi = vld1q_u16(src1 + j + 8);
43
44 uint16x8_t diff_lo_u16 = vrshlq_u16(vabdq_u16(s0_lo, s1_lo), round_vec);
45 uint16x8_t diff_hi_u16 = vrshlq_u16(vabdq_u16(s0_hi, s1_hi), round_vec);
46 uint8x8_t diff_lo_u8 = vshrn_n_u16(diff_lo_u16, DIFF_FACTOR_LOG2);
47 uint8x8_t diff_hi_u8 = vshrn_n_u16(diff_hi_u16, DIFF_FACTOR_LOG2);
48 uint8x16_t diff = vcombine_u8(diff_lo_u8, diff_hi_u8);
49
50 uint8x16_t m;
51 if (inverse) {
52 m = vqsubq_u8(vdupq_n_u8(64 - 38), diff); // Saturating to 0
53 } else {
54 m = vminq_u8(vaddq_u8(diff, vdupq_n_u8(38)), vdupq_n_u8(64));
55 }
56
57 vst1q_u8(mask, m);
58
59 mask += 16;
60 j += 16;
61 } while (j < w);
62 src0 += src0_stride;
63 src1 += src1_stride;
64 } while (++i < h);
65 } else if (w == 8) {
66 int i = 0;
67 do {
68 uint16x8_t s0 = vld1q_u16(src0);
69 uint16x8_t s1 = vld1q_u16(src1);
70
71 uint16x8_t diff_u16 = vrshlq_u16(vabdq_u16(s0, s1), round_vec);
72 uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, DIFF_FACTOR_LOG2);
73 uint8x8_t m;
74 if (inverse) {
75 m = vqsub_u8(vdup_n_u8(64 - 38), diff_u8); // Saturating to 0
76 } else {
77 m = vmin_u8(vadd_u8(diff_u8, vdup_n_u8(38)), vdup_n_u8(64));
78 }
79
80 vst1_u8(mask, m);
81
82 mask += 8;
83 src0 += src0_stride;
84 src1 += src1_stride;
85 } while (++i < h);
86 } else if (w == 4) {
87 int i = 0;
88 do {
89 uint16x8_t s0 =
90 vcombine_u16(vld1_u16(src0), vld1_u16(src0 + src0_stride));
91 uint16x8_t s1 =
92 vcombine_u16(vld1_u16(src1), vld1_u16(src1 + src1_stride));
93
94 uint16x8_t diff_u16 = vrshlq_u16(vabdq_u16(s0, s1), round_vec);
95 uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, DIFF_FACTOR_LOG2);
96 uint8x8_t m;
97 if (inverse) {
98 m = vqsub_u8(vdup_n_u8(64 - 38), diff_u8); // Saturating to 0
99 } else {
100 m = vmin_u8(vadd_u8(diff_u8, vdup_n_u8(38)), vdup_n_u8(64));
101 }
102
103 vst1_u8(mask, m);
104
105 mask += 8;
106 src0 += 2 * src0_stride;
107 src1 += 2 * src1_stride;
108 i += 2;
109 } while (i < h);
110 }
111 }
112
av1_build_compound_diffwtd_mask_d16_neon(uint8_t * mask,DIFFWTD_MASK_TYPE mask_type,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,int h,int w,ConvolveParams * conv_params,int bd)113 void av1_build_compound_diffwtd_mask_d16_neon(
114 uint8_t *mask, DIFFWTD_MASK_TYPE mask_type, const CONV_BUF_TYPE *src0,
115 int src0_stride, const CONV_BUF_TYPE *src1, int src1_stride, int h, int w,
116 ConvolveParams *conv_params, int bd) {
117 assert(h >= 4);
118 assert(w >= 4);
119 assert((mask_type == DIFFWTD_38_INV) || (mask_type == DIFFWTD_38));
120
121 if (mask_type == DIFFWTD_38) {
122 diffwtd_mask_d16_neon(mask, /*inverse=*/false, src0, src0_stride, src1,
123 src1_stride, h, w, conv_params, bd);
124 } else { // mask_type == DIFFWTD_38_INV
125 diffwtd_mask_d16_neon(mask, /*inverse=*/true, src0, src0_stride, src1,
126 src1_stride, h, w, conv_params, bd);
127 }
128 }
129
diffwtd_mask_neon(uint8_t * mask,const bool inverse,const uint8_t * src0,int src0_stride,const uint8_t * src1,int src1_stride,int h,int w)130 static inline void diffwtd_mask_neon(uint8_t *mask, const bool inverse,
131 const uint8_t *src0, int src0_stride,
132 const uint8_t *src1, int src1_stride,
133 int h, int w) {
134 if (w >= 16) {
135 int i = 0;
136 do {
137 int j = 0;
138 do {
139 uint8x16_t s0 = vld1q_u8(src0 + j);
140 uint8x16_t s1 = vld1q_u8(src1 + j);
141
142 uint8x16_t diff = vshrq_n_u8(vabdq_u8(s0, s1), DIFF_FACTOR_LOG2);
143 uint8x16_t m;
144 if (inverse) {
145 m = vqsubq_u8(vdupq_n_u8(64 - 38), diff); // Saturating to 0
146 } else {
147 m = vminq_u8(vaddq_u8(diff, vdupq_n_u8(38)), vdupq_n_u8(64));
148 }
149
150 vst1q_u8(mask, m);
151
152 mask += 16;
153 j += 16;
154 } while (j < w);
155 src0 += src0_stride;
156 src1 += src1_stride;
157 } while (++i < h);
158 } else if (w == 8) {
159 int i = 0;
160 do {
161 uint8x16_t s0 = vcombine_u8(vld1_u8(src0), vld1_u8(src0 + src0_stride));
162 uint8x16_t s1 = vcombine_u8(vld1_u8(src1), vld1_u8(src1 + src0_stride));
163
164 uint8x16_t diff = vshrq_n_u8(vabdq_u8(s0, s1), DIFF_FACTOR_LOG2);
165 uint8x16_t m;
166 if (inverse) {
167 m = vqsubq_u8(vdupq_n_u8(64 - 38), diff); // Saturating to 0
168 } else {
169 m = vminq_u8(vaddq_u8(diff, vdupq_n_u8(38)), vdupq_n_u8(64));
170 }
171
172 vst1q_u8(mask, m);
173
174 mask += 16;
175 src0 += 2 * src0_stride;
176 src1 += 2 * src1_stride;
177 i += 2;
178 } while (i < h);
179 } else if (w == 4) {
180 int i = 0;
181 do {
182 uint8x16_t s0 = load_unaligned_u8q(src0, src0_stride);
183 uint8x16_t s1 = load_unaligned_u8q(src1, src1_stride);
184
185 uint8x16_t diff = vshrq_n_u8(vabdq_u8(s0, s1), DIFF_FACTOR_LOG2);
186 uint8x16_t m;
187 if (inverse) {
188 m = vqsubq_u8(vdupq_n_u8(64 - 38), diff); // Saturating to 0
189 } else {
190 m = vminq_u8(vaddq_u8(diff, vdupq_n_u8(38)), vdupq_n_u8(64));
191 }
192
193 vst1q_u8(mask, m);
194
195 mask += 16;
196 src0 += 4 * src0_stride;
197 src1 += 4 * src1_stride;
198 i += 4;
199 } while (i < h);
200 }
201 }
202
av1_build_compound_diffwtd_mask_neon(uint8_t * mask,DIFFWTD_MASK_TYPE mask_type,const uint8_t * src0,int src0_stride,const uint8_t * src1,int src1_stride,int h,int w)203 void av1_build_compound_diffwtd_mask_neon(uint8_t *mask,
204 DIFFWTD_MASK_TYPE mask_type,
205 const uint8_t *src0, int src0_stride,
206 const uint8_t *src1, int src1_stride,
207 int h, int w) {
208 assert(h % 4 == 0);
209 assert(w % 4 == 0);
210 assert(mask_type == DIFFWTD_38_INV || mask_type == DIFFWTD_38);
211
212 if (mask_type == DIFFWTD_38) {
213 diffwtd_mask_neon(mask, /*inverse=*/false, src0, src0_stride, src1,
214 src1_stride, h, w);
215 } else { // mask_type == DIFFWTD_38_INV
216 diffwtd_mask_neon(mask, /*inverse=*/true, src0, src0_stride, src1,
217 src1_stride, h, w);
218 }
219 }
220