1 /*
2  * Copyright (c) 2019-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/core/Types.h"
25 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
26 #include "arm_compute/runtime/CL/CLTensor.h"
27 #include "arm_compute/runtime/CL/CLTensorAllocator.h"
28 #include "src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyReshapedOnlyRhsKernel.h"
29 #include "src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.h"
30 #include "tests/CL/CLAccessor.h"
31 #include "tests/CL/Helper.h"
32 #include "tests/PaddingCalculator.h"
33 #include "tests/datasets/ShapeDatasets.h"
34 #include "tests/framework/Asserts.h"
35 #include "tests/framework/Macros.h"
36 #include "tests/framework/datasets/Datasets.h"
37 #include "tests/validation/Validation.h"
38 #include "tests/validation/fixtures/GEMMLowpFixture.h"
39 
40 namespace arm_compute
41 {
42 namespace test
43 {
44 namespace validation
45 {
46 using namespace arm_compute::misc::shape_calculator;
47 
48 // Create function for CLGEMMReshapeRHSMatrixKernel
49 using CLGEMMReshapeRHSMatrix = CLSynthetizeOperator<opencl::kernels::ClGemmReshapeRhsMatrixKernel>;
50 
51 // Create function for CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel
52 using CLGEMMLowpMatrixMultiplyReshapedOnlyRHS = CLSynthetizeOperator<opencl::kernels::ClGemmLowpMatrixMultiplyReshapedOnlyRhsKernel>;
53 
54 // Fixture for CLGEMMLowpMatrixMultiplyReshapedOnlyRHS
55 using CLGEMMLowpMatrixMultiplyReshapedOnlyRHSFixture = GEMMLowpMatrixMultiplyReshapedOnlyRHSValidationFixture<CLTensor, CLAccessor, CLGEMMReshapeRHSMatrix, CLGEMMLowpMatrixMultiplyReshapedOnlyRHS>;
56 
57 // Fixture for CLGEMMLowpMatrixMultiplyReshapedOnlyRHS3D
58 using CLGEMMLowpMatrixMultiplyReshapedOnlyRHS3DFixture =
59     GEMMLowpMatrixMultiplyReshapedOnlyRHS3DValidationFixture<CLTensor, CLAccessor, CLGEMMReshapeRHSMatrix, CLGEMMLowpMatrixMultiplyReshapedOnlyRHS>;
60 
61 namespace
62 {
63 // *INDENT-OFF*
64 // clang-format off
65 
66 /** M values to test */
67 const auto m_values = framework::dataset::make("M", 37);
68 
69 /** M_W values to test */
70 const auto m_w_values = framework::dataset::make("M_W", 5);
71 
72 /** M_H values to test */
73 const auto m_h_values = framework::dataset::make("M_H", 7);
74 
75 /** N values to test */
76 const auto n_values = framework::dataset::make("N", 51);
77 
78 /** K values to test */
79 const auto k_values = framework::dataset::make("K", 23);
80 
81 /** Batch size values to test */
82 const auto b_values = framework::dataset::make("batch_size", 1, 3);
83 
84 /** M0 values to test - Precommit */
85 const auto m0_values_precommit_1 = framework::dataset::make("M0", {4});
86 const auto m0_values_precommit_2 = framework::dataset::make("M0", {6});
87 
88 /** N0 values to test - Precommit */
89 const auto n0_values_precommit = framework::dataset::make("N0", { 4 });
90 
91 /** K0 values to test - Precommit */
92 const auto k0_values_precommit = framework::dataset::make("K0", { 16 });
93 
94 /** H0 values to test - Precommit */
95 const auto h0_values_precommit = framework::dataset::make("H0", 1, 3);
96 
97 /** M0 values to test - Nightly */
98 const auto m0_values_nightly = framework::dataset::make("M0", 2, 8);
99 
100 /** N0 values to test - Nightly */
101 const auto n0_values_nightly = framework::dataset::make("N0", { 2, 3, 4, 8 });
102 
103 /** K0 values to test - Nightly */
104 const auto k0_values_nightly = framework::dataset::make("K0", { 2, 3, 4, 8, 16 });
105 
106 /** H0 values to test - Nightly */
107 const auto h0_values_nightly = framework::dataset::make("H0", 1, 4);
108 
109 /** Interleave values to test with RHS matrix */
110 const auto i_values_rhs = framework::dataset::make("interleave_rhs", { true, false });
111 
112 /** Transpose values to test with RHS matrix */
113 const auto t_values_rhs = framework::dataset::make("transpose_rhs", { true });
114 
115 /** Configuration test */
validate_configuration(unsigned int m_value,unsigned int n_value,unsigned int k_value,unsigned int b_value,unsigned int m0_value,unsigned int n0_value,unsigned int k0_value,unsigned int h0_value,bool i_value_rhs)116 void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned int k_value, unsigned int b_value, unsigned int m0_value, unsigned int n0_value, unsigned int k0_value, unsigned int h0_value, bool i_value_rhs)
117 {
118     const unsigned int M = m_value;
119     const unsigned int N = n_value;
120     const unsigned int K = k_value;
121 
122     GEMMLHSMatrixInfo lhs_info;
123     lhs_info.m0         = m0_value;
124     lhs_info.k0         = k0_value;
125 
126     GEMMRHSMatrixInfo rhs_info;
127     rhs_info.n0         = n0_value;
128     rhs_info.k0         = k0_value;
129     rhs_info.h0         = h0_value;
130     rhs_info.interleave = i_value_rhs;
131     rhs_info.transpose  = true;
132 
133     GEMMKernelInfo gemm_info;
134     gemm_info.m = M;
135     gemm_info.n = N;
136     gemm_info.k = K;
137     gemm_info.lhs_info = lhs_info;
138     gemm_info.rhs_info = rhs_info;
139 
140     const TensorShape lhs_shape(K, M, b_value);
141     const TensorShape rhs_shape(N, K, b_value);
142     const TensorShape rhs_shape_reshaped = compute_rhs_reshaped_shape(TensorInfo(rhs_shape, 1, DataType::QASYMM8),
143                                                                       rhs_info);
144 
145     const TensorShape dst_shape = compute_mm_shape(TensorInfo(lhs_shape, 1, DataType::QASYMM8),
146                                                    TensorInfo(rhs_shape_reshaped, 1, DataType::QASYMM8),
147                                                    gemm_info);
148 
149     // Create tensors
150     CLTensor lhs          = create_tensor<CLTensor>(lhs_shape, DataType::QASYMM8);
151     CLTensor rhs_reshaped = create_tensor<CLTensor>(rhs_shape_reshaped, DataType::QASYMM8);
152     CLTensor dst          = create_tensor<CLTensor>(dst_shape, DataType::S32);
153 
154     ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS);
155     ARM_COMPUTE_EXPECT(rhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS);
156     ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
157 
158     // Create and configure function
159     CLGEMMLowpMatrixMultiplyReshapedOnlyRHS gemm;
160     gemm.configure(lhs.info(), rhs_reshaped.info(), dst.info(), gemm_info);
161 }
162 } // namespace
163 
164 TEST_SUITE(CL)
TEST_SUITE(GEMMLowpMatrixMultiplyReshapedOnlyRHS)165 TEST_SUITE(GEMMLowpMatrixMultiplyReshapedOnlyRHS)
166 DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(combine(
167                                                                    m_values,
168                                                                    n_values),
169                                                                    k_values),
170                                                                    framework::dataset::make("batch_size", 1)),
171                                                                    m0_values_precommit_1),
172                                                                    n0_values_precommit),
173                                                                    k0_values_precommit),
174                                                                    h0_values_precommit),
175                                                                    i_values_rhs),
176 m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, h0_value, i_value_rhs)
177 {
178     validate_configuration(m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, h0_value, i_value_rhs);
179 }
180 
181 FIXTURE_DATA_TEST_CASE(RunSmall_1, CLGEMMLowpMatrixMultiplyReshapedOnlyRHSFixture, framework::DatasetMode::ALL,
182                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
183                                                                    m_values,
184                                                                    n_values),
185                                                                    k_values),
186                                                                    b_values),
187                                                                    m0_values_precommit_1),
188                                                                    n0_values_precommit),
189                                                                    k0_values_precommit),
190                                                                    h0_values_precommit),
191                                                                    i_values_rhs),
192                                                                    t_values_rhs),
193                     framework::dataset::make("DataType", { DataType::QASYMM8 })))
194 {
195     // Validate output
196     validate(CLAccessor(_target), _reference);
197 }
198 
199 FIXTURE_DATA_TEST_CASE(RunSmall_2, CLGEMMLowpMatrixMultiplyReshapedOnlyRHSFixture, framework::DatasetMode::ALL,
200                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
201                                                                    m_values,
202                                                                    n_values),
203                                                                    k_values),
204                                                                    b_values),
205                                                                    m0_values_precommit_2),
206                                                                    n0_values_precommit),
207                                                                    k0_values_precommit),
208                                                                    h0_values_precommit),
209                                                                    i_values_rhs),
210                                                                    t_values_rhs),
211                     framework::dataset::make("DataType", { DataType::QASYMM8_SIGNED })))
212 {
213     // Validate output
214     validate(CLAccessor(_target), _reference);
215 }
216 
217 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMLowpMatrixMultiplyReshapedOnlyRHSFixture, framework::DatasetMode::DISABLED,
218                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
219                                                                    m_values,
220                                                                    n_values),
221                                                                    k_values),
222                                                                    b_values),
223                                                                    m0_values_nightly),
224                                                                    n0_values_nightly),
225                                                                    k0_values_nightly),
226                                                                    h0_values_nightly),
227                                                                    i_values_rhs),
228                                                                    t_values_rhs),
229                     framework::dataset::make("DataType", { DataType::QASYMM8 })))
230 {
231     // Validate output
232     validate(CLAccessor(_target), _reference);
233 }
234 
235 FIXTURE_DATA_TEST_CASE(RunSmall3D_1, CLGEMMLowpMatrixMultiplyReshapedOnlyRHS3DFixture, framework::DatasetMode::ALL,
236                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
237                                                                    m_w_values,
238                                                                    m_h_values),
239                                                                    n_values),
240                                                                    k_values),
241                                                                    b_values),
242                                                                    m0_values_precommit_1),
243                                                                    n0_values_precommit),
244                                                                    k0_values_precommit),
245                                                                    h0_values_precommit),
246                                                                    i_values_rhs),
247                                                                    t_values_rhs),
248                     framework::dataset::make("DataType", { DataType::QASYMM8 })))
249 {
250     // Validate output
251     validate(CLAccessor(_target), _reference);
252 }
253 
254 FIXTURE_DATA_TEST_CASE(RunSmall3D_2, CLGEMMLowpMatrixMultiplyReshapedOnlyRHS3DFixture, framework::DatasetMode::ALL,
255                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
256                                                                    m_w_values,
257                                                                    m_h_values),
258                                                                    n_values),
259                                                                    k_values),
260                                                                    b_values),
261                                                                    m0_values_precommit_2),
262                                                                    n0_values_precommit),
263                                                                    k0_values_precommit),
264                                                                    h0_values_precommit),
265                                                                    i_values_rhs),
266                                                                    t_values_rhs),
267                     framework::dataset::make("DataType", { DataType::QASYMM8_SIGNED })))
268 {
269     // Validate output
270     validate(CLAccessor(_target), _reference);
271 }
272 
273 FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMLowpMatrixMultiplyReshapedOnlyRHS3DFixture, framework::DatasetMode::DISABLED,
274                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
275                                                                    m_w_values,
276                                                                    m_h_values),
277                                                                    n_values),
278                                                                    k_values),
279                                                                    b_values),
280                                                                    m0_values_nightly),
281                                                                    n0_values_nightly),
282                                                                    k0_values_nightly),
283                                                                    h0_values_nightly),
284                                                                    i_values_rhs),
285                                                                    t_values_rhs),
286                     framework::dataset::make("DataType", { DataType::QASYMM8 })))
287 {
288     // Validate output
289     validate(CLAccessor(_target), _reference);
290 }
291 TEST_SUITE_END() // GEMMLowpMatrixMultiplyReshapedOnlyRHS
292 TEST_SUITE_END() // CL
293 } // namespace validation
294 } // namespace test
295 } // namespace arm_compute