xref: /aosp_15_r20/external/libgav1/src/dsp/weight_mask.cc (revision 095378508e87ed692bf8dfeb34008b65b3735891)
1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/weight_mask.h"
16 
17 #include <algorithm>
18 #include <cassert>
19 #include <cstddef>
20 #include <cstdint>
21 #include <string>
22 #include <type_traits>
23 
24 #include "src/dsp/dsp.h"
25 #include "src/utils/common.h"
26 
27 namespace libgav1 {
28 namespace dsp {
29 namespace {
30 
31 template <int width, int height, int bitdepth, bool mask_is_inverse>
WeightMask_C(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)32 void WeightMask_C(const void* LIBGAV1_RESTRICT prediction_0,
33                   const void* LIBGAV1_RESTRICT prediction_1,
34                   uint8_t* LIBGAV1_RESTRICT mask, ptrdiff_t mask_stride) {
35   using PredType =
36       typename std::conditional<bitdepth == 8, int16_t, uint16_t>::type;
37   const auto* pred_0 = static_cast<const PredType*>(prediction_0);
38   const auto* pred_1 = static_cast<const PredType*>(prediction_1);
39   static_assert(width >= 8, "");
40   static_assert(height >= 8, "");
41   constexpr int rounding_bits = bitdepth - 8 + ((bitdepth == 12) ? 2 : 4);
42   for (int y = 0; y < height; ++y) {
43     for (int x = 0; x < width; ++x) {
44       const int difference = RightShiftWithRounding(
45           std::abs(pred_0[x] - pred_1[x]), rounding_bits);
46       const auto mask_value =
47           static_cast<uint8_t>(std::min(DivideBy16(difference) + 38, 64));
48       mask[x] = mask_is_inverse ? 64 - mask_value : mask_value;
49     }
50     pred_0 += width;
51     pred_1 += width;
52     mask += mask_stride;
53   }
54 }
55 
56 #define INIT_WEIGHT_MASK(width, height, bitdepth, w_index, h_index) \
57   dsp->weight_mask[w_index][h_index][0] =                           \
58       WeightMask_C<width, height, bitdepth, 0>;                     \
59   dsp->weight_mask[w_index][h_index][1] =                           \
60       WeightMask_C<width, height, bitdepth, 1>
61 
Init8bpp()62 void Init8bpp() {
63   Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
64   assert(dsp != nullptr);
65 #if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
66   INIT_WEIGHT_MASK(8, 8, 8, 0, 0);
67   INIT_WEIGHT_MASK(8, 16, 8, 0, 1);
68   INIT_WEIGHT_MASK(8, 32, 8, 0, 2);
69   INIT_WEIGHT_MASK(16, 8, 8, 1, 0);
70   INIT_WEIGHT_MASK(16, 16, 8, 1, 1);
71   INIT_WEIGHT_MASK(16, 32, 8, 1, 2);
72   INIT_WEIGHT_MASK(16, 64, 8, 1, 3);
73   INIT_WEIGHT_MASK(32, 8, 8, 2, 0);
74   INIT_WEIGHT_MASK(32, 16, 8, 2, 1);
75   INIT_WEIGHT_MASK(32, 32, 8, 2, 2);
76   INIT_WEIGHT_MASK(32, 64, 8, 2, 3);
77   INIT_WEIGHT_MASK(64, 16, 8, 3, 1);
78   INIT_WEIGHT_MASK(64, 32, 8, 3, 2);
79   INIT_WEIGHT_MASK(64, 64, 8, 3, 3);
80   INIT_WEIGHT_MASK(64, 128, 8, 3, 4);
81   INIT_WEIGHT_MASK(128, 64, 8, 4, 3);
82   INIT_WEIGHT_MASK(128, 128, 8, 4, 4);
83 #else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
84   static_cast<void>(dsp);
85 #ifndef LIBGAV1_Dsp8bpp_WeightMask_8x8
86   INIT_WEIGHT_MASK(8, 8, 8, 0, 0);
87 #endif
88 #ifndef LIBGAV1_Dsp8bpp_WeightMask_8x16
89   INIT_WEIGHT_MASK(8, 16, 8, 0, 1);
90 #endif
91 #ifndef LIBGAV1_Dsp8bpp_WeightMask_8x32
92   INIT_WEIGHT_MASK(8, 32, 8, 0, 2);
93 #endif
94 #ifndef LIBGAV1_Dsp8bpp_WeightMask_16x8
95   INIT_WEIGHT_MASK(16, 8, 8, 1, 0);
96 #endif
97 #ifndef LIBGAV1_Dsp8bpp_WeightMask_16x16
98   INIT_WEIGHT_MASK(16, 16, 8, 1, 1);
99 #endif
100 #ifndef LIBGAV1_Dsp8bpp_WeightMask_16x32
101   INIT_WEIGHT_MASK(16, 32, 8, 1, 2);
102 #endif
103 #ifndef LIBGAV1_Dsp8bpp_WeightMask_16x64
104   INIT_WEIGHT_MASK(16, 64, 8, 1, 3);
105 #endif
106 #ifndef LIBGAV1_Dsp8bpp_WeightMask_32x8
107   INIT_WEIGHT_MASK(32, 8, 8, 2, 0);
108 #endif
109 #ifndef LIBGAV1_Dsp8bpp_WeightMask_32x16
110   INIT_WEIGHT_MASK(32, 16, 8, 2, 1);
111 #endif
112 #ifndef LIBGAV1_Dsp8bpp_WeightMask_32x32
113   INIT_WEIGHT_MASK(32, 32, 8, 2, 2);
114 #endif
115 #ifndef LIBGAV1_Dsp8bpp_WeightMask_32x64
116   INIT_WEIGHT_MASK(32, 64, 8, 2, 3);
117 #endif
118 #ifndef LIBGAV1_Dsp8bpp_WeightMask_64x16
119   INIT_WEIGHT_MASK(64, 16, 8, 3, 1);
120 #endif
121 #ifndef LIBGAV1_Dsp8bpp_WeightMask_64x32
122   INIT_WEIGHT_MASK(64, 32, 8, 3, 2);
123 #endif
124 #ifndef LIBGAV1_Dsp8bpp_WeightMask_64x64
125   INIT_WEIGHT_MASK(64, 64, 8, 3, 3);
126 #endif
127 #ifndef LIBGAV1_Dsp8bpp_WeightMask_64x128
128   INIT_WEIGHT_MASK(64, 128, 8, 3, 4);
129 #endif
130 #ifndef LIBGAV1_Dsp8bpp_WeightMask_128x64
131   INIT_WEIGHT_MASK(128, 64, 8, 4, 3);
132 #endif
133 #ifndef LIBGAV1_Dsp8bpp_WeightMask_128x128
134   INIT_WEIGHT_MASK(128, 128, 8, 4, 4);
135 #endif
136 #endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
137 }
138 
139 #if LIBGAV1_MAX_BITDEPTH >= 10
Init10bpp()140 void Init10bpp() {
141   Dsp* const dsp = dsp_internal::GetWritableDspTable(10);
142   assert(dsp != nullptr);
143 #if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
144   INIT_WEIGHT_MASK(8, 8, 10, 0, 0);
145   INIT_WEIGHT_MASK(8, 16, 10, 0, 1);
146   INIT_WEIGHT_MASK(8, 32, 10, 0, 2);
147   INIT_WEIGHT_MASK(16, 8, 10, 1, 0);
148   INIT_WEIGHT_MASK(16, 16, 10, 1, 1);
149   INIT_WEIGHT_MASK(16, 32, 10, 1, 2);
150   INIT_WEIGHT_MASK(16, 64, 10, 1, 3);
151   INIT_WEIGHT_MASK(32, 8, 10, 2, 0);
152   INIT_WEIGHT_MASK(32, 16, 10, 2, 1);
153   INIT_WEIGHT_MASK(32, 32, 10, 2, 2);
154   INIT_WEIGHT_MASK(32, 64, 10, 2, 3);
155   INIT_WEIGHT_MASK(64, 16, 10, 3, 1);
156   INIT_WEIGHT_MASK(64, 32, 10, 3, 2);
157   INIT_WEIGHT_MASK(64, 64, 10, 3, 3);
158   INIT_WEIGHT_MASK(64, 128, 10, 3, 4);
159   INIT_WEIGHT_MASK(128, 64, 10, 4, 3);
160   INIT_WEIGHT_MASK(128, 128, 10, 4, 4);
161 #else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
162   static_cast<void>(dsp);
163 #ifndef LIBGAV1_Dsp10bpp_WeightMask_8x8
164   INIT_WEIGHT_MASK(8, 8, 10, 0, 0);
165 #endif
166 #ifndef LIBGAV1_Dsp10bpp_WeightMask_8x16
167   INIT_WEIGHT_MASK(8, 16, 10, 0, 1);
168 #endif
169 #ifndef LIBGAV1_Dsp10bpp_WeightMask_8x32
170   INIT_WEIGHT_MASK(8, 32, 10, 0, 2);
171 #endif
172 #ifndef LIBGAV1_Dsp10bpp_WeightMask_16x8
173   INIT_WEIGHT_MASK(16, 8, 10, 1, 0);
174 #endif
175 #ifndef LIBGAV1_Dsp10bpp_WeightMask_16x16
176   INIT_WEIGHT_MASK(16, 16, 10, 1, 1);
177 #endif
178 #ifndef LIBGAV1_Dsp10bpp_WeightMask_16x32
179   INIT_WEIGHT_MASK(16, 32, 10, 1, 2);
180 #endif
181 #ifndef LIBGAV1_Dsp10bpp_WeightMask_16x64
182   INIT_WEIGHT_MASK(16, 64, 10, 1, 3);
183 #endif
184 #ifndef LIBGAV1_Dsp10bpp_WeightMask_32x8
185   INIT_WEIGHT_MASK(32, 8, 10, 2, 0);
186 #endif
187 #ifndef LIBGAV1_Dsp10bpp_WeightMask_32x16
188   INIT_WEIGHT_MASK(32, 16, 10, 2, 1);
189 #endif
190 #ifndef LIBGAV1_Dsp10bpp_WeightMask_32x32
191   INIT_WEIGHT_MASK(32, 32, 10, 2, 2);
192 #endif
193 #ifndef LIBGAV1_Dsp10bpp_WeightMask_32x64
194   INIT_WEIGHT_MASK(32, 64, 10, 2, 3);
195 #endif
196 #ifndef LIBGAV1_Dsp10bpp_WeightMask_64x16
197   INIT_WEIGHT_MASK(64, 16, 10, 3, 1);
198 #endif
199 #ifndef LIBGAV1_Dsp10bpp_WeightMask_64x32
200   INIT_WEIGHT_MASK(64, 32, 10, 3, 2);
201 #endif
202 #ifndef LIBGAV1_Dsp10bpp_WeightMask_64x64
203   INIT_WEIGHT_MASK(64, 64, 10, 3, 3);
204 #endif
205 #ifndef LIBGAV1_Dsp10bpp_WeightMask_64x128
206   INIT_WEIGHT_MASK(64, 128, 10, 3, 4);
207 #endif
208 #ifndef LIBGAV1_Dsp10bpp_WeightMask_128x64
209   INIT_WEIGHT_MASK(128, 64, 10, 4, 3);
210 #endif
211 #ifndef LIBGAV1_Dsp10bpp_WeightMask_128x128
212   INIT_WEIGHT_MASK(128, 128, 10, 4, 4);
213 #endif
214 #endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
215 }
216 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
217 
218 #if LIBGAV1_MAX_BITDEPTH == 12
Init12bpp()219 void Init12bpp() {
220   Dsp* const dsp = dsp_internal::GetWritableDspTable(12);
221   assert(dsp != nullptr);
222 #if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
223   INIT_WEIGHT_MASK(8, 8, 12, 0, 0);
224   INIT_WEIGHT_MASK(8, 16, 12, 0, 1);
225   INIT_WEIGHT_MASK(8, 32, 12, 0, 2);
226   INIT_WEIGHT_MASK(16, 8, 12, 1, 0);
227   INIT_WEIGHT_MASK(16, 16, 12, 1, 1);
228   INIT_WEIGHT_MASK(16, 32, 12, 1, 2);
229   INIT_WEIGHT_MASK(16, 64, 12, 1, 3);
230   INIT_WEIGHT_MASK(32, 8, 12, 2, 0);
231   INIT_WEIGHT_MASK(32, 16, 12, 2, 1);
232   INIT_WEIGHT_MASK(32, 32, 12, 2, 2);
233   INIT_WEIGHT_MASK(32, 64, 12, 2, 3);
234   INIT_WEIGHT_MASK(64, 16, 12, 3, 1);
235   INIT_WEIGHT_MASK(64, 32, 12, 3, 2);
236   INIT_WEIGHT_MASK(64, 64, 12, 3, 3);
237   INIT_WEIGHT_MASK(64, 128, 12, 3, 4);
238   INIT_WEIGHT_MASK(128, 64, 12, 4, 3);
239   INIT_WEIGHT_MASK(128, 128, 12, 4, 4);
240 #else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
241   static_cast<void>(dsp);
242 #ifndef LIBGAV1_Dsp12bpp_WeightMask_8x8
243   INIT_WEIGHT_MASK(8, 8, 12, 0, 0);
244 #endif
245 #ifndef LIBGAV1_Dsp12bpp_WeightMask_8x16
246   INIT_WEIGHT_MASK(8, 16, 12, 0, 1);
247 #endif
248 #ifndef LIBGAV1_Dsp12bpp_WeightMask_8x32
249   INIT_WEIGHT_MASK(8, 32, 12, 0, 2);
250 #endif
251 #ifndef LIBGAV1_Dsp12bpp_WeightMask_16x8
252   INIT_WEIGHT_MASK(16, 8, 12, 1, 0);
253 #endif
254 #ifndef LIBGAV1_Dsp12bpp_WeightMask_16x16
255   INIT_WEIGHT_MASK(16, 16, 12, 1, 1);
256 #endif
257 #ifndef LIBGAV1_Dsp12bpp_WeightMask_16x32
258   INIT_WEIGHT_MASK(16, 32, 12, 1, 2);
259 #endif
260 #ifndef LIBGAV1_Dsp12bpp_WeightMask_16x64
261   INIT_WEIGHT_MASK(16, 64, 12, 1, 3);
262 #endif
263 #ifndef LIBGAV1_Dsp12bpp_WeightMask_32x8
264   INIT_WEIGHT_MASK(32, 8, 12, 2, 0);
265 #endif
266 #ifndef LIBGAV1_Dsp12bpp_WeightMask_32x16
267   INIT_WEIGHT_MASK(32, 16, 12, 2, 1);
268 #endif
269 #ifndef LIBGAV1_Dsp12bpp_WeightMask_32x32
270   INIT_WEIGHT_MASK(32, 32, 12, 2, 2);
271 #endif
272 #ifndef LIBGAV1_Dsp12bpp_WeightMask_32x64
273   INIT_WEIGHT_MASK(32, 64, 12, 2, 3);
274 #endif
275 #ifndef LIBGAV1_Dsp12bpp_WeightMask_64x16
276   INIT_WEIGHT_MASK(64, 16, 12, 3, 1);
277 #endif
278 #ifndef LIBGAV1_Dsp12bpp_WeightMask_64x32
279   INIT_WEIGHT_MASK(64, 32, 12, 3, 2);
280 #endif
281 #ifndef LIBGAV1_Dsp12bpp_WeightMask_64x64
282   INIT_WEIGHT_MASK(64, 64, 12, 3, 3);
283 #endif
284 #ifndef LIBGAV1_Dsp12bpp_WeightMask_64x128
285   INIT_WEIGHT_MASK(64, 128, 12, 3, 4);
286 #endif
287 #ifndef LIBGAV1_Dsp12bpp_WeightMask_128x64
288   INIT_WEIGHT_MASK(128, 64, 12, 4, 3);
289 #endif
290 #ifndef LIBGAV1_Dsp12bpp_WeightMask_128x128
291   INIT_WEIGHT_MASK(128, 128, 12, 4, 4);
292 #endif
293 #endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
294 }
295 #endif  // LIBGAV1_MAX_BITDEPTH == 12
296 
297 }  // namespace
298 
WeightMaskInit_C()299 void WeightMaskInit_C() {
300   Init8bpp();
301 #if LIBGAV1_MAX_BITDEPTH >= 10
302   Init10bpp();
303 #endif
304 #if LIBGAV1_MAX_BITDEPTH == 12
305   Init12bpp();
306 #endif
307 }
308 
309 }  // namespace dsp
310 }  // namespace libgav1
311