1 // Copyright 2020 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/dsp/weight_mask.h"
16
17 #include <algorithm>
18 #include <cstdint>
19 #include <ostream>
20 #include <string>
21 #include <type_traits>
22
23 #include "absl/strings/match.h"
24 #include "absl/strings/str_format.h"
25 #include "absl/time/clock.h"
26 #include "absl/time/time.h"
27 #include "gtest/gtest.h"
28 #include "src/dsp/dsp.h"
29 #include "src/utils/common.h"
30 #include "src/utils/constants.h"
31 #include "src/utils/cpu.h"
32 #include "src/utils/memory.h"
33 #include "tests/third_party/libvpx/acm_random.h"
34 #include "tests/utils.h"
35
36 namespace libgav1 {
37 namespace dsp {
38 namespace {
39
40 constexpr int kNumSpeedTests = 50000;
41 constexpr int kMaxPredictionSize = 128;
42 // weight_mask is only used with kCompoundPredictionTypeDiffWeighted with
43 // convolve producing the most extreme ranges.
44 // This includes kCompoundOffset in 10bpp and 12bpp.
45 // see: src/dsp/convolve.cc & src/dsp/warp.cc.
46 constexpr int kCompoundPredictionRange[3][2] = {
47 // 8bpp
48 {-5132, 9212},
49 // 10bpp
50 {3988, 61532},
51 // 12bpp
52 {3974, 61559},
53 };
54
GetDigest8bpp(int id)55 const char* GetDigest8bpp(int id) {
56 static const char* const kDigest[] = {
57 "eaca5b6a96dcfe5e44f3926a071b48b3",
58 "1d82c75cfdf8e57925eb1d5301647538",
59 "25bd455d74fb891b97b133c528f8db60",
60 "" /*kBlock4x16*/,
61 "1d82c75cfdf8e57925eb1d5301647538",
62 "25bd455d74fb891b97b133c528f8db60",
63 "62a08776db35a186406a11ab92dee71c",
64 "95131d1dc0e05fcf4bd234d5ce9eea11",
65 "25bd455d74fb891b97b133c528f8db60",
66 "62a08776db35a186406a11ab92dee71c",
67 "95131d1dc0e05fcf4bd234d5ce9eea11",
68 "0b3c75272e0fb0747b9850145d340c4c",
69 "95131d1dc0e05fcf4bd234d5ce9eea11",
70 "0b3c75272e0fb0747b9850145d340c4c",
71 "f26c43d4bc823a89c1ed47ab8708bc06",
72 "0d99bbf31ecddc1c2d5063a68c0e9375",
73 "0d99bbf31ecddc1c2d5063a68c0e9375",
74 "5fb8ec5f582f0ebfe519ed55860f67c4",
75
76 // mask_is_inverse = true.
77 "96811f3b192828ff679e4c9ad8069d7d",
78 "a04dc180c028d55af70240163445523a",
79 "8513e3988233d0a7de316a0179bb6139",
80 "" /*kBlock4x16*/,
81 "a04dc180c028d55af70240163445523a",
82 "8513e3988233d0a7de316a0179bb6139",
83 "f7356d42fb44a6ccb41253ba35b8b3c7",
84 "3d2d61ffc203ee64fe91c9d16168a19d",
85 "8513e3988233d0a7de316a0179bb6139",
86 "f7356d42fb44a6ccb41253ba35b8b3c7",
87 "3d2d61ffc203ee64fe91c9d16168a19d",
88 "87a2011ac69fb597ca4f71bb3c35ebb0",
89 "3d2d61ffc203ee64fe91c9d16168a19d",
90 "87a2011ac69fb597ca4f71bb3c35ebb0",
91 "97100a3639d567046dc8a99fcb84cb2e",
92 "9fabe05a6523da81a45150e19f75acff",
93 "9fabe05a6523da81a45150e19f75acff",
94 "7c0643e4d02421d06d7ca71822a94e1d",
95 };
96 return kDigest[id];
97 }
98
99 #if LIBGAV1_MAX_BITDEPTH >= 10
GetDigest10bpp(int id)100 const char* GetDigest10bpp(int id) {
101 static const char* const kDigest[] = {
102 "5ae8d64b65a671301a457b8a73368ab5",
103 "61535217f179054d4b76a8d9352a223d",
104 "1aa6614773570e7b021cd509849c4180",
105 "" /*kBlock4x16*/,
106 "61535217f179054d4b76a8d9352a223d",
107 "1aa6614773570e7b021cd509849c4180",
108 "f04c2825cfb6408c7778658f71fa176e",
109 "e1694ea1f026dac7fe7e86a84482cf86",
110 "1aa6614773570e7b021cd509849c4180",
111 "f04c2825cfb6408c7778658f71fa176e",
112 "e1694ea1f026dac7fe7e86a84482cf86",
113 "9c4855d44c013fbddb373b2e9e311080",
114 "e1694ea1f026dac7fe7e86a84482cf86",
115 "9c4855d44c013fbddb373b2e9e311080",
116 "f510e743c3efe3b83374a98ef8a30838",
117 "b6e0bd03c521c5f00e90530daa7d4432",
118 "b6e0bd03c521c5f00e90530daa7d4432",
119 "3270d7f621d488aec5b76bcf121debd0",
120
121 // mask_is_inverse = true.
122 "9aa00fcfe21b71e30c5393699122a020",
123 "4d8ce33262cf6b5375f363530815189a",
124 "428625c51ac1bd4585988f7b36dff1db",
125 "" /*kBlock4x16*/,
126 "4d8ce33262cf6b5375f363530815189a",
127 "428625c51ac1bd4585988f7b36dff1db",
128 "1ef63c06a2d9c42da293fdf924032981",
129 "5dd3f201d755d1c22c126a633bfbb3c0",
130 "428625c51ac1bd4585988f7b36dff1db",
131 "1ef63c06a2d9c42da293fdf924032981",
132 "5dd3f201d755d1c22c126a633bfbb3c0",
133 "fe1e6843e6f214939da516dcbea04a79",
134 "5dd3f201d755d1c22c126a633bfbb3c0",
135 "fe1e6843e6f214939da516dcbea04a79",
136 "240187f27389b5e89f9ec6bdbd7d20a7",
137 "44925dab01011a98b8ab1f0308fa852a",
138 "44925dab01011a98b8ab1f0308fa852a",
139 "6d984b2ccfa056278e2130771127a943",
140 };
141 return kDigest[id];
142 }
143 #endif // LIBGAV1_MAX_BITDEPTH >= 10
144
145 #if LIBGAV1_MAX_BITDEPTH == 12
GetDigest12bpp(int id)146 const char* GetDigest12bpp(int id) {
147 static const char* const kDigest[] = {
148 "57629d3872fd52ff4bbec439c5517ec5",
149 "dba421ceeb534756c77167e00ae91a2c",
150 "72e8ac1d450ef0c6c6b03e93856d5cc2",
151 "" /*kBlock4x16*/,
152 "dba421ceeb534756c77167e00ae91a2c",
153 "72e8ac1d450ef0c6c6b03e93856d5cc2",
154 "ae573eb368df04e6a0133b4e15471728",
155 "ceede597b2729357b15e0d08bb9bb760",
156 "72e8ac1d450ef0c6c6b03e93856d5cc2",
157 "ae573eb368df04e6a0133b4e15471728",
158 "ceede597b2729357b15e0d08bb9bb760",
159 "c4976af803d7ad3f92ef26f25b9f3754",
160 "ceede597b2729357b15e0d08bb9bb760",
161 "c4976af803d7ad3f92ef26f25b9f3754",
162 "1d957d49f71bb7f304705a11a597f0cb",
163 "9522d5713fb951b79f42d78fbff914cf",
164 "9522d5713fb951b79f42d78fbff914cf",
165 "422c046013f79a9f46e2c855967570ba",
166
167 // mask_is_inverse = true.
168 "a585cca9bc459d10e081bc0eb847b6e3",
169 "2fa4ec5f74fad2831d216c51c2cdad5a",
170 "d6c9ac69a9eb3059f5bb6e42b486ebcd",
171 "" /*kBlock4x16*/,
172 "2fa4ec5f74fad2831d216c51c2cdad5a",
173 "d6c9ac69a9eb3059f5bb6e42b486ebcd",
174 "2ddd8c8a1841501964011030e2557e20",
175 "97ef2575023dda008711015cf08d7590",
176 "d6c9ac69a9eb3059f5bb6e42b486ebcd",
177 "2ddd8c8a1841501964011030e2557e20",
178 "97ef2575023dda008711015cf08d7590",
179 "d69aff1e0d43395ce305c9be0dfb4c89",
180 "97ef2575023dda008711015cf08d7590",
181 "d69aff1e0d43395ce305c9be0dfb4c89",
182 "48786f640191dcbee5b3321672778519",
183 "6ad4718230353440b01f2bb78348157e",
184 "6ad4718230353440b01f2bb78348157e",
185 "ad49bd7af0ea17c84f434c7dfd0a911d",
186 };
187 return kDigest[id];
188 }
189 #endif // LIBGAV1_MAX_BITDEPTH == 12
190
191 struct WeightMaskTestParam {
WeightMaskTestParamlibgav1::dsp::__anonb6bc346f0111::WeightMaskTestParam192 WeightMaskTestParam(int width, int height, bool mask_is_inverse)
193 : width(width), height(height), mask_is_inverse(mask_is_inverse) {}
194 int width;
195 int height;
196 bool mask_is_inverse;
197 };
198
operator <<(std::ostream & os,const WeightMaskTestParam & param)199 std::ostream& operator<<(std::ostream& os, const WeightMaskTestParam& param) {
200 return os << param.width << "x" << param.height
201 << ", mask_is_inverse: " << param.mask_is_inverse;
202 }
203
204 template <int bitdepth>
205 class WeightMaskTest : public testing::TestWithParam<WeightMaskTestParam>,
206 public test_utils::MaxAlignedAllocable {
207 public:
208 static_assert(bitdepth >= kBitdepth8 && bitdepth <= LIBGAV1_MAX_BITDEPTH, "");
209 WeightMaskTest() = default;
210 ~WeightMaskTest() override = default;
211
SetUp()212 void SetUp() override {
213 test_utils::ResetDspTable(bitdepth);
214 WeightMaskInit_C();
215 const dsp::Dsp* const dsp = dsp::GetDspTable(bitdepth);
216 ASSERT_NE(dsp, nullptr);
217 const int width_index = FloorLog2(width_) - 3;
218 const int height_index = FloorLog2(height_) - 3;
219 const testing::TestInfo* const test_info =
220 testing::UnitTest::GetInstance()->current_test_info();
221 const char* const test_case = test_info->test_suite_name();
222 if (absl::StartsWith(test_case, "C/")) {
223 } else if (absl::StartsWith(test_case, "NEON/")) {
224 WeightMaskInit_NEON();
225 } else if (absl::StartsWith(test_case, "SSE41/")) {
226 if ((GetCpuInfo() & kSSE4_1) == 0) GTEST_SKIP() << "No SSE4.1 support!";
227 WeightMaskInit_SSE4_1();
228 }
229 func_ = dsp->weight_mask[width_index][height_index][mask_is_inverse_];
230 }
231
232 protected:
233 void SetInputData(bool use_fixed_values, int value_1, int value_2);
234 void Test(int num_runs, bool use_fixed_values, int value_1, int value_2);
235
236 private:
237 const int width_ = GetParam().width;
238 const int height_ = GetParam().height;
239 const bool mask_is_inverse_ = GetParam().mask_is_inverse;
240 using PredType =
241 typename std::conditional<bitdepth == 8, int16_t, uint16_t>::type;
242 alignas(
243 kMaxAlignment) PredType block_1_[kMaxPredictionSize * kMaxPredictionSize];
244 alignas(
245 kMaxAlignment) PredType block_2_[kMaxPredictionSize * kMaxPredictionSize];
246 uint8_t mask_[kMaxPredictionSize * kMaxPredictionSize] = {};
247 dsp::WeightMaskFunc func_;
248 };
249
250 template <int bitdepth>
SetInputData(const bool use_fixed_values,const int value_1,const int value_2)251 void WeightMaskTest<bitdepth>::SetInputData(const bool use_fixed_values,
252 const int value_1,
253 const int value_2) {
254 if (use_fixed_values) {
255 std::fill(block_1_, block_1_ + kMaxPredictionSize * kMaxPredictionSize,
256 value_1);
257 std::fill(block_2_, block_2_ + kMaxPredictionSize * kMaxPredictionSize,
258 value_2);
259 } else {
260 constexpr int bitdepth_index = (bitdepth - 8) >> 1;
261 libvpx_test::ACMRandom rnd(libvpx_test::ACMRandom::DeterministicSeed());
262 for (int y = 0; y < height_; ++y) {
263 for (int x = 0; x < width_; ++x) {
264 const int min_val = kCompoundPredictionRange[bitdepth_index][0];
265 const int max_val = kCompoundPredictionRange[bitdepth_index][1];
266 block_1_[y * width_ + x] =
267 static_cast<PredType>(rnd(max_val - min_val) + min_val);
268 block_2_[y * width_ + x] =
269 static_cast<PredType>(rnd(max_val - min_val) + min_val);
270 }
271 }
272 }
273 }
274
DimensionsToBlockSize(int width,int height)275 BlockSize DimensionsToBlockSize(int width, int height) {
276 if (width == 4) {
277 if (height == 4) return kBlock4x4;
278 if (height == 8) return kBlock4x8;
279 if (height == 16) return kBlock4x16;
280 return kBlockInvalid;
281 }
282 if (width == 8) {
283 if (height == 4) return kBlock8x4;
284 if (height == 8) return kBlock8x8;
285 if (height == 16) return kBlock8x16;
286 if (height == 32) return kBlock8x32;
287 return kBlockInvalid;
288 }
289 if (width == 16) {
290 if (height == 4) return kBlock16x4;
291 if (height == 8) return kBlock16x8;
292 if (height == 16) return kBlock16x16;
293 if (height == 32) return kBlock16x32;
294 if (height == 64) return kBlock16x64;
295 return kBlockInvalid;
296 }
297 if (width == 32) {
298 if (height == 8) return kBlock32x8;
299 if (height == 16) return kBlock32x16;
300 if (height == 32) return kBlock32x32;
301 if (height == 64) return kBlock32x64;
302 return kBlockInvalid;
303 }
304 if (width == 64) {
305 if (height == 16) return kBlock64x16;
306 if (height == 32) return kBlock64x32;
307 if (height == 64) return kBlock64x64;
308 if (height == 128) return kBlock64x128;
309 return kBlockInvalid;
310 }
311 if (width == 128) {
312 if (height == 64) return kBlock128x64;
313 if (height == 128) return kBlock128x128;
314 return kBlockInvalid;
315 }
316 return kBlockInvalid;
317 }
318
319 template <int bitdepth>
Test(const int num_runs,const bool use_fixed_values,const int value_1,const int value_2)320 void WeightMaskTest<bitdepth>::Test(const int num_runs,
321 const bool use_fixed_values,
322 const int value_1, const int value_2) {
323 if (func_ == nullptr) return;
324 SetInputData(use_fixed_values, value_1, value_2);
325 const absl::Time start = absl::Now();
326 for (int i = 0; i < num_runs; ++i) {
327 func_(block_1_, block_2_, mask_, width_);
328 }
329 const absl::Duration elapsed_time = absl::Now() - start;
330 if (use_fixed_values) {
331 int fixed_value = (value_1 - value_2 == 0) ? 38 : 64;
332 if (mask_is_inverse_) fixed_value = 64 - fixed_value;
333 for (int y = 0; y < height_; ++y) {
334 for (int x = 0; x < width_; ++x) {
335 ASSERT_EQ(static_cast<int>(mask_[y * width_ + x]), fixed_value)
336 << "x: " << x << " y: " << y;
337 }
338 }
339 } else {
340 const int id_offset = mask_is_inverse_ ? kMaxBlockSizes - 4 : 0;
341 const int id = id_offset +
342 static_cast<int>(DimensionsToBlockSize(width_, height_)) - 4;
343 const char* expected_digest = nullptr;
344 switch (bitdepth) {
345 case 8:
346 expected_digest = GetDigest8bpp(id);
347 break;
348 #if LIBGAV1_MAX_BITDEPTH >= 10
349 case 10:
350 expected_digest = GetDigest10bpp(id);
351 break;
352 #endif
353 #if LIBGAV1_MAX_BITDEPTH == 12
354 case 12:
355 expected_digest = GetDigest12bpp(id);
356 break;
357 #endif
358 }
359 ASSERT_NE(expected_digest, nullptr);
360 test_utils::CheckMd5Digest(
361 absl::StrFormat("BlockSize %dx%d", width_, height_).c_str(),
362 "WeightMask", expected_digest, mask_, sizeof(mask_), elapsed_time);
363 }
364 }
365
366 const WeightMaskTestParam weight_mask_test_param[] = {
367 WeightMaskTestParam(8, 8, false), WeightMaskTestParam(8, 16, false),
368 WeightMaskTestParam(8, 32, false), WeightMaskTestParam(16, 8, false),
369 WeightMaskTestParam(16, 16, false), WeightMaskTestParam(16, 32, false),
370 WeightMaskTestParam(16, 64, false), WeightMaskTestParam(32, 8, false),
371 WeightMaskTestParam(32, 16, false), WeightMaskTestParam(32, 32, false),
372 WeightMaskTestParam(32, 64, false), WeightMaskTestParam(64, 16, false),
373 WeightMaskTestParam(64, 32, false), WeightMaskTestParam(64, 64, false),
374 WeightMaskTestParam(64, 128, false), WeightMaskTestParam(128, 64, false),
375 WeightMaskTestParam(128, 128, false), WeightMaskTestParam(8, 8, true),
376 WeightMaskTestParam(8, 16, true), WeightMaskTestParam(8, 32, true),
377 WeightMaskTestParam(16, 8, true), WeightMaskTestParam(16, 16, true),
378 WeightMaskTestParam(16, 32, true), WeightMaskTestParam(16, 64, true),
379 WeightMaskTestParam(32, 8, true), WeightMaskTestParam(32, 16, true),
380 WeightMaskTestParam(32, 32, true), WeightMaskTestParam(32, 64, true),
381 WeightMaskTestParam(64, 16, true), WeightMaskTestParam(64, 32, true),
382 WeightMaskTestParam(64, 64, true), WeightMaskTestParam(64, 128, true),
383 WeightMaskTestParam(128, 64, true), WeightMaskTestParam(128, 128, true),
384 };
385
386 using WeightMaskTest8bpp = WeightMaskTest<8>;
387
TEST_P(WeightMaskTest8bpp,FixedValues)388 TEST_P(WeightMaskTest8bpp, FixedValues) {
389 const int min = kCompoundPredictionRange[0][0];
390 const int max = kCompoundPredictionRange[0][1];
391 Test(1, true, min, min);
392 Test(1, true, min, max);
393 Test(1, true, max, min);
394 Test(1, true, max, max);
395 }
396
TEST_P(WeightMaskTest8bpp,RandomValues)397 TEST_P(WeightMaskTest8bpp, RandomValues) { Test(1, false, -1, -1); }
398
TEST_P(WeightMaskTest8bpp,DISABLED_Speed)399 TEST_P(WeightMaskTest8bpp, DISABLED_Speed) {
400 Test(kNumSpeedTests, false, -1, -1);
401 }
402
403 INSTANTIATE_TEST_SUITE_P(C, WeightMaskTest8bpp,
404 testing::ValuesIn(weight_mask_test_param));
405 #if LIBGAV1_ENABLE_NEON
406 INSTANTIATE_TEST_SUITE_P(NEON, WeightMaskTest8bpp,
407 testing::ValuesIn(weight_mask_test_param));
408 #endif
409 #if LIBGAV1_ENABLE_SSE4_1
410 INSTANTIATE_TEST_SUITE_P(SSE41, WeightMaskTest8bpp,
411 testing::ValuesIn(weight_mask_test_param));
412 #endif
413
414 #if LIBGAV1_MAX_BITDEPTH >= 10
415 using WeightMaskTest10bpp = WeightMaskTest<10>;
416
TEST_P(WeightMaskTest10bpp,FixedValues)417 TEST_P(WeightMaskTest10bpp, FixedValues) {
418 const int min = kCompoundPredictionRange[1][0];
419 const int max = kCompoundPredictionRange[1][1];
420 Test(1, true, min, min);
421 Test(1, true, min, max);
422 Test(1, true, max, min);
423 Test(1, true, max, max);
424 }
425
TEST_P(WeightMaskTest10bpp,RandomValues)426 TEST_P(WeightMaskTest10bpp, RandomValues) { Test(1, false, -1, -1); }
427
TEST_P(WeightMaskTest10bpp,DISABLED_Speed)428 TEST_P(WeightMaskTest10bpp, DISABLED_Speed) {
429 Test(kNumSpeedTests, false, -1, -1);
430 }
431
432 INSTANTIATE_TEST_SUITE_P(C, WeightMaskTest10bpp,
433 testing::ValuesIn(weight_mask_test_param));
434 #if LIBGAV1_ENABLE_NEON
435 INSTANTIATE_TEST_SUITE_P(NEON, WeightMaskTest10bpp,
436 testing::ValuesIn(weight_mask_test_param));
437 #endif
438 #if LIBGAV1_ENABLE_SSE4_1
439 INSTANTIATE_TEST_SUITE_P(SSE41, WeightMaskTest10bpp,
440 testing::ValuesIn(weight_mask_test_param));
441 #endif
442 #endif // LIBGAV1_MAX_BITDEPTH >= 10
443
444 #if LIBGAV1_MAX_BITDEPTH == 12
445 using WeightMaskTest12bpp = WeightMaskTest<12>;
446
TEST_P(WeightMaskTest12bpp,FixedValues)447 TEST_P(WeightMaskTest12bpp, FixedValues) {
448 const int min = kCompoundPredictionRange[2][0];
449 const int max = kCompoundPredictionRange[2][1];
450 Test(1, true, min, min);
451 Test(1, true, min, max);
452 Test(1, true, max, min);
453 Test(1, true, max, max);
454 }
455
TEST_P(WeightMaskTest12bpp,RandomValues)456 TEST_P(WeightMaskTest12bpp, RandomValues) { Test(1, false, -1, -1); }
457
TEST_P(WeightMaskTest12bpp,DISABLED_Speed)458 TEST_P(WeightMaskTest12bpp, DISABLED_Speed) {
459 Test(kNumSpeedTests, false, -1, -1);
460 }
461
462 INSTANTIATE_TEST_SUITE_P(C, WeightMaskTest12bpp,
463 testing::ValuesIn(weight_mask_test_param));
464 #endif // LIBGAV1_MAX_BITDEPTH == 12
465
466 } // namespace
467 } // namespace dsp
468 } // namespace libgav1
469