1 /*
2 * Copyright (c) 2016, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <math.h>
13 #include <stdlib.h>
14 #include <string.h>
15 #include <tuple>
16
17 #include "aom_dsp/aom_dsp_common.h"
18 #include "gtest/gtest.h"
19
20 #include "config/av1_rtcd.h"
21 #include "config/aom_dsp_rtcd.h"
22 #include "test/acm_random.h"
23 #include "test/register_state_check.h"
24 #include "test/transform_test_base.h"
25 #include "test/util.h"
26 #include "av1/common/entropy.h"
27 #include "aom/aom_codec.h"
28 #include "aom/aom_integer.h"
29 #include "aom_ports/mem.h"
30
31 using libaom_test::ACMRandom;
32
33 namespace {
34 typedef void (*FdctFunc)(const int16_t *in, tran_low_t *out, int stride);
35 typedef void (*IdctFunc)(const tran_low_t *in, uint8_t *out, int stride);
36
37 using libaom_test::FhtFunc;
38
39 typedef std::tuple<FdctFunc, IdctFunc, TX_TYPE, aom_bit_depth_t, int, FdctFunc>
40 Dct4x4Param;
41
fwht4x4_ref(const int16_t * in,tran_low_t * out,int stride,TxfmParam *)42 void fwht4x4_ref(const int16_t *in, tran_low_t *out, int stride,
43 TxfmParam * /*txfm_param*/) {
44 av1_fwht4x4_c(in, out, stride);
45 }
46
iwht4x4_10_c(const tran_low_t * in,uint8_t * out,int stride)47 void iwht4x4_10_c(const tran_low_t *in, uint8_t *out, int stride) {
48 av1_highbd_iwht4x4_16_add_c(in, out, stride, 10);
49 }
50
iwht4x4_12_c(const tran_low_t * in,uint8_t * out,int stride)51 void iwht4x4_12_c(const tran_low_t *in, uint8_t *out, int stride) {
52 av1_highbd_iwht4x4_16_add_c(in, out, stride, 12);
53 }
54
55 #if HAVE_SSE4_1
56
iwht4x4_10_sse4_1(const tran_low_t * in,uint8_t * out,int stride)57 void iwht4x4_10_sse4_1(const tran_low_t *in, uint8_t *out, int stride) {
58 av1_highbd_iwht4x4_16_add_sse4_1(in, out, stride, 10);
59 }
60
iwht4x4_12_sse4_1(const tran_low_t * in,uint8_t * out,int stride)61 void iwht4x4_12_sse4_1(const tran_low_t *in, uint8_t *out, int stride) {
62 av1_highbd_iwht4x4_16_add_sse4_1(in, out, stride, 12);
63 }
64
65 #endif
66
67 class Trans4x4WHT : public libaom_test::TransformTestBase<tran_low_t>,
68 public ::testing::TestWithParam<Dct4x4Param> {
69 public:
70 ~Trans4x4WHT() override = default;
71
SetUp()72 void SetUp() override {
73 fwd_txfm_ = GET_PARAM(0);
74 inv_txfm_ = GET_PARAM(1);
75 pitch_ = 4;
76 height_ = 4;
77 fwd_txfm_ref = fwht4x4_ref;
78 bit_depth_ = GET_PARAM(3);
79 mask_ = (1 << bit_depth_) - 1;
80 num_coeffs_ = GET_PARAM(4);
81 fwd_txfm_c_ = GET_PARAM(5);
82 }
83
84 protected:
RunFwdTxfm(const int16_t * in,tran_low_t * out,int stride)85 void RunFwdTxfm(const int16_t *in, tran_low_t *out, int stride) override {
86 fwd_txfm_(in, out, stride);
87 }
RunInvTxfm(const tran_low_t * out,uint8_t * dst,int stride)88 void RunInvTxfm(const tran_low_t *out, uint8_t *dst, int stride) override {
89 inv_txfm_(out, dst, stride);
90 }
RunSpeedTest()91 void RunSpeedTest() {
92 if (!fwd_txfm_c_) {
93 GTEST_SKIP();
94 } else {
95 ACMRandom rnd(ACMRandom::DeterministicSeed());
96 const int count_test_block = 10;
97 const int numIter = 5000;
98
99 int c_sum_time = 0;
100 int simd_sum_time = 0;
101
102 int stride = 96;
103
104 int16_t *input_block = reinterpret_cast<int16_t *>(
105 aom_memalign(16, sizeof(int16_t) * stride * height_));
106 ASSERT_NE(input_block, nullptr);
107 tran_low_t *output_ref_block = reinterpret_cast<tran_low_t *>(
108 aom_memalign(16, sizeof(output_ref_block[0]) * num_coeffs_));
109 ASSERT_NE(output_ref_block, nullptr);
110 tran_low_t *output_block = reinterpret_cast<tran_low_t *>(
111 aom_memalign(16, sizeof(output_block[0]) * num_coeffs_));
112 ASSERT_NE(output_block, nullptr);
113
114 for (int i = 0; i < count_test_block; ++i) {
115 for (int j = 0; j < height_; ++j) {
116 for (int k = 0; k < pitch_; ++k) {
117 int in_idx = j * stride + k;
118 int out_idx = j * pitch_ + k;
119 input_block[in_idx] =
120 (rnd.Rand16() & mask_) - (rnd.Rand16() & mask_);
121 if (bit_depth_ == AOM_BITS_8) {
122 output_block[out_idx] = output_ref_block[out_idx] = rnd.Rand8();
123 } else {
124 output_block[out_idx] = output_ref_block[out_idx] =
125 rnd.Rand16() & mask_;
126 }
127 }
128 }
129
130 aom_usec_timer c_timer_;
131 aom_usec_timer_start(&c_timer_);
132 for (int iter = 0; iter < numIter; iter++) {
133 API_REGISTER_STATE_CHECK(
134 fwd_txfm_c_(input_block, output_ref_block, stride));
135 }
136 aom_usec_timer_mark(&c_timer_);
137
138 aom_usec_timer simd_timer_;
139 aom_usec_timer_start(&simd_timer_);
140
141 for (int iter = 0; iter < numIter; iter++) {
142 API_REGISTER_STATE_CHECK(
143 fwd_txfm_(input_block, output_block, stride));
144 }
145 aom_usec_timer_mark(&simd_timer_);
146
147 c_sum_time += static_cast<int>(aom_usec_timer_elapsed(&c_timer_));
148 simd_sum_time += static_cast<int>(aom_usec_timer_elapsed(&simd_timer_));
149
150 // The minimum quant value is 4.
151 for (int j = 0; j < height_; ++j) {
152 for (int k = 0; k < pitch_; ++k) {
153 int out_idx = j * pitch_ + k;
154 ASSERT_EQ(output_block[out_idx], output_ref_block[out_idx])
155 << "Error: not bit-exact result at index: " << out_idx
156 << " at test block: " << i;
157 }
158 }
159 }
160
161 printf(
162 "c_time = %d \t simd_time = %d \t Gain = %4.2f \n", c_sum_time,
163 simd_sum_time,
164 (static_cast<float>(c_sum_time) / static_cast<float>(simd_sum_time)));
165
166 aom_free(input_block);
167 aom_free(output_ref_block);
168 aom_free(output_block);
169 }
170 }
171
172 FdctFunc fwd_txfm_;
173 IdctFunc inv_txfm_;
174
175 FdctFunc fwd_txfm_c_; // C version of forward transform for speed test.
176 };
177
TEST_P(Trans4x4WHT,AccuracyCheck)178 TEST_P(Trans4x4WHT, AccuracyCheck) { RunAccuracyCheck(0, 0.00001); }
179
TEST_P(Trans4x4WHT,CoeffCheck)180 TEST_P(Trans4x4WHT, CoeffCheck) { RunCoeffCheck(); }
181
TEST_P(Trans4x4WHT,MemCheck)182 TEST_P(Trans4x4WHT, MemCheck) { RunMemCheck(); }
183
TEST_P(Trans4x4WHT,InvAccuracyCheck)184 TEST_P(Trans4x4WHT, InvAccuracyCheck) { RunInvAccuracyCheck(0); }
185
TEST_P(Trans4x4WHT,DISABLED_Speed)186 TEST_P(Trans4x4WHT, DISABLED_Speed) { RunSpeedTest(); }
187
188 using std::make_tuple;
189
190 INSTANTIATE_TEST_SUITE_P(
191 C, Trans4x4WHT,
192 ::testing::Values(make_tuple(&av1_fwht4x4_c, &iwht4x4_10_c, DCT_DCT,
193 AOM_BITS_10, 16,
194 static_cast<FdctFunc>(nullptr)),
195 make_tuple(&av1_fwht4x4_c, &iwht4x4_12_c, DCT_DCT,
196 AOM_BITS_12, 16,
197 static_cast<FdctFunc>(nullptr))));
198
199 #if HAVE_SSE4_1
200
201 INSTANTIATE_TEST_SUITE_P(
202 SSE4_1, Trans4x4WHT,
203 ::testing::Values(make_tuple(&av1_fwht4x4_sse4_1, &iwht4x4_10_sse4_1,
204 DCT_DCT, AOM_BITS_10, 16,
205 static_cast<FdctFunc>(nullptr)),
206 make_tuple(&av1_fwht4x4_sse4_1, &iwht4x4_12_sse4_1,
207 DCT_DCT, AOM_BITS_12, 16,
208 static_cast<FdctFunc>(nullptr))));
209
210 #endif // HAVE_SSE4_1
211
212 #if HAVE_NEON
213
214 INSTANTIATE_TEST_SUITE_P(
215 NEON, Trans4x4WHT,
216 ::testing::Values(make_tuple(&av1_fwht4x4_neon, &iwht4x4_10_c, DCT_DCT,
217 AOM_BITS_10, 16, &av1_fwht4x4_c),
218 make_tuple(&av1_fwht4x4_neon, &iwht4x4_12_c, DCT_DCT,
219 AOM_BITS_12, 16, &av1_fwht4x4_c)));
220
221 #endif // HAVE_NEON
222
223 } // namespace
224