1 /*
2 * Copyright (c) 2018-2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #include "arm_compute/core/KernelDescriptors.h"
25 #include "arm_compute/core/Types.h"
26 #include "arm_compute/core/experimental/PostOps.h"
27 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
28 #include "arm_compute/runtime/CL/CLTensor.h"
29 #include "arm_compute/runtime/CL/CLTensorAllocator.h"
30 #include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.h"
31 #include "src/gpu/cl/kernels/ClGemmReshapeLhsMatrixKernel.h"
32 #include "src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.h"
33 #include "tests/CL/CLAccessor.h"
34 #include "tests/CL/Helper.h"
35 #include "tests/PaddingCalculator.h"
36 #include "tests/datasets/ShapeDatasets.h"
37 #include "tests/framework/Asserts.h"
38 #include "tests/framework/Macros.h"
39 #include "tests/framework/datasets/Datasets.h"
40 #include "tests/validation/Validation.h"
41 #include "tests/validation/fixtures/GEMMFixture.h"
42
43 namespace arm_compute
44 {
45 namespace test
46 {
47 namespace validation
48 {
49 using namespace arm_compute::misc::shape_calculator;
50 using namespace arm_compute::opencl::kernels;
51
52 // Create function for ClGemmReshapeLhsMatrixKernel
53 using CLGEMMReshapeLHSMatrix = CLSynthetizeOperator<ClGemmReshapeLhsMatrixKernel>;
54
55 // Create function for ClGemmReshapeRhsMatrixKernel
56 using CLGEMMReshapeRHSMatrix = CLSynthetizeOperator<ClGemmReshapeRhsMatrixKernel>;
57
58 // Create function for ClGemmMatrixMultiplyReshapedKernel
59 using CLGEMMMatrixMultiplyReshaped = CLSynthetizeOperator<ClGemmMatrixMultiplyReshapedKernel>;
60
61 // Fixture for CLGEMMMatrixMultiplyReshaped
62 template <typename T>
63 using CLGEMMMatrixMultiplyReshapedFixture = GEMMMatrixMultiplyReshapedValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped>;
64
65 // Fixture for CLGEMMMatrixMultiplyReshaped with post ops
66 template <typename T>
67 using CLGEMMMatrixMultiplyReshapedWithPostOpsFixture =
68 GEMMMatrixMultiplyReshapedWithPostOpsValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped>;
69
70 // Fixture for CLGEMMMatrixMultiplyReshaped mixed precision
71 template <typename T>
72 using CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture =
73 GEMMMatrixMultiplyReshapedValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true>;
74
75 // Fixture for CLGEMMMatrixMultiplyReshaped mixed precision with post ops
76 template <typename T>
77 using CLGEMMMatrixMultiplyReshapedMixedPrecisionWithPostOpsFixture =
78 GEMMMatrixMultiplyReshapedWithPostOpsValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true>;
79
80 // Fixture for CLGEMMMatrixMultiplyReshaped3D
81 template <typename T>
82 using CLGEMMMatrixMultiplyReshaped3DFixture = GEMMMatrixMultiplyReshaped3DValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped>;
83
84 // Fixture for CLGEMMMatrixMultiplyReshaped3D mixed precision
85 template <typename T>
86 using CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture =
87 GEMMMatrixMultiplyReshaped3DValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true>;
88
89 namespace
90 {
91 // *INDENT-OFF*
92 // clang-format off
93 RelativeTolerance<float> rel_tolerance_f32(0.001f);
94 constexpr float abs_tolerance_f32(0.0001f);
95
96 RelativeTolerance<float> rel_tolerance_f16_mixed_precision(0.001f);
97 constexpr float abs_tolerance_f16_mixed_precision(0.01f);
98
99 RelativeTolerance<float> rel_tolerance_f16(0.001f);
100 constexpr float abs_tolerance_f16(0.01f);
101
102 /** M values to test */
103 const auto m_values = framework::dataset::make("M", 17);
104
105 /** M_W values to test */
106 const auto m_w_values = framework::dataset::make("M_W", 5);
107
108 /** M_H values to test */
109 const auto m_h_values = framework::dataset::make("M_H", 7);
110
111 /** N values to test */
112 const auto n_values = framework::dataset::make("N", 21);
113
114 /** K values to test */
115 const auto k_values = framework::dataset::make("K", 13);
116
117 /** Batch size values to test */
118 const auto b_values = framework::dataset::make("batch_size", 2, 3);
119
120 /** Activation values to test */
121 const auto act_values = framework::dataset::make("Activation",
122 {
123 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 8.f, 2.f),
124 ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::ELU),
125 });
126
127 /** Alpha values to test - Precommit */
128 const auto a_values_precommit = framework::dataset::make("alpha", {-0.75f} );
129
130 /** Beta values to test - Precommit */
131 const auto beta_values_precommit = framework::dataset::make("beta", {-0.35f} );
132
133 /** M0 values to test - Precommit */
134 const auto m0_values_precommit = framework::dataset::make("M0", { 4 });
135
136 /** N0 values to test - Precommit */
137 const auto n0_values_precommit = framework::dataset::make("N0", { 4 });
138
139 /** K0 values to test - Precommit */
140 const auto k0_values_precommit = framework::dataset::make("K0", { 4 });
141
142 /** V0 values to test - Precommit */
143 const auto v0_values_precommit = framework::dataset::make("V0", 1, 3);
144
145 /** H0 values to test - Precommit */
146 const auto h0_values_precommit = framework::dataset::make("H0", 1, 3);
147
148 /** Alpha values to test - Nightly */
149 const auto a_values_nightly = framework::dataset::make("alpha", {1.0f} );
150
151 /** Beta values to test - Nightly */
152 const auto beta_values_nightly = framework::dataset::make("beta", {1.0f} );
153
154 /** M0 values to test - Nightly */
155 const auto m0_values_nightly = framework::dataset::make("M0", { 8 });
156
157 /** N0 values to test - Nightly */
158 const auto n0_values_nightly = framework::dataset::make("N0", { 8 });
159
160 /** K0 values to test - Nightly */
161 const auto k0_values_nightly = framework::dataset::make("K0", { 4 });
162
163 /** N0 values to test with export to OpenCL image object - Nightly */
164 const auto n0_export_to_cl_image_values_nightly = framework::dataset::make("N0", { 4, 8, 16 });
165
166 /** K0 values to test with export to OpenCL image object - Nightly */
167 const auto k0_export_to_cl_image_values_nightly = framework::dataset::make("K0", { 4, 8, 16 });
168
169 /** V0 values to test - Nightly */
170 const auto v0_values_nightly = framework::dataset::make("V0", 1, 3);
171
172 /** H0 values to test - Nightly */
173 const auto h0_values_nightly = framework::dataset::make("H0", 1, 3);
174
175 /** Interleave values to test with LHS matrix */
176 const auto i_values_lhs = framework::dataset::make("interleave_lhs", { true, false });
177
178 /** Interleave values to test with RHS matrix */
179 const auto i_values_rhs = framework::dataset::make("interleave_rhs", { true, false });
180
181 /** Broadcast bias from vector to matrix */
182 const auto broadcast_bias_values = framework::dataset::make("broadcast_bias", { false, true } );
183
184 /** LHS transposed values */
185 const auto lhs_transpose_values = framework::dataset::make("lhs_transpose", { false, true } );
186
187 /** Post Ops */
188 using PostOpArgBroadcast = CLGEMMMatrixMultiplyReshapedWithPostOpsFixture<float>::PostOpArgBroadcast;
post_ops_1()189 experimental::PostOpList<PostOpArgBroadcast> post_ops_1()
190 {
191 experimental::PostOpList<PostOpArgBroadcast> post_ops{};
192 post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F});
193 post_ops.push_back_op<experimental::PostOpEltwiseAdd<PostOpArgBroadcast>>(
194 std::make_tuple(true, true, false), // If broadcast in dims 0, 1 and 2
195 0,
196 ConvertPolicy::SATURATE);
197 post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
198 return post_ops;
199 }
post_ops_2()200 experimental::PostOpList<PostOpArgBroadcast> post_ops_2()
201 {
202 experimental::PostOpList<PostOpArgBroadcast> post_ops{};
203 post_ops.push_back_op<experimental::PostOpEltwiseAdd<PostOpArgBroadcast>>(
204 std::make_tuple(false, true, true), // If broadcast in dims 0, 1 and 2
205 1,
206 ConvertPolicy::SATURATE);
207 post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
208 return post_ops;
209 }
post_ops_3()210 experimental::PostOpList<PostOpArgBroadcast> post_ops_3()
211 {
212 experimental::PostOpList<PostOpArgBroadcast> post_ops{};
213 post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
214 post_ops.push_back_op<experimental::PostOpEltwiseAdd<PostOpArgBroadcast>>(
215 std::make_tuple(false, false, true), // If broadcast in dims 0, 1 and 2
216 1,
217 ConvertPolicy::SATURATE);
218 return post_ops;
219 }
220 // To test that the output of the main op is the first parameter in prelu post op
post_ops_4()221 experimental::PostOpList<PostOpArgBroadcast> post_ops_4()
222 {
223 experimental::PostOpList<PostOpArgBroadcast> post_ops{};
224 post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F});
225 post_ops.push_back_op<experimental::PostOpEltwisePRelu<PostOpArgBroadcast>>(
226 std::make_tuple(false, false, true), // If true, broadcast in corresponding dim: 0, 1 or 2
227 0,
228 ConvertPolicy::SATURATE);
229 post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
230 return post_ops;
231 }
232 // To test that the output of the main op is the second parameter in prelu post op i.e. it is the alpha_param
post_ops_5()233 experimental::PostOpList<PostOpArgBroadcast> post_ops_5()
234 {
235 experimental::PostOpList<PostOpArgBroadcast> post_ops{};
236 post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F});
237 post_ops.push_back_op<experimental::PostOpEltwisePRelu<PostOpArgBroadcast>>(
238 std::make_tuple(false, false, false), // If true, broadcast in corresponding dim: 0, 1 or 2
239 1,
240 ConvertPolicy::SATURATE);
241 post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
242 return post_ops;
243 }
244 /** Different Post Op Lists */
245 const auto post_op_lists = framework::dataset::make("post_op_lists", {
246 post_ops_1(),
247 post_ops_2(),
248 post_ops_3(),
249 post_ops_4(),
250 post_ops_5()
251 } );
252
is_post_op_list_valid(unsigned int m,unsigned int n,unsigned int k,unsigned int batch,DataType data_type,const experimental::PostOpList<ITensorInfo * > & post_ops)253 bool is_post_op_list_valid(unsigned int m, unsigned int n, unsigned int k, unsigned int batch, DataType data_type, const experimental::PostOpList<ITensorInfo*>& post_ops)
254 {
255 const auto lhs_info = GEMMLHSMatrixInfo(4,4,1,false,true);
256 const auto rhs_info = GEMMRHSMatrixInfo(4,4,1,true,true,false);
257
258 // Create TensorInfo for post op arguments
259 TensorInfo input0_info(TensorShape(k, m, batch), 1, data_type);
260 TensorInfo input1_info(TensorShape(n, k, batch), 1, data_type);
261 TensorInfo input2_info(TensorShape(n), 1, data_type);
262 TensorInfo output_info(TensorShape(n, m, batch), 1, data_type);
263
264 const TensorInfo reshaped_input0_info = input0_info.clone()->set_tensor_shape(misc::shape_calculator::compute_lhs_reshaped_shape(input0_info, lhs_info));
265 const TensorInfo reshaped_input1_info = input1_info.clone()->set_tensor_shape(misc::shape_calculator::compute_rhs_reshaped_shape(input1_info, rhs_info));
266
267 GEMMKernelInfo gemm_info(m, n, k, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
268 false /**< reinterpret the input as 3D */,
269 true /**< Flag used to broadcast the bias addition */,
270 false /**< wider accumm */,
271 false /**< has pad y */,
272 ActivationLayerInfo::ActivationFunction::IDENTITY,
273 1 /**< Multiplication factor for the width of the 1xW transposed block */,
274 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
275 lhs_info,
276 rhs_info,
277 0 /**< Offset to be added to each element of the matrix A */,
278 0 /**< Offset to be added to each element of the matrix B */,
279 post_ops);
280 return bool(ClGemmMatrixMultiplyReshapedKernel::validate(&reshaped_input0_info.clone()->set_is_resizable(true),
281 &reshaped_input1_info.clone()->set_is_resizable(true),
282 &input2_info.clone()->set_is_resizable(true),
283 &output_info.clone()->set_is_resizable(true),1.f,1.f,
284 lhs_info,
285 rhs_info,
286 gemm_info));
287 }
288
289 } // namespace
290
291 TEST_SUITE(CL)
TEST_SUITE(GEMMMatrixMultiplyReshaped)292 TEST_SUITE(GEMMMatrixMultiplyReshaped)
293
294 // *INDENT-OFF*
295 // clang-format off
296 DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zip(
297 framework::dataset::make("Input0Info", { TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F32), // OK
298 TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F16), // OK
299 TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::QASYMM8), // Data type not supported
300 TensorInfo(TensorShape(10U, 5U, 2U), 1, DataType::F32), // Incorrect dimension bias
301 TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F32), // Mismatching shapes
302 TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F16), // OK, do not broadcast bias
303 TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F16), // OK, wider accummulation
304 TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F16), // OK, RHS 4,4,2
305
306 }),
307 framework::dataset::make("Input1Info",{ TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F32),
308 TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F16),
309 TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::QASYMM8),
310 TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F32),
311 TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
312 TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F16),
313 TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F16),
314 TensorInfo(TensorShape(128U, 3U, 2U), 1, DataType::F16),
315
316 })),
317 framework::dataset::make("Input2Info", { TensorInfo(TensorShape(21U), 1, DataType::F32),
318 TensorInfo(TensorShape(21U), 1, DataType::F16),
319 TensorInfo(TensorShape(21U), 1, DataType::QASYMM8),
320 TensorInfo(TensorShape(21U), 1, DataType::F32),
321 TensorInfo(TensorShape(21U), 1, DataType::F32),
322 TensorInfo(TensorShape(21U,17U), 1, DataType::F16),
323 TensorInfo(TensorShape(21U,17U), 1, DataType::F16),
324 TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F16),
325
326 })),
327 framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F32),
328 TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F16),
329 TensorInfo(TensorShape(21U,17U,2U), 1, DataType::QASYMM8),
330 TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F32),
331 TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F32),
332 TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F16),
333 TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F16),
334 TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F16),
335
336 })),
337 framework::dataset::make("LHSMInfo",{
338 GEMMLHSMatrixInfo(4,4,1,false,true),
339 GEMMLHSMatrixInfo(4,4,1,false,true),
340 GEMMLHSMatrixInfo(4,4,1,false,true),
341 GEMMLHSMatrixInfo(4,2,4,false,false),
342 GEMMLHSMatrixInfo(4,2,4,false,false),
343 GEMMLHSMatrixInfo(4,4,1,false,true),
344 GEMMLHSMatrixInfo(4,4,1,false,true),
345 GEMMLHSMatrixInfo(4,4,1,false,true),
346
347 })),
348 framework::dataset::make("RHSMInfo",{
349 GEMMRHSMatrixInfo(4,4,1,true,true,false),
350 GEMMRHSMatrixInfo(4,4,1,true,true,false),
351 GEMMRHSMatrixInfo(4,4,1,true,true,false),
352 GEMMRHSMatrixInfo(2,2,1,true,false,false),
353 GEMMRHSMatrixInfo(2,2,1,true,false,false),
354 GEMMRHSMatrixInfo(4,4,1,true,true,false),
355 GEMMRHSMatrixInfo(4,4,1,true,true,false),
356 GEMMRHSMatrixInfo(4,4,2,true,false,false),
357
358
359 })),
360
361
362 framework::dataset::make("GEMMInfo",{
363 GEMMKernelInfo( 17 /**<M Number of LHS rows*/,
364 21 /**<N Number of RHS columns*/,
365 13 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
366 false /**< reinterpret the input as 3D */,
367 true /**< Flag used to broadcast the bias addition */,
368 false /**< wider accumm */,
369 false /**< has pad y */,
370 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
371 1 /**< Multiplication factor for the width of the 1xW transposed block */,
372 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
373 GEMMLHSMatrixInfo(4,4,1,false,true),
374 GEMMRHSMatrixInfo(4,4,1,true,true,false),
375 0 /**< Offset to be added to each element of the matrix A */,
376 0 /**< Offset to be added to each element of the matrix B */),
377
378 GEMMKernelInfo( 17 /**<M Number of LHS rows*/,
379 21 /**<N Number of RHS columns*/,
380 13 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
381 false /**< reinterpret the input as 3D */,
382 true /**< Flag used to broadcast the bias addition */,
383 false /**< wider accumm */,
384 false /**< has pad y */,
385 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
386 1 /**< Multiplication factor for the width of the 1xW transposed block */,
387 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
388 GEMMLHSMatrixInfo(4,4,1,false,true),
389 GEMMRHSMatrixInfo(4,4,1,true,true,false),
390 0 /**< Offset to be added to each element of the matrix A */,
391 0 /**< Offset to be added to each element of the matrix B */),
392 GEMMKernelInfo(),
393 GEMMKernelInfo(),
394 GEMMKernelInfo(),
395
396 GEMMKernelInfo( 17 /**<M Number of LHS rows*/,
397 21 /**<N Number of RHS columns*/,
398 13 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
399 false /**< reinterpret the input as 3D */,
400 false /**< Flag used to broadcast the bias addition */,
401 false /**< wider accumm */,
402 false /**< has pad y */,
403 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
404 1 /**< Multiplication factor for the width of the 1xW transposed block */,
405 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
406 GEMMLHSMatrixInfo(4,4,1,false,true),
407 GEMMRHSMatrixInfo(4,4,1,true,true,false),
408 0 /**< Offset to be added to each element of the matrix A */,
409 0 /**< Offset to be added to each element of the matrix B */),
410
411
412 GEMMKernelInfo( 17 /**<M Number of LHS rows*/,
413 21 /**<N Number of RHS columns*/,
414 13 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
415 false /**< reinterpret the input as 3D */,
416 false /**< Flag used to broadcast the bias addition */,
417 true /**< wider accumm */,
418 true /**< has pad y */,
419 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
420 1 /**< Multiplication factor for the width of the 1xW transposed block */,
421 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
422 GEMMLHSMatrixInfo(4,4,1,false,true),
423 GEMMRHSMatrixInfo(4,4,1,true,true,false),
424 0 /**< Offset to be added to each element of the matrix A */,
425 0 /**< Offset to be added to each element of the matrix B */),
426
427 GEMMKernelInfo( 17 /**<M Number of LHS rows*/,
428 21 /**<N Number of RHS columns*/,
429 13 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
430 false /**< reinterpret the input as 3D */,
431 false /**< Flag used to broadcast the bias addition */,
432 false /**< wider accumm */,
433 false /**< has pad y */,
434 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
435 1 /**< Multiplication factor for the width of the 1xW transposed block */,
436 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
437 GEMMLHSMatrixInfo(4,4,1,false,true),
438 GEMMRHSMatrixInfo(4,4,2,true,false,false),
439 0 /**< Offset to be added to each element of the matrix A */,
440 0 /**< Offset to be added to each element of the matrix B */),
441 })),
442 framework::dataset::make("Expected", { true, true, false, false, false, true, true,true})),
443 input0_info ,input1_info, input2_info, output_info, lhs_info, rhs_info, gemm_info, expected)
444 {
445 ARM_COMPUTE_EXPECT(bool(ClGemmMatrixMultiplyReshapedKernel::validate(&input0_info.clone()->set_is_resizable(true),
446 &input1_info.clone()->set_is_resizable(true),
447 &input2_info.clone()->set_is_resizable(true),
448 &output_info.clone()->set_is_resizable(true),1.f,1.f,
449 lhs_info,
450 rhs_info,
451 gemm_info)) == expected, framework::LogLevel::ERRORS);
452 }
453 TEST_SUITE(ValidateFusedPostOpsConfigs)
TEST_SUITE(Invalid)454 TEST_SUITE(Invalid)
455 TEST_CASE(UnsupportedPostOpSequence, framework::DatasetMode::ALL)
456 {
457 const auto data_type = DataType::F32;
458 const unsigned int m = 17;
459 const unsigned int n = 1;
460 const unsigned int k = 13;
461 const unsigned int batch = 2;
462 TensorShape post_op_arg0_shape(n, m, batch);
463 TensorInfo post_op_arg_info(post_op_arg0_shape, 1, data_type);
464 auto post_op_arg1_info = post_op_arg_info.clone();
465
466 // Unsupported sequence of post ops
467 experimental::PostOpList<ITensorInfo*> post_ops{};
468 post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo*>>(
469 &post_op_arg_info,
470 1,
471 ConvertPolicy::SATURATE);
472 post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo*>>(
473 post_op_arg1_info.get(),
474 0,
475 ConvertPolicy::SATURATE);
476
477 ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == false, framework::LogLevel::ERRORS);
478 }
TEST_CASE(OutputWidened,framework::DatasetMode::ALL)479 TEST_CASE(OutputWidened, framework::DatasetMode::ALL)
480 {
481 // Invalid broadcast: post op tensors "widen" the output tensor
482 const auto data_type = DataType::F32;
483 const unsigned int m = 17;
484 const unsigned int n = 1;
485 const unsigned int k = 13;
486 const unsigned int batch = 2;
487 TensorShape post_op_arg_shape(n + 4, m, batch); // output's X dimension (n) is "widened", which is not allowed
488 TensorInfo post_op_arg_info(post_op_arg_shape, 1, data_type);
489 experimental::PostOpList<ITensorInfo*> post_ops{};
490 post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo*>>( &post_op_arg_info, 0, ConvertPolicy::SATURATE);
491
492 ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == false, framework::LogLevel::ERRORS);
493 }
TEST_CASE(BroadcastInXDimOnly,framework::DatasetMode::ALL)494 TEST_CASE(BroadcastInXDimOnly, framework::DatasetMode::ALL)
495 {
496 // Invalid broadcast: post op tensors broadcast in the first dimension (X) only
497 const auto data_type = DataType::F32;
498 const unsigned int m = 22;
499 const unsigned int n = 16;
500 const unsigned int k = 15;
501 const unsigned int batch = 3;
502 TensorShape post_op_arg_shape(1, m, batch);
503 TensorInfo post_op_arg_info(post_op_arg_shape, 1, data_type);
504 experimental::PostOpList<ITensorInfo*> post_ops{};
505 post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo*>>( &post_op_arg_info, 0, ConvertPolicy::SATURATE);
506
507 ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == false, framework::LogLevel::ERRORS);
508 }
509 TEST_SUITE_END() // Invalid
TEST_SUITE(Valid)510 TEST_SUITE(Valid)
511 TEST_CASE(EmptyPostOpList, framework::DatasetMode::ALL)
512 {
513 const auto data_type = DataType::F32;
514 const unsigned int m = 22;
515 const unsigned int n = 16;
516 const unsigned int k = 15;
517 const unsigned int batch = 3;
518 experimental::PostOpList<ITensorInfo*> post_ops{};
519
520 ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == true, framework::LogLevel::ERRORS);
521 }
TEST_CASE(BroadcastInYDimOnly,framework::DatasetMode::ALL)522 TEST_CASE(BroadcastInYDimOnly, framework::DatasetMode::ALL)
523 {
524 const auto data_type = DataType::F32;
525 const unsigned int m = 22;
526 const unsigned int n = 16;
527 const unsigned int k = 15;
528 const unsigned int batch = 3;
529 TensorShape post_op_arg_shape(n, 1, batch);
530 TensorInfo post_op_arg_info(post_op_arg_shape, 1, data_type);
531 experimental::PostOpList<ITensorInfo*> post_ops{};
532 post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo*>>( &post_op_arg_info, 0, ConvertPolicy::SATURATE);
533
534 ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == true, framework::LogLevel::ERRORS);
535 }
TEST_CASE(BroadcastInBothXandYDims,framework::DatasetMode::ALL)536 TEST_CASE(BroadcastInBothXandYDims, framework::DatasetMode::ALL)
537 {
538 const auto data_type = DataType::F32;
539 const unsigned int m = 22;
540 const unsigned int n = 16;
541 const unsigned int k = 15;
542 const unsigned int batch = 3;
543 TensorShape post_op_arg_shape(1, 1, batch);
544 TensorInfo post_op_arg_info(post_op_arg_shape, 1, data_type);
545 experimental::PostOpList<ITensorInfo*> post_ops{};
546 post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo*>>( &post_op_arg_info, 0, ConvertPolicy::SATURATE);
547
548 ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == true, framework::LogLevel::ERRORS);
549 }
TEST_CASE(BroadcastInAllDims,framework::DatasetMode::ALL)550 TEST_CASE(BroadcastInAllDims, framework::DatasetMode::ALL)
551 {
552 const auto data_type = DataType::F32;
553 const unsigned int m = 22;
554 const unsigned int n = 16;
555 const unsigned int k = 15;
556 const unsigned int batch = 3;
557 TensorShape post_op_arg_shape(1, 1, 1);
558 TensorInfo post_op_arg_info(post_op_arg_shape, 1, data_type);
559 experimental::PostOpList<ITensorInfo*> post_ops{};
560 post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo*>>( &post_op_arg_info, 0, ConvertPolicy::SATURATE);
561
562 ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == true, framework::LogLevel::ERRORS);
563 }
564 TEST_SUITE_END() // Valid
TEST_SUITE_END()565 TEST_SUITE_END() // ValidateFusedPostOps
566 TEST_SUITE(Float)
567 TEST_SUITE(FP32)
568
569 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::ALL,
570 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
571 m_values,
572 n_values),
573 k_values),
574 b_values),
575 m0_values_precommit),
576 n0_values_precommit),
577 k0_values_precommit),
578 v0_values_precommit),
579 h0_values_precommit),
580 i_values_lhs),
581 i_values_rhs),
582 framework::dataset::make("export_to_cl_image_rhs", false)),
583 framework::dataset::make("DataType", DataType::F32)),
584 a_values_precommit),
585 beta_values_precommit),
586 broadcast_bias_values),
587 lhs_transpose_values),
588 act_values))
589 {
590 // Validate output
591 if(validate_result)
592 {
593 validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
594 }
595 else
596 {
597 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
598 framework::ARM_COMPUTE_PRINT_INFO();
599 }
600 }
601
602 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::DISABLED,
603 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
604 m_values,
605 n_values),
606 k_values),
607 b_values),
608 m0_values_nightly),
609 n0_values_nightly),
610 k0_values_nightly),
611 v0_values_nightly),
612 h0_values_nightly),
613 i_values_lhs),
614 i_values_rhs),
615 framework::dataset::make("export_to_cl_image_rhs", false)),
616 framework::dataset::make("DataType", DataType::F32)),
617 a_values_nightly),
618 beta_values_nightly),
619 broadcast_bias_values),
620 lhs_transpose_values),
621 act_values))
622 {
623 // Validate output
624 if(validate_result)
625 {
626 validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
627 }
628 else
629 {
630 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
631 framework::ARM_COMPUTE_PRINT_INFO();
632 }
633 }
634
635 FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::ALL,
636 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
637 m_w_values,
638 m_h_values),
639 n_values),
640 k_values),
641 b_values),
642 m0_values_precommit),
643 n0_values_precommit),
644 k0_values_precommit),
645 v0_values_precommit),
646 h0_values_precommit),
647 i_values_lhs),
648 i_values_rhs),
649 framework::dataset::make("export_to_cl_image_rhs", false)),
650 framework::dataset::make("DataType", DataType::F32)),
651 a_values_precommit),
652 beta_values_precommit),
653 lhs_transpose_values),
654 act_values))
655 {
656 // Validate output
657 if(validate_result)
658 {
659 validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
660 }
661 else
662 {
663 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
664 framework::ARM_COMPUTE_PRINT_INFO();
665 }
666 }
667
668 FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::DISABLED,
669 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
670 m_w_values,
671 m_h_values),
672 n_values),
673 k_values),
674 b_values),
675 m0_values_nightly),
676 n0_values_nightly),
677 k0_values_nightly),
678 v0_values_nightly),
679 h0_values_nightly),
680 i_values_lhs),
681 i_values_rhs),
682 framework::dataset::make("export_to_cl_image_rhs", false)),
683 framework::dataset::make("DataType", DataType::F32)),
684 a_values_nightly),
685 beta_values_nightly),
686 lhs_transpose_values),
687 act_values))
688 {
689 // Validate output
690 if(validate_result)
691 {
692 validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
693 }
694 else
695 {
696 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
697 framework::ARM_COMPUTE_PRINT_INFO();
698 }
699 }
700 TEST_SUITE(FusedPostOps)
701
702 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedWithPostOpsFixture<float>, framework::DatasetMode::ALL,
703 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
704 m_values,
705 n_values),
706 k_values),
707 b_values),
708 m0_values_precommit),
709 n0_values_precommit),
710 k0_values_precommit),
711 v0_values_precommit),
712 h0_values_precommit),
713 framework::dataset::make("interleave_lhs", { false })),
714 framework::dataset::make("interleave_rhs", { false })),
715 framework::dataset::make("export_to_cl_image_rhs", false)),
716 framework::dataset::make("DataType", DataType::F32)),
717 a_values_precommit),
718 beta_values_precommit),
719 framework::dataset::make("broadcast_bias", { true } )),
720 lhs_transpose_values),
721 act_values),
722 post_op_lists)
723 )
724 {
725 // Validate output
726 if(validate_result)
727 {
728 validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
729 }
730 else
731 {
732 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
733 framework::ARM_COMPUTE_PRINT_INFO();
734 }
735 }
736
737 TEST_SUITE_END() // FusedPostOps
738
TEST_SUITE(ExportToCLImage)739 TEST_SUITE(ExportToCLImage)
740 DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zip(
741 framework::dataset::make("Input0Info", { TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32), // OK or incorrect if cl_khr_image2d_from_buffer not supported
742 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32), // OK or incorrect if cl_khr_image2d_from_buffer not supported
743 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32), // OK or incorrect if cl_khr_image2d_from_buffer not supported
744 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32), // Incorrect k0
745 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32), // Incorrect n0
746
747 }),
748 framework::dataset::make("Input1Info",{ TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32),
749 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32),
750 TensorInfo(TensorShape(512U, 8U, 2U), 1, DataType::F32),
751 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32),
752 TensorInfo(TensorShape(128U, 32U, 2U), 1, DataType::F32),
753
754 })),
755 framework::dataset::make("Input2Info", { TensorInfo(TensorShape(64U), 1, DataType::F32),
756 TensorInfo(TensorShape(64U), 1, DataType::F32),
757 TensorInfo(TensorShape(64U), 1, DataType::F32),
758 TensorInfo(TensorShape(64U), 1, DataType::F32),
759 TensorInfo(TensorShape(64U), 1, DataType::F32),
760
761 })),
762 framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F32),
763 TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F32),
764 TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F32),
765 TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F32),
766 TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F32),
767 TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F32),
768
769 })),
770 framework::dataset::make("LHSMInfo",{
771 GEMMLHSMatrixInfo(4, 4, 1, false, true),
772 GEMMLHSMatrixInfo(4, 8, 1, false, true),
773 GEMMLHSMatrixInfo(4, 4, 1, false, true),
774 GEMMLHSMatrixInfo(4, 2, 1, false, false),
775 GEMMLHSMatrixInfo(4, 4, 1, false, false),
776
777 })),
778 framework::dataset::make("RHSMInfo",{
779 GEMMRHSMatrixInfo(4, 4, 1, true, true, true),
780 GEMMRHSMatrixInfo(4, 8, 1, true, true, true),
781 GEMMRHSMatrixInfo(8, 4, 1, true, true, true),
782 GEMMRHSMatrixInfo(4, 2, 1, true, false, true),
783 GEMMRHSMatrixInfo(2, 4, 1, true, false, true),
784 })),
785 framework::dataset::make("GEMMInfo",{GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
786 64 /**<N Number of RHS columns*/,
787 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
788 false /**< reinterpret the input as 3D */,
789 true /**< Flag used to broadcast the bias addition */,
790 false /**< wider accumm */,
791 false /**< has pad y */,
792 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
793 1 /**< Multiplication factor for the width of the 1xW transposed block */,
794 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
795 GEMMLHSMatrixInfo(),
796 GEMMRHSMatrixInfo(),
797 0 /**< Offset to be added to each element of the matrix A */,
798 0 /**< Offset to be added to each element of the matrix B */),
799 GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
800 64 /**<N Number of RHS columns*/,
801 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
802 false /**< reinterpret the input as 3D */,
803 true /**< Flag used to broadcast the bias addition */,
804 false /**< wider accumm */,
805 false /**< has pad y */,
806 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
807 1 /**< Multiplication factor for the width of the 1xW transposed block */,
808 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
809 GEMMLHSMatrixInfo(),
810 GEMMRHSMatrixInfo(),
811 0 /**< Offset to be added to each element of the matrix A */,
812 0 /**< Offset to be added to each element of the matrix B */),
813 GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
814 64 /**<N Number of RHS columns*/,
815 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
816 false /**< reinterpret the input as 3D */,
817 true /**< Flag used to broadcast the bias addition */,
818 false /**< wider accumm */,
819 false /**< has pad y */,
820 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
821 1 /**< Multiplication factor for the width of the 1xW transposed block */,
822 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
823 GEMMLHSMatrixInfo(),
824 GEMMRHSMatrixInfo(),
825 0 /**< Offset to be added to each element of the matrix A */,
826 0 /**< Offset to be added to each element of the matrix B */),
827
828 GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
829 64 /**<N Number of RHS columns*/,
830 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
831 false /**< reinterpret the input as 3D */,
832 true /**< Flag used to broadcast the bias addition */,
833 false /**< wider accumm */,
834 false /**< has pad y */,
835 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
836 1 /**< Multiplication factor for the width of the 1xW transposed block */,
837 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
838 GEMMLHSMatrixInfo(),
839 GEMMRHSMatrixInfo(),
840 0 /**< Offset to be added to each element of the matrix A */,
841 0 /**< Offset to be added to each element of the matrix B */),
842 GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
843 64 /**<N Number of RHS columns*/,
844 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
845 false /**< reinterpret the input as 3D */,
846 true /**< Flag used to broadcast the bias addition */,
847 false /**< wider accumm */,
848 false /**< has pad y */,
849 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
850 1 /**< Multiplication factor for the width of the 1xW transposed block */,
851 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
852 GEMMLHSMatrixInfo(),
853 GEMMRHSMatrixInfo(),
854 0 /**< Offset to be added to each element of the matrix A */,
855 0 /**< Offset to be added to each element of the matrix B */)
856 })),
857 framework::dataset::make("Expected", { true,
858 true,
859 true,
860 false,
861 true})),
862 input0_info ,input1_info, input2_info, output_info, lhs_info, rhs_info, gemm_info, expected)
863 {
864 ARM_COMPUTE_EXPECT(bool(ClGemmMatrixMultiplyReshapedKernel::validate(&input0_info.clone()->set_is_resizable(true),
865 &input1_info.clone()->set_is_resizable(true),
866 &input2_info.clone()->set_is_resizable(true),
867 &output_info.clone()->set_is_resizable(true),1.f,1.f,
868 lhs_info,
869 rhs_info,
870 gemm_info)) == (expected && image2d_from_buffer_supported(CLKernelLibrary::get().get_device())), framework::LogLevel::ERRORS);
871 }
872
873 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::ALL,
874 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
875 m_values,
876 n_values),
877 k_values),
878 b_values),
879 m0_values_precommit),
880 n0_values_precommit),
881 k0_values_precommit),
882 v0_values_precommit),
883 h0_values_precommit),
884 i_values_lhs),
885 i_values_rhs),
886 framework::dataset::make("export_to_cl_image_rhs", true)),
887 framework::dataset::make("DataType", DataType::F32)),
888 a_values_precommit),
889 beta_values_precommit),
890 broadcast_bias_values),
891 lhs_transpose_values),
892 act_values))
893 {
894 // Validate output only if validate() is successful
895 if(validate_result)
896 {
897 validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
898 }
899 else
900 {
901 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
902 framework::ARM_COMPUTE_PRINT_INFO();
903 }
904
905 }
906
907 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::NIGHTLY,
908 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
909 m_values,
910 n_values),
911 k_values),
912 b_values),
913 m0_values_nightly),
914 n0_export_to_cl_image_values_nightly),
915 k0_export_to_cl_image_values_nightly),
916 v0_values_nightly),
917 h0_values_nightly),
918 i_values_lhs),
919 i_values_rhs),
920 framework::dataset::make("export_to_cl_image_rhs", true)),
921 framework::dataset::make("DataType", DataType::F32)),
922 a_values_nightly),
923 beta_values_nightly),
924 broadcast_bias_values),
925 lhs_transpose_values),
926 act_values))
927 {
928 // Validate output only if validate() is successful
929 if(validate_result)
930 {
931 validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
932 }
933 else
934 {
935 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
936 framework::ARM_COMPUTE_PRINT_INFO();
937 }
938 }
939
940 FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::ALL,
941 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
942 m_w_values,
943 m_h_values),
944 n_values),
945 k_values),
946 b_values),
947 m0_values_precommit),
948 n0_values_precommit),
949 k0_values_precommit),
950 v0_values_precommit),
951 h0_values_precommit),
952 i_values_lhs),
953 i_values_rhs),
954 framework::dataset::make("export_to_cl_image_rhs", true)),
955 framework::dataset::make("DataType", DataType::F32)),
956 a_values_precommit),
957 beta_values_precommit),
958 lhs_transpose_values),
959 act_values))
960 {
961 // Validate output only if validate() is successful
962 if(validate_result)
963 {
964 validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
965 }
966 else
967 {
968 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
969 framework::ARM_COMPUTE_PRINT_INFO();
970 }
971 }
972
973 FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::NIGHTLY,
974 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
975 m_w_values,
976 m_h_values),
977 n_values),
978 k_values),
979 b_values),
980 m0_values_nightly),
981 n0_export_to_cl_image_values_nightly),
982 k0_export_to_cl_image_values_nightly),
983 v0_values_nightly),
984 h0_values_nightly),
985 i_values_lhs),
986 i_values_rhs),
987 framework::dataset::make("export_to_cl_image_rhs", true)),
988 framework::dataset::make("DataType", DataType::F32)),
989 a_values_nightly),
990 beta_values_nightly),
991 lhs_transpose_values),
992 act_values))
993 {
994 // Validate output only if validate() is successful
995 if(validate_result)
996 {
997 validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
998 }
999 else
1000 {
1001 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1002 framework::ARM_COMPUTE_PRINT_INFO();
1003 }
1004 }
1005 TEST_SUITE(FusedPostOps)
1006
1007 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedWithPostOpsFixture<float>, framework::DatasetMode::ALL,
1008 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1009 m_values,
1010 n_values),
1011 k_values),
1012 b_values),
1013 m0_values_precommit),
1014 n0_values_precommit),
1015 k0_values_precommit),
1016 v0_values_precommit),
1017 h0_values_precommit),
1018 framework::dataset::make("interleave_lhs", { false })),
1019 framework::dataset::make("interleave_rhs", { false })),
1020 framework::dataset::make("export_to_cl_image_rhs", true)),
1021 framework::dataset::make("DataType", DataType::F32)),
1022 a_values_precommit),
1023 beta_values_precommit),
1024 framework::dataset::make("broadcast_bias", { true } )),
1025 lhs_transpose_values),
1026 act_values),
1027 post_op_lists)
1028 )
1029 {
1030 // Validate output only if validate() is successful
1031 if(validate_result)
1032 {
1033 validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
1034 }
1035 else
1036 {
1037 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1038 framework::ARM_COMPUTE_PRINT_INFO();
1039 }
1040 }
1041
1042 TEST_SUITE_END() // FusedPostOps
1043
TEST_SUITE_END()1044 TEST_SUITE_END() // ExportToCLImage
1045 TEST_SUITE_END() // FP32
1046
1047 TEST_SUITE(FP16)
1048
1049 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::ALL,
1050 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1051 m_values,
1052 n_values),
1053 k_values),
1054 b_values),
1055 m0_values_precommit),
1056 n0_values_precommit),
1057 k0_values_precommit),
1058 v0_values_precommit),
1059 h0_values_precommit),
1060 i_values_lhs),
1061 i_values_rhs),
1062 framework::dataset::make("export_to_cl_image_rhs", false)),
1063 framework::dataset::make("DataType", DataType::F16)),
1064 a_values_precommit),
1065 beta_values_precommit),
1066 broadcast_bias_values),
1067 lhs_transpose_values),
1068 act_values))
1069 {
1070 // Validate output
1071 if(validate_result)
1072 {
1073 validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1074 }
1075 else
1076 {
1077 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1078 framework::ARM_COMPUTE_PRINT_INFO();
1079 }
1080 }
1081
1082 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::DISABLED,
1083 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1084 m_values,
1085 n_values),
1086 k_values),
1087 b_values),
1088 m0_values_nightly),
1089 n0_values_nightly),
1090 k0_values_nightly),
1091 v0_values_nightly),
1092 h0_values_nightly),
1093 i_values_lhs),
1094 i_values_rhs),
1095 framework::dataset::make("export_to_cl_image_rhs", false)),
1096 framework::dataset::make("DataType", DataType::F16)),
1097 a_values_nightly),
1098 beta_values_nightly),
1099 broadcast_bias_values),
1100 lhs_transpose_values),
1101 act_values))
1102 {
1103 // Validate output
1104 if(validate_result)
1105 {
1106 validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1107 }
1108 else
1109 {
1110 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1111 framework::ARM_COMPUTE_PRINT_INFO();
1112 }
1113 }
1114
1115 FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, framework::DatasetMode::ALL,
1116 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1117 m_w_values,
1118 m_h_values),
1119 n_values),
1120 k_values),
1121 b_values),
1122 m0_values_precommit),
1123 n0_values_precommit),
1124 k0_values_precommit),
1125 v0_values_precommit),
1126 h0_values_precommit),
1127 i_values_lhs),
1128 i_values_rhs),
1129 framework::dataset::make("export_to_cl_image_rhs", false)),
1130 framework::dataset::make("DataType", DataType::F16)),
1131 a_values_precommit),
1132 beta_values_precommit),
1133 lhs_transpose_values),
1134 act_values))
1135 {
1136 // Validate output
1137 if(validate_result)
1138 {
1139 validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1140 }
1141 else
1142 {
1143 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1144 framework::ARM_COMPUTE_PRINT_INFO();
1145 }
1146 }
1147
1148 FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, framework::DatasetMode::DISABLED,
1149 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1150 m_w_values,
1151 m_h_values),
1152 n_values),
1153 k_values),
1154 b_values),
1155 m0_values_nightly),
1156 n0_values_nightly),
1157 k0_values_nightly),
1158 v0_values_nightly),
1159 h0_values_nightly),
1160 i_values_lhs),
1161 i_values_rhs),
1162 framework::dataset::make("export_to_cl_image_rhs", false)),
1163 framework::dataset::make("DataType", DataType::F16)),
1164 a_values_nightly),
1165 beta_values_nightly),
1166 lhs_transpose_values),
1167 act_values))
1168 {
1169 // Validate output
1170 if(validate_result)
1171 {
1172 validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1173 }
1174 else
1175 {
1176 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1177 framework::ARM_COMPUTE_PRINT_INFO();
1178 }
1179 }
1180
1181 TEST_SUITE(FusedPostOps)
1182
1183 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedWithPostOpsFixture<half>, framework::DatasetMode::ALL,
1184 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1185 m_values,
1186 n_values),
1187 k_values),
1188 b_values),
1189 m0_values_precommit),
1190 n0_values_precommit),
1191 k0_values_precommit),
1192 v0_values_precommit),
1193 h0_values_precommit),
1194 framework::dataset::make("interleave_lhs", { false })),
1195 framework::dataset::make("interleave_rhs", { false })),
1196 framework::dataset::make("export_to_cl_image_rhs", false)),
1197 framework::dataset::make("DataType", DataType::F16)),
1198 a_values_precommit),
1199 beta_values_precommit),
1200 framework::dataset::make("broadcast_bias", { true } )),
1201 lhs_transpose_values),
1202 act_values),
1203 post_op_lists)
1204 )
1205 {
1206 // Validate output
1207 if(validate_result)
1208 {
1209 validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1210 }
1211 else
1212 {
1213 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1214 framework::ARM_COMPUTE_PRINT_INFO();
1215 }
1216 }
1217
1218 TEST_SUITE_END() // FusedPostOps
1219
TEST_SUITE(ExportToCLImage)1220 TEST_SUITE(ExportToCLImage)
1221 DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zip(
1222 framework::dataset::make("Input0Info", { TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16), // OK or incorrect if cl_khr_image2d_from_buffer not supported
1223 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16), // OK or incorrect if cl_khr_image2d_from_buffer not supported
1224 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16), // OK or incorrect if cl_khr_image2d_from_buffer not supported
1225 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16), // Incorrect k0
1226 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16), // Incorrect n0
1227
1228 }),
1229 framework::dataset::make("Input1Info",{ TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16),
1230 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16),
1231 TensorInfo(TensorShape(512U, 8U, 2U), 1, DataType::F16),
1232 TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16),
1233 TensorInfo(TensorShape(128U, 32U, 2U), 1, DataType::F16),
1234
1235 })),
1236 framework::dataset::make("Input2Info", { TensorInfo(TensorShape(64U), 1, DataType::F16),
1237 TensorInfo(TensorShape(64U), 1, DataType::F16),
1238 TensorInfo(TensorShape(64U), 1, DataType::F16),
1239 TensorInfo(TensorShape(64U), 1, DataType::F16),
1240 TensorInfo(TensorShape(64U), 1, DataType::F16),
1241
1242 })),
1243 framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
1244 TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
1245 TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
1246 TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
1247 TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
1248 TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
1249
1250 })),
1251 framework::dataset::make("LHSMInfo",{
1252 GEMMLHSMatrixInfo(4, 4, 1, false, true),
1253 GEMMLHSMatrixInfo(4, 8, 1, false, true),
1254 GEMMLHSMatrixInfo(4, 4, 1, false, true),
1255 GEMMLHSMatrixInfo(4, 2, 1, false, false),
1256 GEMMLHSMatrixInfo(4, 4, 1, false, false),
1257
1258 })),
1259 framework::dataset::make("RHSMInfo",{
1260 GEMMRHSMatrixInfo(4, 4, 1, true, true, true),
1261 GEMMRHSMatrixInfo(4, 8, 1, true, true, true),
1262 GEMMRHSMatrixInfo(8, 4, 1, true, true, true),
1263 GEMMRHSMatrixInfo(4, 2, 1, true, false, true),
1264 GEMMRHSMatrixInfo(2, 4, 1, true, false, true),
1265 })),
1266 framework::dataset::make("GEMMInfo",{GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
1267 64 /**<N Number of RHS columns*/,
1268 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
1269 false /**< reinterpret the input as 3D */,
1270 true /**< Flag used to broadcast the bias addition */,
1271 false /**< wider accumm */,
1272 false /**< has pad y */,
1273 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
1274 1 /**< Multiplication factor for the width of the 1xW transposed block */,
1275 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
1276 GEMMLHSMatrixInfo(),
1277 GEMMRHSMatrixInfo(),
1278 0 /**< Offset to be added to each element of the matrix A */,
1279 0 /**< Offset to be added to each element of the matrix B */),
1280 GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
1281 64 /**<N Number of RHS columns*/,
1282 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
1283 false /**< reinterpret the input as 3D */,
1284 true /**< Flag used to broadcast the bias addition */,
1285 false /**< wider accumm */,
1286 false /**< has pad y */,
1287 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
1288 1 /**< Multiplication factor for the width of the 1xW transposed block */,
1289 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
1290 GEMMLHSMatrixInfo(),
1291 GEMMRHSMatrixInfo(),
1292 0 /**< Offset to be added to each element of the matrix A */,
1293 0 /**< Offset to be added to each element of the matrix B */),
1294 GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
1295 64 /**<N Number of RHS columns*/,
1296 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
1297 false /**< reinterpret the input as 3D */,
1298 true /**< Flag used to broadcast the bias addition */,
1299 false /**< wider accumm */,
1300 false /**< has pad y */,
1301 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
1302 1 /**< Multiplication factor for the width of the 1xW transposed block */,
1303 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
1304 GEMMLHSMatrixInfo(),
1305 GEMMRHSMatrixInfo(),
1306 0 /**< Offset to be added to each element of the matrix A */,
1307 0 /**< Offset to be added to each element of the matrix B */),
1308
1309 GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
1310 64 /**<N Number of RHS columns*/,
1311 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
1312 false /**< reinterpret the input as 3D */,
1313 true /**< Flag used to broadcast the bias addition */,
1314 false /**< wider accumm */,
1315 false /**< has pad y */,
1316 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
1317 1 /**< Multiplication factor for the width of the 1xW transposed block */,
1318 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
1319 GEMMLHSMatrixInfo(),
1320 GEMMRHSMatrixInfo(),
1321 0 /**< Offset to be added to each element of the matrix A */,
1322 0 /**< Offset to be added to each element of the matrix B */),
1323 GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
1324 64 /**<N Number of RHS columns*/,
1325 64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
1326 false /**< reinterpret the input as 3D */,
1327 true /**< Flag used to broadcast the bias addition */,
1328 false /**< wider accumm */,
1329 false /**< has pad y */,
1330 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
1331 1 /**< Multiplication factor for the width of the 1xW transposed block */,
1332 1 /**< Multiplication factor for the height of the 4x4 interleaved block */,
1333 GEMMLHSMatrixInfo(),
1334 GEMMRHSMatrixInfo(),
1335 0 /**< Offset to be added to each element of the matrix A */,
1336 0 /**< Offset to be added to each element of the matrix B */)
1337 })),
1338 framework::dataset::make("Expected", { true,
1339 true,
1340 true,
1341 false,
1342 true})),
1343 input0_info ,input1_info, input2_info, output_info, lhs_info, rhs_info, gemm_info, expected)
1344 {
1345 ARM_COMPUTE_EXPECT(bool(ClGemmMatrixMultiplyReshapedKernel::validate(&input0_info.clone()->set_is_resizable(true),
1346 &input1_info.clone()->set_is_resizable(true),
1347 &input2_info.clone()->set_is_resizable(true),
1348 &output_info.clone()->set_is_resizable(true),1.f,1.f,
1349 lhs_info,
1350 rhs_info,
1351 gemm_info)) == (expected && image2d_from_buffer_supported(CLKernelLibrary::get().get_device())), framework::LogLevel::ERRORS);
1352 }
1353
1354 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::ALL,
1355 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1356 m_values,
1357 n_values),
1358 k_values),
1359 b_values),
1360 m0_values_precommit),
1361 n0_values_precommit),
1362 k0_values_precommit),
1363 v0_values_precommit),
1364 h0_values_precommit),
1365 i_values_lhs),
1366 i_values_rhs),
1367 framework::dataset::make("export_to_cl_image_rhs", true)),
1368 framework::dataset::make("DataType", DataType::F16)),
1369 a_values_precommit),
1370 beta_values_precommit),
1371 broadcast_bias_values),
1372 lhs_transpose_values),
1373 act_values))
1374 {
1375 // Validate output only if validate() is successful
1376 if(validate_result)
1377 {
1378 validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1379 }
1380 else
1381 {
1382 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1383 framework::ARM_COMPUTE_PRINT_INFO();
1384 }
1385
1386 }
1387
1388 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::NIGHTLY,
1389 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1390 m_values,
1391 n_values),
1392 k_values),
1393 b_values),
1394 m0_values_nightly),
1395 n0_export_to_cl_image_values_nightly),
1396 k0_export_to_cl_image_values_nightly),
1397 v0_values_nightly),
1398 h0_values_nightly),
1399 i_values_lhs),
1400 i_values_rhs),
1401 framework::dataset::make("export_to_cl_image_rhs", true)),
1402 framework::dataset::make("DataType", DataType::F16)),
1403 a_values_nightly),
1404 beta_values_nightly),
1405 broadcast_bias_values),
1406 lhs_transpose_values),
1407 act_values))
1408 {
1409 // Validate output only if validate() is successful
1410 if(validate_result)
1411 {
1412 validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1413 }
1414 else
1415 {
1416 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1417 framework::ARM_COMPUTE_PRINT_INFO();
1418 }
1419 }
1420
1421 FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, framework::DatasetMode::ALL,
1422 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1423 m_w_values,
1424 m_h_values),
1425 n_values),
1426 k_values),
1427 b_values),
1428 m0_values_precommit),
1429 n0_values_precommit),
1430 k0_values_precommit),
1431 v0_values_precommit),
1432 h0_values_precommit),
1433 i_values_lhs),
1434 i_values_rhs),
1435 framework::dataset::make("export_to_cl_image_rhs", true)),
1436 framework::dataset::make("DataType", DataType::F16)),
1437 a_values_precommit),
1438 beta_values_precommit),
1439 lhs_transpose_values),
1440 act_values))
1441 {
1442 // Validate output only if validate() is successful
1443 if(validate_result)
1444 {
1445 validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1446 }
1447 else
1448 {
1449 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1450 framework::ARM_COMPUTE_PRINT_INFO();
1451 }
1452 }
1453
1454 FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, framework::DatasetMode::NIGHTLY,
1455 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1456 m_w_values,
1457 m_h_values),
1458 n_values),
1459 k_values),
1460 b_values),
1461 m0_values_nightly),
1462 n0_export_to_cl_image_values_nightly),
1463 k0_export_to_cl_image_values_nightly),
1464 v0_values_nightly),
1465 h0_values_nightly),
1466 i_values_lhs),
1467 i_values_rhs),
1468 framework::dataset::make("export_to_cl_image_rhs", true)),
1469 framework::dataset::make("DataType", DataType::F16)),
1470 a_values_nightly),
1471 beta_values_nightly),
1472 lhs_transpose_values),
1473 act_values))
1474 {
1475 // Validate output only if validate() is successful
1476 if(validate_result)
1477 {
1478 validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1479 }
1480 else
1481 {
1482 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1483 framework::ARM_COMPUTE_PRINT_INFO();
1484 }
1485 }
1486 TEST_SUITE(FusedPostOps)
1487
1488 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedWithPostOpsFixture<half>, framework::DatasetMode::ALL,
1489 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1490 m_values,
1491 n_values),
1492 k_values),
1493 b_values),
1494 m0_values_precommit),
1495 n0_values_precommit),
1496 k0_values_precommit),
1497 v0_values_precommit),
1498 h0_values_precommit),
1499 framework::dataset::make("interleave_lhs", { false })),
1500 framework::dataset::make("interleave_rhs", { false })),
1501 framework::dataset::make("export_to_cl_image_rhs", true)),
1502 framework::dataset::make("DataType", DataType::F16)),
1503 a_values_precommit),
1504 beta_values_precommit),
1505 framework::dataset::make("broadcast_bias", { true } )),
1506 lhs_transpose_values),
1507 act_values),
1508 post_op_lists)
1509 )
1510 {
1511 // Validate output only if validate() is successful
1512 if(validate_result)
1513 {
1514 validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1515 }
1516 else
1517 {
1518 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1519 framework::ARM_COMPUTE_PRINT_INFO();
1520 }
1521 }
1522
1523 TEST_SUITE_END() // FusedPostOps
1524
TEST_SUITE_END()1525 TEST_SUITE_END() // ExportToCLImage
1526 TEST_SUITE_END() // FP16
1527
1528 TEST_SUITE(MixedPrecision)
1529
1530 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture<half>, framework::DatasetMode::ALL,
1531 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1532 m_values,
1533 n_values),
1534 k_values),
1535 b_values),
1536 m0_values_precommit),
1537 n0_values_precommit),
1538 k0_values_precommit),
1539 v0_values_precommit),
1540 h0_values_precommit),
1541 i_values_lhs),
1542 i_values_rhs),
1543 framework::dataset::make("export_to_cl_image_rhs", false)),
1544 framework::dataset::make("DataType", DataType::F16)),
1545 a_values_precommit),
1546 beta_values_precommit),
1547 broadcast_bias_values),
1548 lhs_transpose_values),
1549 act_values))
1550 {
1551 // Validate output
1552 if(validate_result)
1553 {
1554 validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1555 }
1556 else
1557 {
1558 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1559 framework::ARM_COMPUTE_PRINT_INFO();
1560 }
1561 }
1562
1563 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture<half>, framework::DatasetMode::DISABLED,
1564 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1565 m_values,
1566 n_values),
1567 k_values),
1568 b_values),
1569 m0_values_nightly),
1570 n0_values_nightly),
1571 k0_values_nightly),
1572 v0_values_nightly),
1573 h0_values_nightly),
1574 i_values_lhs),
1575 i_values_rhs),
1576 framework::dataset::make("export_to_cl_image_rhs", false)),
1577 framework::dataset::make("DataType", DataType::F16)),
1578 a_values_nightly),
1579 beta_values_nightly),
1580 broadcast_bias_values),
1581 lhs_transpose_values),
1582 act_values))
1583 {
1584 // Validate output
1585 if(validate_result)
1586 {
1587 validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1588 }
1589 else
1590 {
1591 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1592 framework::ARM_COMPUTE_PRINT_INFO();
1593 }
1594 }
1595
1596 FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture<half>, framework::DatasetMode::ALL,
1597 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1598 m_w_values,
1599 m_h_values),
1600 n_values),
1601 k_values),
1602 b_values),
1603 m0_values_precommit),
1604 n0_values_precommit),
1605 k0_values_precommit),
1606 v0_values_precommit),
1607 h0_values_precommit),
1608 i_values_lhs),
1609 i_values_rhs),
1610 framework::dataset::make("export_to_cl_image_rhs", false)),
1611 framework::dataset::make("DataType", DataType::F16)),
1612 a_values_precommit),
1613 beta_values_precommit),
1614 lhs_transpose_values),
1615 act_values))
1616 {
1617 // Validate output
1618 if(validate_result)
1619 {
1620 validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1621 }
1622 else
1623 {
1624 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1625 framework::ARM_COMPUTE_PRINT_INFO();
1626 }
1627 }
1628
1629 FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture<half>, framework::DatasetMode::DISABLED,
1630 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1631 m_w_values,
1632 m_h_values),
1633 n_values),
1634 k_values),
1635 b_values),
1636 m0_values_nightly),
1637 n0_values_nightly),
1638 k0_values_nightly),
1639 v0_values_nightly),
1640 h0_values_nightly),
1641 i_values_lhs),
1642 i_values_rhs),
1643 framework::dataset::make("export_to_cl_image_rhs", false)),
1644 framework::dataset::make("DataType", DataType::F16)),
1645 a_values_nightly),
1646 beta_values_nightly),
1647 lhs_transpose_values),
1648 act_values))
1649 {
1650 // Validate output
1651 if(validate_result)
1652 {
1653 validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1654 }
1655 else
1656 {
1657 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1658 framework::ARM_COMPUTE_PRINT_INFO();
1659 }
1660 }
1661
1662 TEST_SUITE(FusedPostOps)
1663
1664 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedMixedPrecisionWithPostOpsFixture<half>, framework::DatasetMode::ALL,
1665 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1666 m_values,
1667 n_values),
1668 k_values),
1669 b_values),
1670 m0_values_precommit),
1671 n0_values_precommit),
1672 k0_values_precommit),
1673 v0_values_precommit),
1674 h0_values_precommit),
1675 framework::dataset::make("interleave_lhs", { false })),
1676 framework::dataset::make("interleave_rhs", { false })),
1677 framework::dataset::make("export_to_cl_image_rhs", { true, false })),
1678 framework::dataset::make("DataType", DataType::F16)),
1679 a_values_precommit),
1680 beta_values_precommit),
1681 framework::dataset::make("broadcast_bias", { true } )),
1682 lhs_transpose_values),
1683 act_values),
1684 post_op_lists)
1685 )
1686 {
1687 // Validate output
1688 if(validate_result)
1689 {
1690 validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1691 }
1692 else
1693 {
1694 ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1695 framework::ARM_COMPUTE_PRINT_INFO();
1696 }
1697 }
1698
1699 TEST_SUITE_END() // FusedPostOps
1700
1701 TEST_SUITE_END() // MixedPrecision
1702 TEST_SUITE_END() // Float
1703 TEST_SUITE_END() // GEMMMatrixMultiplyReshaped
1704 TEST_SUITE_END() // CL
1705 } // namespace validation
1706 } // namespace test
1707 } // namespace arm_compute
1708