xref: /aosp_15_r20/external/libgav1/src/dsp/weight_mask_test.cc (revision 095378508e87ed692bf8dfeb34008b65b3735891)
1 // Copyright 2020 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/weight_mask.h"
16 
17 #include <algorithm>
18 #include <cstdint>
19 #include <ostream>
20 #include <string>
21 #include <type_traits>
22 
23 #include "absl/strings/match.h"
24 #include "absl/strings/str_format.h"
25 #include "absl/time/clock.h"
26 #include "absl/time/time.h"
27 #include "gtest/gtest.h"
28 #include "src/dsp/dsp.h"
29 #include "src/utils/common.h"
30 #include "src/utils/constants.h"
31 #include "src/utils/cpu.h"
32 #include "src/utils/memory.h"
33 #include "tests/third_party/libvpx/acm_random.h"
34 #include "tests/utils.h"
35 
36 namespace libgav1 {
37 namespace dsp {
38 namespace {
39 
40 constexpr int kNumSpeedTests = 50000;
41 constexpr int kMaxPredictionSize = 128;
42 // weight_mask is only used with kCompoundPredictionTypeDiffWeighted with
43 // convolve producing the most extreme ranges.
44 // This includes kCompoundOffset in 10bpp and 12bpp.
45 // see: src/dsp/convolve.cc & src/dsp/warp.cc.
46 constexpr int kCompoundPredictionRange[3][2] = {
47     // 8bpp
48     {-5132, 9212},
49     // 10bpp
50     {3988, 61532},
51     // 12bpp
52     {3974, 61559},
53 };
54 
GetDigest8bpp(int id)55 const char* GetDigest8bpp(int id) {
56   static const char* const kDigest[] = {
57       "eaca5b6a96dcfe5e44f3926a071b48b3",
58       "1d82c75cfdf8e57925eb1d5301647538",
59       "25bd455d74fb891b97b133c528f8db60",
60       "" /*kBlock4x16*/,
61       "1d82c75cfdf8e57925eb1d5301647538",
62       "25bd455d74fb891b97b133c528f8db60",
63       "62a08776db35a186406a11ab92dee71c",
64       "95131d1dc0e05fcf4bd234d5ce9eea11",
65       "25bd455d74fb891b97b133c528f8db60",
66       "62a08776db35a186406a11ab92dee71c",
67       "95131d1dc0e05fcf4bd234d5ce9eea11",
68       "0b3c75272e0fb0747b9850145d340c4c",
69       "95131d1dc0e05fcf4bd234d5ce9eea11",
70       "0b3c75272e0fb0747b9850145d340c4c",
71       "f26c43d4bc823a89c1ed47ab8708bc06",
72       "0d99bbf31ecddc1c2d5063a68c0e9375",
73       "0d99bbf31ecddc1c2d5063a68c0e9375",
74       "5fb8ec5f582f0ebfe519ed55860f67c4",
75 
76       // mask_is_inverse = true.
77       "96811f3b192828ff679e4c9ad8069d7d",
78       "a04dc180c028d55af70240163445523a",
79       "8513e3988233d0a7de316a0179bb6139",
80       "" /*kBlock4x16*/,
81       "a04dc180c028d55af70240163445523a",
82       "8513e3988233d0a7de316a0179bb6139",
83       "f7356d42fb44a6ccb41253ba35b8b3c7",
84       "3d2d61ffc203ee64fe91c9d16168a19d",
85       "8513e3988233d0a7de316a0179bb6139",
86       "f7356d42fb44a6ccb41253ba35b8b3c7",
87       "3d2d61ffc203ee64fe91c9d16168a19d",
88       "87a2011ac69fb597ca4f71bb3c35ebb0",
89       "3d2d61ffc203ee64fe91c9d16168a19d",
90       "87a2011ac69fb597ca4f71bb3c35ebb0",
91       "97100a3639d567046dc8a99fcb84cb2e",
92       "9fabe05a6523da81a45150e19f75acff",
93       "9fabe05a6523da81a45150e19f75acff",
94       "7c0643e4d02421d06d7ca71822a94e1d",
95   };
96   return kDigest[id];
97 }
98 
99 #if LIBGAV1_MAX_BITDEPTH >= 10
GetDigest10bpp(int id)100 const char* GetDigest10bpp(int id) {
101   static const char* const kDigest[] = {
102       "5ae8d64b65a671301a457b8a73368ab5",
103       "61535217f179054d4b76a8d9352a223d",
104       "1aa6614773570e7b021cd509849c4180",
105       "" /*kBlock4x16*/,
106       "61535217f179054d4b76a8d9352a223d",
107       "1aa6614773570e7b021cd509849c4180",
108       "f04c2825cfb6408c7778658f71fa176e",
109       "e1694ea1f026dac7fe7e86a84482cf86",
110       "1aa6614773570e7b021cd509849c4180",
111       "f04c2825cfb6408c7778658f71fa176e",
112       "e1694ea1f026dac7fe7e86a84482cf86",
113       "9c4855d44c013fbddb373b2e9e311080",
114       "e1694ea1f026dac7fe7e86a84482cf86",
115       "9c4855d44c013fbddb373b2e9e311080",
116       "f510e743c3efe3b83374a98ef8a30838",
117       "b6e0bd03c521c5f00e90530daa7d4432",
118       "b6e0bd03c521c5f00e90530daa7d4432",
119       "3270d7f621d488aec5b76bcf121debd0",
120 
121       // mask_is_inverse = true.
122       "9aa00fcfe21b71e30c5393699122a020",
123       "4d8ce33262cf6b5375f363530815189a",
124       "428625c51ac1bd4585988f7b36dff1db",
125       "" /*kBlock4x16*/,
126       "4d8ce33262cf6b5375f363530815189a",
127       "428625c51ac1bd4585988f7b36dff1db",
128       "1ef63c06a2d9c42da293fdf924032981",
129       "5dd3f201d755d1c22c126a633bfbb3c0",
130       "428625c51ac1bd4585988f7b36dff1db",
131       "1ef63c06a2d9c42da293fdf924032981",
132       "5dd3f201d755d1c22c126a633bfbb3c0",
133       "fe1e6843e6f214939da516dcbea04a79",
134       "5dd3f201d755d1c22c126a633bfbb3c0",
135       "fe1e6843e6f214939da516dcbea04a79",
136       "240187f27389b5e89f9ec6bdbd7d20a7",
137       "44925dab01011a98b8ab1f0308fa852a",
138       "44925dab01011a98b8ab1f0308fa852a",
139       "6d984b2ccfa056278e2130771127a943",
140   };
141   return kDigest[id];
142 }
143 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
144 
145 #if LIBGAV1_MAX_BITDEPTH == 12
GetDigest12bpp(int id)146 const char* GetDigest12bpp(int id) {
147   static const char* const kDigest[] = {
148       "57629d3872fd52ff4bbec439c5517ec5",
149       "dba421ceeb534756c77167e00ae91a2c",
150       "72e8ac1d450ef0c6c6b03e93856d5cc2",
151       "" /*kBlock4x16*/,
152       "dba421ceeb534756c77167e00ae91a2c",
153       "72e8ac1d450ef0c6c6b03e93856d5cc2",
154       "ae573eb368df04e6a0133b4e15471728",
155       "ceede597b2729357b15e0d08bb9bb760",
156       "72e8ac1d450ef0c6c6b03e93856d5cc2",
157       "ae573eb368df04e6a0133b4e15471728",
158       "ceede597b2729357b15e0d08bb9bb760",
159       "c4976af803d7ad3f92ef26f25b9f3754",
160       "ceede597b2729357b15e0d08bb9bb760",
161       "c4976af803d7ad3f92ef26f25b9f3754",
162       "1d957d49f71bb7f304705a11a597f0cb",
163       "9522d5713fb951b79f42d78fbff914cf",
164       "9522d5713fb951b79f42d78fbff914cf",
165       "422c046013f79a9f46e2c855967570ba",
166 
167       // mask_is_inverse = true.
168       "a585cca9bc459d10e081bc0eb847b6e3",
169       "2fa4ec5f74fad2831d216c51c2cdad5a",
170       "d6c9ac69a9eb3059f5bb6e42b486ebcd",
171       "" /*kBlock4x16*/,
172       "2fa4ec5f74fad2831d216c51c2cdad5a",
173       "d6c9ac69a9eb3059f5bb6e42b486ebcd",
174       "2ddd8c8a1841501964011030e2557e20",
175       "97ef2575023dda008711015cf08d7590",
176       "d6c9ac69a9eb3059f5bb6e42b486ebcd",
177       "2ddd8c8a1841501964011030e2557e20",
178       "97ef2575023dda008711015cf08d7590",
179       "d69aff1e0d43395ce305c9be0dfb4c89",
180       "97ef2575023dda008711015cf08d7590",
181       "d69aff1e0d43395ce305c9be0dfb4c89",
182       "48786f640191dcbee5b3321672778519",
183       "6ad4718230353440b01f2bb78348157e",
184       "6ad4718230353440b01f2bb78348157e",
185       "ad49bd7af0ea17c84f434c7dfd0a911d",
186   };
187   return kDigest[id];
188 }
189 #endif  // LIBGAV1_MAX_BITDEPTH == 12
190 
191 struct WeightMaskTestParam {
WeightMaskTestParamlibgav1::dsp::__anonb6bc346f0111::WeightMaskTestParam192   WeightMaskTestParam(int width, int height, bool mask_is_inverse)
193       : width(width), height(height), mask_is_inverse(mask_is_inverse) {}
194   int width;
195   int height;
196   bool mask_is_inverse;
197 };
198 
operator <<(std::ostream & os,const WeightMaskTestParam & param)199 std::ostream& operator<<(std::ostream& os, const WeightMaskTestParam& param) {
200   return os << param.width << "x" << param.height
201             << ", mask_is_inverse: " << param.mask_is_inverse;
202 }
203 
204 template <int bitdepth>
205 class WeightMaskTest : public testing::TestWithParam<WeightMaskTestParam>,
206                        public test_utils::MaxAlignedAllocable {
207  public:
208   static_assert(bitdepth >= kBitdepth8 && bitdepth <= LIBGAV1_MAX_BITDEPTH, "");
209   WeightMaskTest() = default;
210   ~WeightMaskTest() override = default;
211 
SetUp()212   void SetUp() override {
213     test_utils::ResetDspTable(bitdepth);
214     WeightMaskInit_C();
215     const dsp::Dsp* const dsp = dsp::GetDspTable(bitdepth);
216     ASSERT_NE(dsp, nullptr);
217     const int width_index = FloorLog2(width_) - 3;
218     const int height_index = FloorLog2(height_) - 3;
219     const testing::TestInfo* const test_info =
220         testing::UnitTest::GetInstance()->current_test_info();
221     const char* const test_case = test_info->test_suite_name();
222     if (absl::StartsWith(test_case, "C/")) {
223     } else if (absl::StartsWith(test_case, "NEON/")) {
224       WeightMaskInit_NEON();
225     } else if (absl::StartsWith(test_case, "SSE41/")) {
226       if ((GetCpuInfo() & kSSE4_1) == 0) GTEST_SKIP() << "No SSE4.1 support!";
227       WeightMaskInit_SSE4_1();
228     }
229     func_ = dsp->weight_mask[width_index][height_index][mask_is_inverse_];
230   }
231 
232  protected:
233   void SetInputData(bool use_fixed_values, int value_1, int value_2);
234   void Test(int num_runs, bool use_fixed_values, int value_1, int value_2);
235 
236  private:
237   const int width_ = GetParam().width;
238   const int height_ = GetParam().height;
239   const bool mask_is_inverse_ = GetParam().mask_is_inverse;
240   using PredType =
241       typename std::conditional<bitdepth == 8, int16_t, uint16_t>::type;
242   alignas(
243       kMaxAlignment) PredType block_1_[kMaxPredictionSize * kMaxPredictionSize];
244   alignas(
245       kMaxAlignment) PredType block_2_[kMaxPredictionSize * kMaxPredictionSize];
246   uint8_t mask_[kMaxPredictionSize * kMaxPredictionSize] = {};
247   dsp::WeightMaskFunc func_;
248 };
249 
250 template <int bitdepth>
SetInputData(const bool use_fixed_values,const int value_1,const int value_2)251 void WeightMaskTest<bitdepth>::SetInputData(const bool use_fixed_values,
252                                             const int value_1,
253                                             const int value_2) {
254   if (use_fixed_values) {
255     std::fill(block_1_, block_1_ + kMaxPredictionSize * kMaxPredictionSize,
256               value_1);
257     std::fill(block_2_, block_2_ + kMaxPredictionSize * kMaxPredictionSize,
258               value_2);
259   } else {
260     constexpr int bitdepth_index = (bitdepth - 8) >> 1;
261     libvpx_test::ACMRandom rnd(libvpx_test::ACMRandom::DeterministicSeed());
262     for (int y = 0; y < height_; ++y) {
263       for (int x = 0; x < width_; ++x) {
264         const int min_val = kCompoundPredictionRange[bitdepth_index][0];
265         const int max_val = kCompoundPredictionRange[bitdepth_index][1];
266         block_1_[y * width_ + x] =
267             static_cast<PredType>(rnd(max_val - min_val) + min_val);
268         block_2_[y * width_ + x] =
269             static_cast<PredType>(rnd(max_val - min_val) + min_val);
270       }
271     }
272   }
273 }
274 
DimensionsToBlockSize(int width,int height)275 BlockSize DimensionsToBlockSize(int width, int height) {
276   if (width == 4) {
277     if (height == 4) return kBlock4x4;
278     if (height == 8) return kBlock4x8;
279     if (height == 16) return kBlock4x16;
280     return kBlockInvalid;
281   }
282   if (width == 8) {
283     if (height == 4) return kBlock8x4;
284     if (height == 8) return kBlock8x8;
285     if (height == 16) return kBlock8x16;
286     if (height == 32) return kBlock8x32;
287     return kBlockInvalid;
288   }
289   if (width == 16) {
290     if (height == 4) return kBlock16x4;
291     if (height == 8) return kBlock16x8;
292     if (height == 16) return kBlock16x16;
293     if (height == 32) return kBlock16x32;
294     if (height == 64) return kBlock16x64;
295     return kBlockInvalid;
296   }
297   if (width == 32) {
298     if (height == 8) return kBlock32x8;
299     if (height == 16) return kBlock32x16;
300     if (height == 32) return kBlock32x32;
301     if (height == 64) return kBlock32x64;
302     return kBlockInvalid;
303   }
304   if (width == 64) {
305     if (height == 16) return kBlock64x16;
306     if (height == 32) return kBlock64x32;
307     if (height == 64) return kBlock64x64;
308     if (height == 128) return kBlock64x128;
309     return kBlockInvalid;
310   }
311   if (width == 128) {
312     if (height == 64) return kBlock128x64;
313     if (height == 128) return kBlock128x128;
314     return kBlockInvalid;
315   }
316   return kBlockInvalid;
317 }
318 
319 template <int bitdepth>
Test(const int num_runs,const bool use_fixed_values,const int value_1,const int value_2)320 void WeightMaskTest<bitdepth>::Test(const int num_runs,
321                                     const bool use_fixed_values,
322                                     const int value_1, const int value_2) {
323   if (func_ == nullptr) return;
324   SetInputData(use_fixed_values, value_1, value_2);
325   const absl::Time start = absl::Now();
326   for (int i = 0; i < num_runs; ++i) {
327     func_(block_1_, block_2_, mask_, width_);
328   }
329   const absl::Duration elapsed_time = absl::Now() - start;
330   if (use_fixed_values) {
331     int fixed_value = (value_1 - value_2 == 0) ? 38 : 64;
332     if (mask_is_inverse_) fixed_value = 64 - fixed_value;
333     for (int y = 0; y < height_; ++y) {
334       for (int x = 0; x < width_; ++x) {
335         ASSERT_EQ(static_cast<int>(mask_[y * width_ + x]), fixed_value)
336             << "x: " << x << " y: " << y;
337       }
338     }
339   } else {
340     const int id_offset = mask_is_inverse_ ? kMaxBlockSizes - 4 : 0;
341     const int id = id_offset +
342                    static_cast<int>(DimensionsToBlockSize(width_, height_)) - 4;
343     const char* expected_digest = nullptr;
344     switch (bitdepth) {
345       case 8:
346         expected_digest = GetDigest8bpp(id);
347         break;
348 #if LIBGAV1_MAX_BITDEPTH >= 10
349       case 10:
350         expected_digest = GetDigest10bpp(id);
351         break;
352 #endif
353 #if LIBGAV1_MAX_BITDEPTH == 12
354       case 12:
355         expected_digest = GetDigest12bpp(id);
356         break;
357 #endif
358     }
359     ASSERT_NE(expected_digest, nullptr);
360     test_utils::CheckMd5Digest(
361         absl::StrFormat("BlockSize %dx%d", width_, height_).c_str(),
362         "WeightMask", expected_digest, mask_, sizeof(mask_), elapsed_time);
363   }
364 }
365 
366 const WeightMaskTestParam weight_mask_test_param[] = {
367     WeightMaskTestParam(8, 8, false),     WeightMaskTestParam(8, 16, false),
368     WeightMaskTestParam(8, 32, false),    WeightMaskTestParam(16, 8, false),
369     WeightMaskTestParam(16, 16, false),   WeightMaskTestParam(16, 32, false),
370     WeightMaskTestParam(16, 64, false),   WeightMaskTestParam(32, 8, false),
371     WeightMaskTestParam(32, 16, false),   WeightMaskTestParam(32, 32, false),
372     WeightMaskTestParam(32, 64, false),   WeightMaskTestParam(64, 16, false),
373     WeightMaskTestParam(64, 32, false),   WeightMaskTestParam(64, 64, false),
374     WeightMaskTestParam(64, 128, false),  WeightMaskTestParam(128, 64, false),
375     WeightMaskTestParam(128, 128, false), WeightMaskTestParam(8, 8, true),
376     WeightMaskTestParam(8, 16, true),     WeightMaskTestParam(8, 32, true),
377     WeightMaskTestParam(16, 8, true),     WeightMaskTestParam(16, 16, true),
378     WeightMaskTestParam(16, 32, true),    WeightMaskTestParam(16, 64, true),
379     WeightMaskTestParam(32, 8, true),     WeightMaskTestParam(32, 16, true),
380     WeightMaskTestParam(32, 32, true),    WeightMaskTestParam(32, 64, true),
381     WeightMaskTestParam(64, 16, true),    WeightMaskTestParam(64, 32, true),
382     WeightMaskTestParam(64, 64, true),    WeightMaskTestParam(64, 128, true),
383     WeightMaskTestParam(128, 64, true),   WeightMaskTestParam(128, 128, true),
384 };
385 
386 using WeightMaskTest8bpp = WeightMaskTest<8>;
387 
TEST_P(WeightMaskTest8bpp,FixedValues)388 TEST_P(WeightMaskTest8bpp, FixedValues) {
389   const int min = kCompoundPredictionRange[0][0];
390   const int max = kCompoundPredictionRange[0][1];
391   Test(1, true, min, min);
392   Test(1, true, min, max);
393   Test(1, true, max, min);
394   Test(1, true, max, max);
395 }
396 
TEST_P(WeightMaskTest8bpp,RandomValues)397 TEST_P(WeightMaskTest8bpp, RandomValues) { Test(1, false, -1, -1); }
398 
TEST_P(WeightMaskTest8bpp,DISABLED_Speed)399 TEST_P(WeightMaskTest8bpp, DISABLED_Speed) {
400   Test(kNumSpeedTests, false, -1, -1);
401 }
402 
403 INSTANTIATE_TEST_SUITE_P(C, WeightMaskTest8bpp,
404                          testing::ValuesIn(weight_mask_test_param));
405 #if LIBGAV1_ENABLE_NEON
406 INSTANTIATE_TEST_SUITE_P(NEON, WeightMaskTest8bpp,
407                          testing::ValuesIn(weight_mask_test_param));
408 #endif
409 #if LIBGAV1_ENABLE_SSE4_1
410 INSTANTIATE_TEST_SUITE_P(SSE41, WeightMaskTest8bpp,
411                          testing::ValuesIn(weight_mask_test_param));
412 #endif
413 
414 #if LIBGAV1_MAX_BITDEPTH >= 10
415 using WeightMaskTest10bpp = WeightMaskTest<10>;
416 
TEST_P(WeightMaskTest10bpp,FixedValues)417 TEST_P(WeightMaskTest10bpp, FixedValues) {
418   const int min = kCompoundPredictionRange[1][0];
419   const int max = kCompoundPredictionRange[1][1];
420   Test(1, true, min, min);
421   Test(1, true, min, max);
422   Test(1, true, max, min);
423   Test(1, true, max, max);
424 }
425 
TEST_P(WeightMaskTest10bpp,RandomValues)426 TEST_P(WeightMaskTest10bpp, RandomValues) { Test(1, false, -1, -1); }
427 
TEST_P(WeightMaskTest10bpp,DISABLED_Speed)428 TEST_P(WeightMaskTest10bpp, DISABLED_Speed) {
429   Test(kNumSpeedTests, false, -1, -1);
430 }
431 
432 INSTANTIATE_TEST_SUITE_P(C, WeightMaskTest10bpp,
433                          testing::ValuesIn(weight_mask_test_param));
434 #if LIBGAV1_ENABLE_NEON
435 INSTANTIATE_TEST_SUITE_P(NEON, WeightMaskTest10bpp,
436                          testing::ValuesIn(weight_mask_test_param));
437 #endif
438 #if LIBGAV1_ENABLE_SSE4_1
439 INSTANTIATE_TEST_SUITE_P(SSE41, WeightMaskTest10bpp,
440                          testing::ValuesIn(weight_mask_test_param));
441 #endif
442 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
443 
444 #if LIBGAV1_MAX_BITDEPTH == 12
445 using WeightMaskTest12bpp = WeightMaskTest<12>;
446 
TEST_P(WeightMaskTest12bpp,FixedValues)447 TEST_P(WeightMaskTest12bpp, FixedValues) {
448   const int min = kCompoundPredictionRange[2][0];
449   const int max = kCompoundPredictionRange[2][1];
450   Test(1, true, min, min);
451   Test(1, true, min, max);
452   Test(1, true, max, min);
453   Test(1, true, max, max);
454 }
455 
TEST_P(WeightMaskTest12bpp,RandomValues)456 TEST_P(WeightMaskTest12bpp, RandomValues) { Test(1, false, -1, -1); }
457 
TEST_P(WeightMaskTest12bpp,DISABLED_Speed)458 TEST_P(WeightMaskTest12bpp, DISABLED_Speed) {
459   Test(kNumSpeedTests, false, -1, -1);
460 }
461 
462 INSTANTIATE_TEST_SUITE_P(C, WeightMaskTest12bpp,
463                          testing::ValuesIn(weight_mask_test_param));
464 #endif  // LIBGAV1_MAX_BITDEPTH == 12
465 
466 }  // namespace
467 }  // namespace dsp
468 }  // namespace libgav1
469