xref: /aosp_15_r20/external/ComputeLibrary/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2018-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/core/KernelDescriptors.h"
25 #include "arm_compute/core/Types.h"
26 #include "arm_compute/core/experimental/PostOps.h"
27 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
28 #include "arm_compute/runtime/CL/CLTensor.h"
29 #include "arm_compute/runtime/CL/CLTensorAllocator.h"
30 #include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.h"
31 #include "src/gpu/cl/kernels/ClGemmReshapeLhsMatrixKernel.h"
32 #include "src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.h"
33 #include "tests/CL/CLAccessor.h"
34 #include "tests/CL/Helper.h"
35 #include "tests/PaddingCalculator.h"
36 #include "tests/datasets/ShapeDatasets.h"
37 #include "tests/framework/Asserts.h"
38 #include "tests/framework/Macros.h"
39 #include "tests/framework/datasets/Datasets.h"
40 #include "tests/validation/Validation.h"
41 #include "tests/validation/fixtures/GEMMFixture.h"
42 
43 namespace arm_compute
44 {
45 namespace test
46 {
47 namespace validation
48 {
49 using namespace arm_compute::misc::shape_calculator;
50 using namespace arm_compute::opencl::kernels;
51 
52 // Create function for ClGemmReshapeLhsMatrixKernel
53 using CLGEMMReshapeLHSMatrix = CLSynthetizeOperator<ClGemmReshapeLhsMatrixKernel>;
54 
55 // Create function for ClGemmReshapeRhsMatrixKernel
56 using CLGEMMReshapeRHSMatrix = CLSynthetizeOperator<ClGemmReshapeRhsMatrixKernel>;
57 
58 // Create function for ClGemmMatrixMultiplyReshapedKernel
59 using CLGEMMMatrixMultiplyReshaped = CLSynthetizeOperator<ClGemmMatrixMultiplyReshapedKernel>;
60 
61 // Fixture for CLGEMMMatrixMultiplyReshaped
62 template <typename T>
63 using CLGEMMMatrixMultiplyReshapedFixture = GEMMMatrixMultiplyReshapedValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped>;
64 
65 // Fixture for CLGEMMMatrixMultiplyReshaped with post ops
66 template <typename T>
67 using CLGEMMMatrixMultiplyReshapedWithPostOpsFixture =
68     GEMMMatrixMultiplyReshapedWithPostOpsValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped>;
69 
70 // Fixture for CLGEMMMatrixMultiplyReshaped mixed precision
71 template <typename T>
72 using CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture =
73     GEMMMatrixMultiplyReshapedValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true>;
74 
75 // Fixture for CLGEMMMatrixMultiplyReshaped mixed precision with post ops
76 template <typename T>
77 using CLGEMMMatrixMultiplyReshapedMixedPrecisionWithPostOpsFixture =
78     GEMMMatrixMultiplyReshapedWithPostOpsValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true>;
79 
80 // Fixture for CLGEMMMatrixMultiplyReshaped3D
81 template <typename T>
82 using CLGEMMMatrixMultiplyReshaped3DFixture = GEMMMatrixMultiplyReshaped3DValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped>;
83 
84 // Fixture for CLGEMMMatrixMultiplyReshaped3D mixed precision
85 template <typename T>
86 using CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture =
87     GEMMMatrixMultiplyReshaped3DValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true>;
88 
89 namespace
90 {
91 // *INDENT-OFF*
92 // clang-format off
93 RelativeTolerance<float> rel_tolerance_f32(0.001f);
94 constexpr float          abs_tolerance_f32(0.0001f);
95 
96 RelativeTolerance<float> rel_tolerance_f16_mixed_precision(0.001f);
97 constexpr float          abs_tolerance_f16_mixed_precision(0.01f);
98 
99 RelativeTolerance<float> rel_tolerance_f16(0.001f);
100 constexpr float          abs_tolerance_f16(0.01f);
101 
102 /** M values to test */
103 const auto m_values = framework::dataset::make("M", 17);
104 
105 /** M_W values to test */
106 const auto m_w_values = framework::dataset::make("M_W", 5);
107 
108 /** M_H values to test */
109 const auto m_h_values = framework::dataset::make("M_H", 7);
110 
111 /** N values to test */
112 const auto n_values = framework::dataset::make("N", 21);
113 
114 /** K values to test */
115 const auto k_values = framework::dataset::make("K", 13);
116 
117 /** Batch size values to test */
118 const auto b_values = framework::dataset::make("batch_size", 2, 3);
119 
120 /** Activation values to test */
121 const auto act_values = framework::dataset::make("Activation",
122 {
123     ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 8.f, 2.f),
124     ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::ELU),
125 });
126 
127 /** Alpha values to test - Precommit */
128 const auto a_values_precommit = framework::dataset::make("alpha", {-0.75f} );
129 
130 /** Beta values to test - Precommit */
131 const auto beta_values_precommit = framework::dataset::make("beta", {-0.35f} );
132 
133 /** M0 values to test - Precommit */
134 const auto m0_values_precommit = framework::dataset::make("M0", { 4 });
135 
136 /** N0 values to test - Precommit */
137 const auto n0_values_precommit = framework::dataset::make("N0", { 4 });
138 
139 /** K0 values to test - Precommit */
140 const auto k0_values_precommit = framework::dataset::make("K0", { 4 });
141 
142 /** V0 values to test - Precommit */
143 const auto v0_values_precommit = framework::dataset::make("V0", 1, 3);
144 
145 /** H0 values to test - Precommit */
146 const auto h0_values_precommit = framework::dataset::make("H0", 1, 3);
147 
148 /** Alpha values to test - Nightly */
149 const auto a_values_nightly = framework::dataset::make("alpha", {1.0f} );
150 
151 /** Beta values to test - Nightly */
152 const auto beta_values_nightly = framework::dataset::make("beta", {1.0f} );
153 
154 /** M0 values to test - Nightly */
155 const auto m0_values_nightly = framework::dataset::make("M0", { 8 });
156 
157 /** N0 values to test - Nightly */
158 const auto n0_values_nightly = framework::dataset::make("N0", { 8 });
159 
160 /** K0 values to test - Nightly */
161 const auto k0_values_nightly = framework::dataset::make("K0", { 4 });
162 
163 /** N0 values to test with export to OpenCL image object - Nightly */
164 const auto n0_export_to_cl_image_values_nightly = framework::dataset::make("N0", { 4, 8, 16 });
165 
166 /** K0 values to test with export to OpenCL image object - Nightly */
167 const auto k0_export_to_cl_image_values_nightly = framework::dataset::make("K0", { 4, 8, 16 });
168 
169 /** V0 values to test - Nightly */
170 const auto v0_values_nightly = framework::dataset::make("V0", 1, 3);
171 
172 /** H0 values to test - Nightly */
173 const auto h0_values_nightly = framework::dataset::make("H0", 1, 3);
174 
175 /** Interleave values to test with LHS matrix */
176 const auto i_values_lhs = framework::dataset::make("interleave_lhs", { true, false });
177 
178 /** Interleave values to test with RHS matrix */
179 const auto i_values_rhs = framework::dataset::make("interleave_rhs", { true, false });
180 
181 /** Broadcast bias from vector to matrix */
182 const auto broadcast_bias_values = framework::dataset::make("broadcast_bias", { false, true } );
183 
184 /** LHS transposed values */
185 const auto lhs_transpose_values = framework::dataset::make("lhs_transpose", { false, true } );
186 
187 /** Post Ops */
188 using PostOpArgBroadcast =  CLGEMMMatrixMultiplyReshapedWithPostOpsFixture<float>::PostOpArgBroadcast;
post_ops_1()189 experimental::PostOpList<PostOpArgBroadcast> post_ops_1()
190 {
191     experimental::PostOpList<PostOpArgBroadcast> post_ops{};
192     post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F});
193     post_ops.push_back_op<experimental::PostOpEltwiseAdd<PostOpArgBroadcast>>(
194         std::make_tuple(true, true, false),   // If broadcast in dims 0, 1 and 2
195         0,
196         ConvertPolicy::SATURATE);
197     post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
198     return post_ops;
199 }
post_ops_2()200 experimental::PostOpList<PostOpArgBroadcast> post_ops_2()
201 {
202     experimental::PostOpList<PostOpArgBroadcast> post_ops{};
203     post_ops.push_back_op<experimental::PostOpEltwiseAdd<PostOpArgBroadcast>>(
204         std::make_tuple(false, true, true),   // If broadcast in dims 0, 1 and 2
205         1,
206         ConvertPolicy::SATURATE);
207     post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
208     return post_ops;
209 }
post_ops_3()210 experimental::PostOpList<PostOpArgBroadcast> post_ops_3()
211 {
212     experimental::PostOpList<PostOpArgBroadcast> post_ops{};
213     post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
214     post_ops.push_back_op<experimental::PostOpEltwiseAdd<PostOpArgBroadcast>>(
215         std::make_tuple(false, false, true),  // If broadcast in dims 0, 1 and 2
216         1,
217         ConvertPolicy::SATURATE);
218     return post_ops;
219 }
220 // To test that the output of the main op is the first parameter in prelu post op
post_ops_4()221 experimental::PostOpList<PostOpArgBroadcast> post_ops_4()
222 {
223     experimental::PostOpList<PostOpArgBroadcast> post_ops{};
224     post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F});
225     post_ops.push_back_op<experimental::PostOpEltwisePRelu<PostOpArgBroadcast>>(
226         std::make_tuple(false, false, true),   // If true, broadcast in corresponding dim: 0, 1 or 2
227         0,
228         ConvertPolicy::SATURATE);
229     post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
230     return post_ops;
231 }
232 // To test that the output of the main op is the second parameter in prelu post op i.e. it is the alpha_param
post_ops_5()233 experimental::PostOpList<PostOpArgBroadcast> post_ops_5()
234 {
235     experimental::PostOpList<PostOpArgBroadcast> post_ops{};
236     post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F});
237     post_ops.push_back_op<experimental::PostOpEltwisePRelu<PostOpArgBroadcast>>(
238         std::make_tuple(false, false, false),   // If true, broadcast in corresponding dim: 0, 1 or 2
239         1,
240         ConvertPolicy::SATURATE);
241     post_ops.push_back_op<experimental::PostOpAct<PostOpArgBroadcast>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::RELU, 2.1F, 1.3F});
242     return post_ops;
243 }
244 /** Different Post Op Lists */
245 const auto post_op_lists = framework::dataset::make("post_op_lists", {
246     post_ops_1(),
247     post_ops_2(),
248     post_ops_3(),
249     post_ops_4(),
250     post_ops_5()
251  } );
252 
is_post_op_list_valid(unsigned int m,unsigned int n,unsigned int k,unsigned int batch,DataType data_type,const experimental::PostOpList<ITensorInfo * > & post_ops)253 bool is_post_op_list_valid(unsigned int m, unsigned int n, unsigned int k, unsigned int batch, DataType data_type, const experimental::PostOpList<ITensorInfo*>& post_ops)
254 {
255     const auto lhs_info = GEMMLHSMatrixInfo(4,4,1,false,true);
256     const auto rhs_info = GEMMRHSMatrixInfo(4,4,1,true,true,false);
257 
258     // Create TensorInfo for post op arguments
259     TensorInfo input0_info(TensorShape(k, m, batch), 1, data_type);
260     TensorInfo input1_info(TensorShape(n, k, batch), 1, data_type);
261     TensorInfo input2_info(TensorShape(n), 1, data_type);
262     TensorInfo output_info(TensorShape(n, m, batch), 1, data_type);
263 
264     const TensorInfo reshaped_input0_info = input0_info.clone()->set_tensor_shape(misc::shape_calculator::compute_lhs_reshaped_shape(input0_info, lhs_info));
265     const TensorInfo reshaped_input1_info = input1_info.clone()->set_tensor_shape(misc::shape_calculator::compute_rhs_reshaped_shape(input1_info, rhs_info));
266 
267     GEMMKernelInfo gemm_info(m, n, k, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
268              false /**< reinterpret the input as 3D */,
269              true  /**< Flag used to broadcast the bias addition */,
270              false /**< wider accumm */,
271              false /**< has pad y */,
272            ActivationLayerInfo::ActivationFunction::IDENTITY,
273              1   /**< Multiplication factor for the width of the 1xW transposed block */,
274              1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
275              lhs_info,
276              rhs_info,
277              0  /**< Offset to be added to each element of the matrix A */,
278              0 /**< Offset to be added to each element of the matrix B */,
279              post_ops);
280     return bool(ClGemmMatrixMultiplyReshapedKernel::validate(&reshaped_input0_info.clone()->set_is_resizable(true),
281                                                           &reshaped_input1_info.clone()->set_is_resizable(true),
282                                                           &input2_info.clone()->set_is_resizable(true),
283                                                           &output_info.clone()->set_is_resizable(true),1.f,1.f,
284                                                           lhs_info,
285                                                           rhs_info,
286                                                           gemm_info));
287 }
288 
289 } // namespace
290 
291 TEST_SUITE(CL)
TEST_SUITE(GEMMMatrixMultiplyReshaped)292 TEST_SUITE(GEMMMatrixMultiplyReshaped)
293 
294 // *INDENT-OFF*
295 // clang-format off
296 DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zip(
297                framework::dataset::make("Input0Info", { TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F32),      // OK
298                                                         TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F16),      // OK
299                                                         TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::QASYMM8),  // Data type not supported
300                                                         TensorInfo(TensorShape(10U, 5U, 2U), 1, DataType::F32),      // Incorrect dimension bias
301                                                         TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F32),      // Mismatching shapes
302                                                         TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F16),      // OK, do not broadcast bias
303                                                         TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F16),      // OK, wider accummulation
304                                                         TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::F16),      // OK, RHS 4,4,2
305 
306                                                       }),
307                framework::dataset::make("Input1Info",{ TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F32),
308                                                        TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F16),
309                                                        TensorInfo(TensorShape(64U, 5U, 2U), 1, DataType::QASYMM8),
310                                                        TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F32),
311                                                        TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
312                                                        TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F16),
313                                                        TensorInfo(TensorShape(64U, 6U, 2U), 1, DataType::F16),
314                                                        TensorInfo(TensorShape(128U, 3U, 2U), 1, DataType::F16),
315 
316                       })),
317                framework::dataset::make("Input2Info", { TensorInfo(TensorShape(21U), 1, DataType::F32),
318                                                         TensorInfo(TensorShape(21U), 1, DataType::F16),
319                                                         TensorInfo(TensorShape(21U), 1, DataType::QASYMM8),
320                                                         TensorInfo(TensorShape(21U), 1, DataType::F32),
321                                                         TensorInfo(TensorShape(21U), 1, DataType::F32),
322                                                         TensorInfo(TensorShape(21U,17U), 1, DataType::F16),
323                                                         TensorInfo(TensorShape(21U,17U), 1, DataType::F16),
324                                                         TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F16),
325 
326                                                       })),
327                framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F32),
328                                                        TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F16),
329                                                        TensorInfo(TensorShape(21U,17U,2U), 1, DataType::QASYMM8),
330                                                        TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F32),
331                                                        TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F32),
332                                                        TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F16),
333                                                        TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F16),
334                                                        TensorInfo(TensorShape(21U,17U,2U), 1, DataType::F16),
335 
336                            })),
337                framework::dataset::make("LHSMInfo",{
338                                                           GEMMLHSMatrixInfo(4,4,1,false,true),
339                                                           GEMMLHSMatrixInfo(4,4,1,false,true),
340                                                           GEMMLHSMatrixInfo(4,4,1,false,true),
341                                                           GEMMLHSMatrixInfo(4,2,4,false,false),
342                                                           GEMMLHSMatrixInfo(4,2,4,false,false),
343                                                           GEMMLHSMatrixInfo(4,4,1,false,true),
344                                                           GEMMLHSMatrixInfo(4,4,1,false,true),
345                                                           GEMMLHSMatrixInfo(4,4,1,false,true),
346 
347                                 })),
348                framework::dataset::make("RHSMInfo",{
349                                                           GEMMRHSMatrixInfo(4,4,1,true,true,false),
350                                                           GEMMRHSMatrixInfo(4,4,1,true,true,false),
351                                                           GEMMRHSMatrixInfo(4,4,1,true,true,false),
352                                                           GEMMRHSMatrixInfo(2,2,1,true,false,false),
353                                                           GEMMRHSMatrixInfo(2,2,1,true,false,false),
354                                                           GEMMRHSMatrixInfo(4,4,1,true,true,false),
355                                                           GEMMRHSMatrixInfo(4,4,1,true,true,false),
356                                                           GEMMRHSMatrixInfo(4,4,2,true,false,false),
357 
358 
359                            })),
360 
361 
362                framework::dataset::make("GEMMInfo",{
363                                                             GEMMKernelInfo( 17 /**<M Number of LHS rows*/,
364                                                                             21 /**<N Number of RHS columns*/,
365                                                                             13 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
366                                                                      false /**< reinterpret the input as 3D */,
367                                                                      true  /**< Flag used to broadcast the bias addition */,
368                                                                      false /**< wider accumm */,
369                                                                      false /**< has pad y */,
370                                                                    ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
371                                                                      1   /**< Multiplication factor for the width of the 1xW transposed block */,
372                                                                      1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
373                                                                      GEMMLHSMatrixInfo(4,4,1,false,true),
374                                                                      GEMMRHSMatrixInfo(4,4,1,true,true,false),
375                                                                      0  /**< Offset to be added to each element of the matrix A */,
376                                                                      0 /**< Offset to be added to each element of the matrix B */),
377 
378                                                             GEMMKernelInfo( 17 /**<M Number of LHS rows*/,
379                                                                             21 /**<N Number of RHS columns*/,
380                                                                             13 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
381                                                                      false /**< reinterpret the input as 3D */,
382                                                                      true  /**< Flag used to broadcast the bias addition */,
383                                                                      false /**< wider accumm */,
384                                                                      false /**< has pad y */,
385                                                                    ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
386                                                                      1   /**< Multiplication factor for the width of the 1xW transposed block */,
387                                                                      1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
388                                                                      GEMMLHSMatrixInfo(4,4,1,false,true),
389                                                                      GEMMRHSMatrixInfo(4,4,1,true,true,false),
390                                                                      0  /**< Offset to be added to each element of the matrix A */,
391                                                                      0 /**< Offset to be added to each element of the matrix B */),
392                                                             GEMMKernelInfo(),
393                                                             GEMMKernelInfo(),
394                                                             GEMMKernelInfo(),
395 
396                                                             GEMMKernelInfo( 17 /**<M Number of LHS rows*/,
397                                                                             21 /**<N Number of RHS columns*/,
398                                                                             13 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
399                                                                      false /**< reinterpret the input as 3D */,
400                                                                      false  /**< Flag used to broadcast the bias addition */,
401                                                                      false /**< wider accumm */,
402                                                                      false /**< has pad y */,
403                                                                    ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
404                                                                      1   /**< Multiplication factor for the width of the 1xW transposed block */,
405                                                                      1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
406                                                                      GEMMLHSMatrixInfo(4,4,1,false,true),
407                                                                      GEMMRHSMatrixInfo(4,4,1,true,true,false),
408                                                                      0  /**< Offset to be added to each element of the matrix A */,
409                                                                      0 /**< Offset to be added to each element of the matrix B */),
410 
411 
412                                                             GEMMKernelInfo( 17 /**<M Number of LHS rows*/,
413                                                                             21 /**<N Number of RHS columns*/,
414                                                                             13 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
415                                                                      false /**< reinterpret the input as 3D */,
416                                                                      false  /**< Flag used to broadcast the bias addition */,
417                                                                      true /**< wider accumm */,
418                                                                      true /**< has pad y */,
419                                                                    ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
420                                                                      1   /**< Multiplication factor for the width of the 1xW transposed block */,
421                                                                      1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
422                                                                      GEMMLHSMatrixInfo(4,4,1,false,true),
423                                                                      GEMMRHSMatrixInfo(4,4,1,true,true,false),
424                                                                      0  /**< Offset to be added to each element of the matrix A */,
425                                                                      0 /**< Offset to be added to each element of the matrix B */),
426 
427                                                             GEMMKernelInfo( 17 /**<M Number of LHS rows*/,
428                                                                             21 /**<N Number of RHS columns*/,
429                                                                             13 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
430                                                                      false /**< reinterpret the input as 3D */,
431                                                                      false  /**< Flag used to broadcast the bias addition */,
432                                                                      false /**< wider accumm */,
433                                                                      false /**< has pad y */,
434                                                                    ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
435                                                                      1   /**< Multiplication factor for the width of the 1xW transposed block */,
436                                                                      1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
437                                                                      GEMMLHSMatrixInfo(4,4,1,false,true),
438                                                                      GEMMRHSMatrixInfo(4,4,2,true,false,false),
439                                                                      0  /**< Offset to be added to each element of the matrix A */,
440                                                                      0 /**< Offset to be added to each element of the matrix B */),
441                                                     })),
442                framework::dataset::make("Expected", { true, true, false, false, false, true, true,true})),
443                     input0_info ,input1_info, input2_info, output_info, lhs_info, rhs_info, gemm_info, expected)
444 {
445     ARM_COMPUTE_EXPECT(bool(ClGemmMatrixMultiplyReshapedKernel::validate(&input0_info.clone()->set_is_resizable(true),
446                                                           &input1_info.clone()->set_is_resizable(true),
447                                                           &input2_info.clone()->set_is_resizable(true),
448                                                           &output_info.clone()->set_is_resizable(true),1.f,1.f,
449                                                           lhs_info,
450                                                           rhs_info,
451                                                           gemm_info)) == expected, framework::LogLevel::ERRORS);
452 }
453 TEST_SUITE(ValidateFusedPostOpsConfigs)
TEST_SUITE(Invalid)454 TEST_SUITE(Invalid)
455 TEST_CASE(UnsupportedPostOpSequence, framework::DatasetMode::ALL)
456 {
457     const auto data_type = DataType::F32;
458     const unsigned int m = 17;
459     const unsigned int n = 1;
460     const unsigned int k = 13;
461     const unsigned int batch = 2;
462     TensorShape post_op_arg0_shape(n, m, batch);
463     TensorInfo post_op_arg_info(post_op_arg0_shape, 1, data_type);
464     auto post_op_arg1_info = post_op_arg_info.clone();
465 
466     // Unsupported sequence of post ops
467     experimental::PostOpList<ITensorInfo*> post_ops{};
468     post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo*>>(
469         &post_op_arg_info,
470         1,
471         ConvertPolicy::SATURATE);
472     post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo*>>(
473         post_op_arg1_info.get(),
474         0,
475         ConvertPolicy::SATURATE);
476 
477     ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == false, framework::LogLevel::ERRORS);
478 }
TEST_CASE(OutputWidened,framework::DatasetMode::ALL)479 TEST_CASE(OutputWidened, framework::DatasetMode::ALL)
480 {
481     // Invalid broadcast: post op tensors "widen" the output tensor
482     const auto data_type = DataType::F32;
483     const unsigned int m = 17;
484     const unsigned int n = 1;
485     const unsigned int k = 13;
486     const unsigned int batch = 2;
487     TensorShape post_op_arg_shape(n + 4, m, batch); // output's X dimension (n) is "widened", which is not allowed
488     TensorInfo post_op_arg_info(post_op_arg_shape, 1, data_type);
489     experimental::PostOpList<ITensorInfo*> post_ops{};
490     post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo*>>( &post_op_arg_info, 0, ConvertPolicy::SATURATE);
491 
492     ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == false, framework::LogLevel::ERRORS);
493 }
TEST_CASE(BroadcastInXDimOnly,framework::DatasetMode::ALL)494 TEST_CASE(BroadcastInXDimOnly, framework::DatasetMode::ALL)
495 {
496     // Invalid broadcast: post op tensors broadcast in the first dimension (X) only
497     const auto data_type = DataType::F32;
498     const unsigned int m = 22;
499     const unsigned int n = 16;
500     const unsigned int k = 15;
501     const unsigned int batch = 3;
502     TensorShape post_op_arg_shape(1, m, batch);
503     TensorInfo post_op_arg_info(post_op_arg_shape, 1, data_type);
504     experimental::PostOpList<ITensorInfo*> post_ops{};
505     post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo*>>( &post_op_arg_info, 0, ConvertPolicy::SATURATE);
506 
507     ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == false, framework::LogLevel::ERRORS);
508 }
509 TEST_SUITE_END() // Invalid
TEST_SUITE(Valid)510 TEST_SUITE(Valid)
511 TEST_CASE(EmptyPostOpList, framework::DatasetMode::ALL)
512 {
513     const auto data_type = DataType::F32;
514     const unsigned int m = 22;
515     const unsigned int n = 16;
516     const unsigned int k = 15;
517     const unsigned int batch = 3;
518     experimental::PostOpList<ITensorInfo*> post_ops{};
519 
520     ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == true, framework::LogLevel::ERRORS);
521 }
TEST_CASE(BroadcastInYDimOnly,framework::DatasetMode::ALL)522 TEST_CASE(BroadcastInYDimOnly, framework::DatasetMode::ALL)
523 {
524     const auto data_type = DataType::F32;
525     const unsigned int m = 22;
526     const unsigned int n = 16;
527     const unsigned int k = 15;
528     const unsigned int batch = 3;
529     TensorShape post_op_arg_shape(n, 1, batch);
530     TensorInfo post_op_arg_info(post_op_arg_shape, 1, data_type);
531     experimental::PostOpList<ITensorInfo*> post_ops{};
532     post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo*>>( &post_op_arg_info, 0, ConvertPolicy::SATURATE);
533 
534     ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == true, framework::LogLevel::ERRORS);
535 }
TEST_CASE(BroadcastInBothXandYDims,framework::DatasetMode::ALL)536 TEST_CASE(BroadcastInBothXandYDims, framework::DatasetMode::ALL)
537 {
538     const auto data_type = DataType::F32;
539     const unsigned int m = 22;
540     const unsigned int n = 16;
541     const unsigned int k = 15;
542     const unsigned int batch = 3;
543     TensorShape post_op_arg_shape(1, 1, batch);
544     TensorInfo post_op_arg_info(post_op_arg_shape, 1, data_type);
545     experimental::PostOpList<ITensorInfo*> post_ops{};
546     post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo*>>( &post_op_arg_info, 0, ConvertPolicy::SATURATE);
547 
548     ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == true, framework::LogLevel::ERRORS);
549 }
TEST_CASE(BroadcastInAllDims,framework::DatasetMode::ALL)550 TEST_CASE(BroadcastInAllDims, framework::DatasetMode::ALL)
551 {
552     const auto data_type = DataType::F32;
553     const unsigned int m = 22;
554     const unsigned int n = 16;
555     const unsigned int k = 15;
556     const unsigned int batch = 3;
557     TensorShape post_op_arg_shape(1, 1, 1);
558     TensorInfo post_op_arg_info(post_op_arg_shape, 1, data_type);
559     experimental::PostOpList<ITensorInfo*> post_ops{};
560     post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo*>>( &post_op_arg_info, 0, ConvertPolicy::SATURATE);
561 
562     ARM_COMPUTE_EXPECT(is_post_op_list_valid(m, n, k, batch, data_type, post_ops) == true, framework::LogLevel::ERRORS);
563 }
564 TEST_SUITE_END() // Valid
TEST_SUITE_END()565 TEST_SUITE_END() // ValidateFusedPostOps
566 TEST_SUITE(Float)
567 TEST_SUITE(FP32)
568 
569 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::ALL,
570                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
571                                                                    m_values,
572                                                                    n_values),
573                                                                    k_values),
574                                                                    b_values),
575                                                                    m0_values_precommit),
576                                                                    n0_values_precommit),
577                                                                    k0_values_precommit),
578                                                                    v0_values_precommit),
579                                                                    h0_values_precommit),
580                                                                    i_values_lhs),
581                                                                    i_values_rhs),
582                                                                    framework::dataset::make("export_to_cl_image_rhs", false)),
583                                                                    framework::dataset::make("DataType", DataType::F32)),
584                                                                    a_values_precommit),
585                                                                    beta_values_precommit),
586                                                                    broadcast_bias_values),
587                                                                    lhs_transpose_values),
588                                                                    act_values))
589 {
590     // Validate output
591     if(validate_result)
592     {
593         validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
594     }
595     else
596     {
597         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
598         framework::ARM_COMPUTE_PRINT_INFO();
599     }
600 }
601 
602 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::DISABLED,
603                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
604                                                                    m_values,
605                                                                    n_values),
606                                                                    k_values),
607                                                                    b_values),
608                                                                    m0_values_nightly),
609                                                                    n0_values_nightly),
610                                                                    k0_values_nightly),
611                                                                    v0_values_nightly),
612                                                                    h0_values_nightly),
613                                                                    i_values_lhs),
614                                                                    i_values_rhs),
615                                                                    framework::dataset::make("export_to_cl_image_rhs", false)),
616                                                                    framework::dataset::make("DataType", DataType::F32)),
617                                                                    a_values_nightly),
618                                                                    beta_values_nightly),
619                                                                    broadcast_bias_values),
620                                                                    lhs_transpose_values),
621                                                                    act_values))
622 {
623     // Validate output
624     if(validate_result)
625     {
626         validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
627     }
628     else
629     {
630         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
631         framework::ARM_COMPUTE_PRINT_INFO();
632     }
633 }
634 
635 FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::ALL,
636                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
637                                                                    m_w_values,
638                                                                    m_h_values),
639                                                                    n_values),
640                                                                    k_values),
641                                                                    b_values),
642                                                                    m0_values_precommit),
643                                                                    n0_values_precommit),
644                                                                    k0_values_precommit),
645                                                                    v0_values_precommit),
646                                                                    h0_values_precommit),
647                                                                    i_values_lhs),
648                                                                    i_values_rhs),
649                                                                    framework::dataset::make("export_to_cl_image_rhs", false)),
650                                                                    framework::dataset::make("DataType", DataType::F32)),
651                                                                    a_values_precommit),
652                                                                    beta_values_precommit),
653                                                                    lhs_transpose_values),
654                                                                    act_values))
655 {
656     // Validate output
657     if(validate_result)
658     {
659         validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
660     }
661     else
662     {
663         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
664         framework::ARM_COMPUTE_PRINT_INFO();
665     }
666 }
667 
668 FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::DISABLED,
669                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
670                                                                    m_w_values,
671                                                                    m_h_values),
672                                                                    n_values),
673                                                                    k_values),
674                                                                    b_values),
675                                                                    m0_values_nightly),
676                                                                    n0_values_nightly),
677                                                                    k0_values_nightly),
678                                                                    v0_values_nightly),
679                                                                    h0_values_nightly),
680                                                                    i_values_lhs),
681                                                                    i_values_rhs),
682                                                                    framework::dataset::make("export_to_cl_image_rhs", false)),
683                                                                    framework::dataset::make("DataType", DataType::F32)),
684                                                                    a_values_nightly),
685                                                                    beta_values_nightly),
686                                                                    lhs_transpose_values),
687                                                                    act_values))
688 {
689     // Validate output
690     if(validate_result)
691     {
692         validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
693     }
694     else
695     {
696         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
697         framework::ARM_COMPUTE_PRINT_INFO();
698     }
699 }
700 TEST_SUITE(FusedPostOps)
701 
702 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedWithPostOpsFixture<float>, framework::DatasetMode::ALL,
703                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
704                                                                    m_values,
705                                                                    n_values),
706                                                                    k_values),
707                                                                    b_values),
708                                                                    m0_values_precommit),
709                                                                    n0_values_precommit),
710                                                                    k0_values_precommit),
711                                                                    v0_values_precommit),
712                                                                    h0_values_precommit),
713                                                                    framework::dataset::make("interleave_lhs", { false })),
714                                                                    framework::dataset::make("interleave_rhs", { false })),
715                                                                    framework::dataset::make("export_to_cl_image_rhs", false)),
716                                                                    framework::dataset::make("DataType", DataType::F32)),
717                                                                    a_values_precommit),
718                                                                    beta_values_precommit),
719                                                                    framework::dataset::make("broadcast_bias", { true } )),
720                                                                    lhs_transpose_values),
721                                                                    act_values),
722                                                                    post_op_lists)
723                                                                    )
724 {
725     // Validate output
726     if(validate_result)
727     {
728         validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
729     }
730     else
731     {
732         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
733         framework::ARM_COMPUTE_PRINT_INFO();
734     }
735 }
736 
737 TEST_SUITE_END() //  FusedPostOps
738 
TEST_SUITE(ExportToCLImage)739 TEST_SUITE(ExportToCLImage)
740 DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zip(
741                framework::dataset::make("Input0Info", { TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32),  // OK or incorrect if cl_khr_image2d_from_buffer not supported
742                                                         TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32),  // OK or incorrect if cl_khr_image2d_from_buffer not supported
743                                                         TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32),  // OK or incorrect if cl_khr_image2d_from_buffer not supported
744                                                         TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32),  // Incorrect k0
745                                                         TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32),  // Incorrect n0
746 
747                                                       }),
748                framework::dataset::make("Input1Info",{ TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32),
749                                                        TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32),
750                                                        TensorInfo(TensorShape(512U, 8U, 2U), 1, DataType::F32),
751                                                        TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F32),
752                                                        TensorInfo(TensorShape(128U, 32U, 2U), 1, DataType::F32),
753 
754                       })),
755                framework::dataset::make("Input2Info", { TensorInfo(TensorShape(64U), 1, DataType::F32),
756                                                         TensorInfo(TensorShape(64U), 1, DataType::F32),
757                                                         TensorInfo(TensorShape(64U), 1, DataType::F32),
758                                                         TensorInfo(TensorShape(64U), 1, DataType::F32),
759                                                         TensorInfo(TensorShape(64U), 1, DataType::F32),
760 
761                                                       })),
762                framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F32),
763                                                        TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F32),
764                                                        TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F32),
765                                                        TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F32),
766                                                        TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F32),
767                                                        TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F32),
768 
769                            })),
770                framework::dataset::make("LHSMInfo",{
771                                                           GEMMLHSMatrixInfo(4, 4, 1, false, true),
772                                                           GEMMLHSMatrixInfo(4, 8, 1, false, true),
773                                                           GEMMLHSMatrixInfo(4, 4, 1, false, true),
774                                                           GEMMLHSMatrixInfo(4, 2, 1, false, false),
775                                                           GEMMLHSMatrixInfo(4, 4, 1, false, false),
776 
777                                 })),
778                framework::dataset::make("RHSMInfo",{
779                                                           GEMMRHSMatrixInfo(4, 4, 1, true, true, true),
780                                                           GEMMRHSMatrixInfo(4, 8, 1, true, true, true),
781                                                           GEMMRHSMatrixInfo(8, 4, 1, true, true, true),
782                                                           GEMMRHSMatrixInfo(4, 2, 1, true, false, true),
783                                                           GEMMRHSMatrixInfo(2, 4, 1, true, false, true),
784                            })),
785                framework::dataset::make("GEMMInfo",{GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
786                                                                     64 /**<N Number of RHS columns*/,
787                                                                     64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
788                                                              false /**< reinterpret the input as 3D */,
789                                                              true  /**< Flag used to broadcast the bias addition */,
790                                                              false /**< wider accumm */,
791                                                              false /**< has pad y */,
792                                                            ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
793                                                              1   /**< Multiplication factor for the width of the 1xW transposed block */,
794                                                              1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
795                                                              GEMMLHSMatrixInfo(),
796                                                              GEMMRHSMatrixInfo(),
797                                                              0  /**< Offset to be added to each element of the matrix A */,
798                                                              0 /**< Offset to be added to each element of the matrix B */),
799                                                     GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
800                                                                     64 /**<N Number of RHS columns*/,
801                                                                     64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
802                                                              false /**< reinterpret the input as 3D */,
803                                                              true  /**< Flag used to broadcast the bias addition */,
804                                                              false /**< wider accumm */,
805                                                              false /**< has pad y */,
806                                                            ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
807                                                              1   /**< Multiplication factor for the width of the 1xW transposed block */,
808                                                              1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
809                                                              GEMMLHSMatrixInfo(),
810                                                              GEMMRHSMatrixInfo(),
811                                                              0  /**< Offset to be added to each element of the matrix A */,
812                                                              0 /**< Offset to be added to each element of the matrix B */),
813                                                     GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
814                                                                     64 /**<N Number of RHS columns*/,
815                                                                     64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
816                                                              false /**< reinterpret the input as 3D */,
817                                                              true  /**< Flag used to broadcast the bias addition */,
818                                                              false /**< wider accumm */,
819                                                              false /**< has pad y */,
820                                                            ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
821                                                              1   /**< Multiplication factor for the width of the 1xW transposed block */,
822                                                              1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
823                                                              GEMMLHSMatrixInfo(),
824                                                              GEMMRHSMatrixInfo(),
825                                                              0  /**< Offset to be added to each element of the matrix A */,
826                                                              0 /**< Offset to be added to each element of the matrix B */),
827 
828                                                     GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
829                                                                     64 /**<N Number of RHS columns*/,
830                                                                     64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
831                                                              false /**< reinterpret the input as 3D */,
832                                                              true  /**< Flag used to broadcast the bias addition */,
833                                                              false /**< wider accumm */,
834                                                              false /**< has pad y */,
835                                                            ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
836                                                              1   /**< Multiplication factor for the width of the 1xW transposed block */,
837                                                              1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
838                                                              GEMMLHSMatrixInfo(),
839                                                              GEMMRHSMatrixInfo(),
840                                                              0  /**< Offset to be added to each element of the matrix A */,
841                                                              0 /**< Offset to be added to each element of the matrix B */),
842                                                     GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
843                                                                     64 /**<N Number of RHS columns*/,
844                                                                     64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
845                                                              false /**< reinterpret the input as 3D */,
846                                                              true  /**< Flag used to broadcast the bias addition */,
847                                                              false /**< wider accumm */,
848                                                              false /**< has pad y */,
849                                                            ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
850                                                              1   /**< Multiplication factor for the width of the 1xW transposed block */,
851                                                              1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
852                                                              GEMMLHSMatrixInfo(),
853                                                              GEMMRHSMatrixInfo(),
854                                                              0  /**< Offset to be added to each element of the matrix A */,
855                                                              0 /**< Offset to be added to each element of the matrix B */)
856                                                     })),
857                framework::dataset::make("Expected", { true,
858                                                       true,
859                                                       true,
860                                                       false,
861                                                       true})),
862                     input0_info ,input1_info, input2_info, output_info, lhs_info, rhs_info, gemm_info, expected)
863 {
864    ARM_COMPUTE_EXPECT(bool(ClGemmMatrixMultiplyReshapedKernel::validate(&input0_info.clone()->set_is_resizable(true),
865                                                           &input1_info.clone()->set_is_resizable(true),
866                                                           &input2_info.clone()->set_is_resizable(true),
867                                                           &output_info.clone()->set_is_resizable(true),1.f,1.f,
868                                                           lhs_info,
869                                                           rhs_info,
870                                                           gemm_info)) == (expected && image2d_from_buffer_supported(CLKernelLibrary::get().get_device())), framework::LogLevel::ERRORS);
871 }
872 
873 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::ALL,
874                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
875                                                                    m_values,
876                                                                    n_values),
877                                                                    k_values),
878                                                                    b_values),
879                                                                    m0_values_precommit),
880                                                                    n0_values_precommit),
881                                                                    k0_values_precommit),
882                                                                    v0_values_precommit),
883                                                                    h0_values_precommit),
884                                                                    i_values_lhs),
885                                                                    i_values_rhs),
886                                                                    framework::dataset::make("export_to_cl_image_rhs", true)),
887                                                                    framework::dataset::make("DataType", DataType::F32)),
888                                                                    a_values_precommit),
889                                                                    beta_values_precommit),
890                                                                    broadcast_bias_values),
891                                                                    lhs_transpose_values),
892                                                                    act_values))
893 {
894      // Validate output only if validate() is successful
895     if(validate_result)
896     {
897         validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
898     }
899     else
900     {
901         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
902         framework::ARM_COMPUTE_PRINT_INFO();
903     }
904 
905 }
906 
907 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::NIGHTLY,
908                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
909                                                                    m_values,
910                                                                    n_values),
911                                                                    k_values),
912                                                                    b_values),
913                                                                    m0_values_nightly),
914                                                                    n0_export_to_cl_image_values_nightly),
915                                                                    k0_export_to_cl_image_values_nightly),
916                                                                    v0_values_nightly),
917                                                                    h0_values_nightly),
918                                                                    i_values_lhs),
919                                                                    i_values_rhs),
920                                                                    framework::dataset::make("export_to_cl_image_rhs", true)),
921                                                                    framework::dataset::make("DataType", DataType::F32)),
922                                                                    a_values_nightly),
923                                                                    beta_values_nightly),
924                                                                    broadcast_bias_values),
925                                                                    lhs_transpose_values),
926                                                                    act_values))
927 {
928      // Validate output only if validate() is successful
929     if(validate_result)
930     {
931         validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
932     }
933     else
934     {
935         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
936         framework::ARM_COMPUTE_PRINT_INFO();
937     }
938 }
939 
940 FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::ALL,
941                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
942                                                                    m_w_values,
943                                                                    m_h_values),
944                                                                    n_values),
945                                                                    k_values),
946                                                                    b_values),
947                                                                    m0_values_precommit),
948                                                                    n0_values_precommit),
949                                                                    k0_values_precommit),
950                                                                    v0_values_precommit),
951                                                                    h0_values_precommit),
952                                                                    i_values_lhs),
953                                                                    i_values_rhs),
954                                                                    framework::dataset::make("export_to_cl_image_rhs", true)),
955                                                                    framework::dataset::make("DataType", DataType::F32)),
956                                                                    a_values_precommit),
957                                                                    beta_values_precommit),
958                                                                    lhs_transpose_values),
959                                                                    act_values))
960 {
961      // Validate output only if validate() is successful
962     if(validate_result)
963     {
964         validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
965     }
966     else
967     {
968         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
969         framework::ARM_COMPUTE_PRINT_INFO();
970     }
971 }
972 
973 FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::NIGHTLY,
974                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
975                                                                    m_w_values,
976                                                                    m_h_values),
977                                                                    n_values),
978                                                                    k_values),
979                                                                    b_values),
980                                                                    m0_values_nightly),
981                                                                    n0_export_to_cl_image_values_nightly),
982                                                                    k0_export_to_cl_image_values_nightly),
983                                                                    v0_values_nightly),
984                                                                    h0_values_nightly),
985                                                                    i_values_lhs),
986                                                                    i_values_rhs),
987                                                                    framework::dataset::make("export_to_cl_image_rhs", true)),
988                                                                    framework::dataset::make("DataType", DataType::F32)),
989                                                                    a_values_nightly),
990                                                                    beta_values_nightly),
991                                                                    lhs_transpose_values),
992                                                                    act_values))
993 {
994     // Validate output only if validate() is successful
995     if(validate_result)
996     {
997         validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
998     }
999     else
1000     {
1001         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1002         framework::ARM_COMPUTE_PRINT_INFO();
1003     }
1004 }
1005 TEST_SUITE(FusedPostOps)
1006 
1007 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedWithPostOpsFixture<float>, framework::DatasetMode::ALL,
1008                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1009                                                                    m_values,
1010                                                                    n_values),
1011                                                                    k_values),
1012                                                                    b_values),
1013                                                                    m0_values_precommit),
1014                                                                    n0_values_precommit),
1015                                                                    k0_values_precommit),
1016                                                                    v0_values_precommit),
1017                                                                    h0_values_precommit),
1018                                                                    framework::dataset::make("interleave_lhs", { false })),
1019                                                                    framework::dataset::make("interleave_rhs", { false })),
1020                                                                    framework::dataset::make("export_to_cl_image_rhs", true)),
1021                                                                    framework::dataset::make("DataType", DataType::F32)),
1022                                                                    a_values_precommit),
1023                                                                    beta_values_precommit),
1024                                                                    framework::dataset::make("broadcast_bias", { true } )),
1025                                                                    lhs_transpose_values),
1026                                                                    act_values),
1027                                                                    post_op_lists)
1028                                                                    )
1029 {
1030     // Validate output only if validate() is successful
1031     if(validate_result)
1032     {
1033         validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
1034     }
1035     else
1036     {
1037         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1038         framework::ARM_COMPUTE_PRINT_INFO();
1039     }
1040 }
1041 
1042 TEST_SUITE_END() //  FusedPostOps
1043 
TEST_SUITE_END()1044 TEST_SUITE_END() // ExportToCLImage
1045 TEST_SUITE_END() // FP32
1046 
1047 TEST_SUITE(FP16)
1048 
1049 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::ALL,
1050                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1051                                                                    m_values,
1052                                                                    n_values),
1053                                                                    k_values),
1054                                                                    b_values),
1055                                                                    m0_values_precommit),
1056                                                                    n0_values_precommit),
1057                                                                    k0_values_precommit),
1058                                                                    v0_values_precommit),
1059                                                                    h0_values_precommit),
1060                                                                    i_values_lhs),
1061                                                                    i_values_rhs),
1062                                                                    framework::dataset::make("export_to_cl_image_rhs", false)),
1063                                                                    framework::dataset::make("DataType", DataType::F16)),
1064                                                                    a_values_precommit),
1065                                                                    beta_values_precommit),
1066                                                                    broadcast_bias_values),
1067                                                                    lhs_transpose_values),
1068                                                                    act_values))
1069 {
1070     // Validate output
1071     if(validate_result)
1072     {
1073         validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1074     }
1075     else
1076     {
1077         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1078         framework::ARM_COMPUTE_PRINT_INFO();
1079     }
1080 }
1081 
1082 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::DISABLED,
1083                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1084                                                                    m_values,
1085                                                                    n_values),
1086                                                                    k_values),
1087                                                                    b_values),
1088                                                                    m0_values_nightly),
1089                                                                    n0_values_nightly),
1090                                                                    k0_values_nightly),
1091                                                                    v0_values_nightly),
1092                                                                    h0_values_nightly),
1093                                                                    i_values_lhs),
1094                                                                    i_values_rhs),
1095                                                                    framework::dataset::make("export_to_cl_image_rhs", false)),
1096                                                                    framework::dataset::make("DataType", DataType::F16)),
1097                                                                    a_values_nightly),
1098                                                                    beta_values_nightly),
1099                                                                    broadcast_bias_values),
1100                                                                    lhs_transpose_values),
1101                                                                    act_values))
1102 {
1103     // Validate output
1104     if(validate_result)
1105     {
1106         validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1107     }
1108     else
1109     {
1110         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1111         framework::ARM_COMPUTE_PRINT_INFO();
1112     }
1113 }
1114 
1115 FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, framework::DatasetMode::ALL,
1116                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1117                                                                    m_w_values,
1118                                                                    m_h_values),
1119                                                                    n_values),
1120                                                                    k_values),
1121                                                                    b_values),
1122                                                                    m0_values_precommit),
1123                                                                    n0_values_precommit),
1124                                                                    k0_values_precommit),
1125                                                                    v0_values_precommit),
1126                                                                    h0_values_precommit),
1127                                                                    i_values_lhs),
1128                                                                    i_values_rhs),
1129                                                                    framework::dataset::make("export_to_cl_image_rhs", false)),
1130                                                                    framework::dataset::make("DataType", DataType::F16)),
1131                                                                    a_values_precommit),
1132                                                                    beta_values_precommit),
1133                                                                    lhs_transpose_values),
1134                                                                    act_values))
1135 {
1136     // Validate output
1137     if(validate_result)
1138     {
1139         validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1140     }
1141     else
1142     {
1143         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1144         framework::ARM_COMPUTE_PRINT_INFO();
1145     }
1146 }
1147 
1148 FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, framework::DatasetMode::DISABLED,
1149                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1150                                                                    m_w_values,
1151                                                                    m_h_values),
1152                                                                    n_values),
1153                                                                    k_values),
1154                                                                    b_values),
1155                                                                    m0_values_nightly),
1156                                                                    n0_values_nightly),
1157                                                                    k0_values_nightly),
1158                                                                    v0_values_nightly),
1159                                                                    h0_values_nightly),
1160                                                                    i_values_lhs),
1161                                                                    i_values_rhs),
1162                                                                    framework::dataset::make("export_to_cl_image_rhs", false)),
1163                                                                    framework::dataset::make("DataType", DataType::F16)),
1164                                                                    a_values_nightly),
1165                                                                    beta_values_nightly),
1166                                                                    lhs_transpose_values),
1167                                                                    act_values))
1168 {
1169     // Validate output
1170     if(validate_result)
1171     {
1172         validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1173     }
1174     else
1175     {
1176         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1177         framework::ARM_COMPUTE_PRINT_INFO();
1178     }
1179 }
1180 
1181 TEST_SUITE(FusedPostOps)
1182 
1183 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedWithPostOpsFixture<half>, framework::DatasetMode::ALL,
1184                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1185                                                                    m_values,
1186                                                                    n_values),
1187                                                                    k_values),
1188                                                                    b_values),
1189                                                                    m0_values_precommit),
1190                                                                    n0_values_precommit),
1191                                                                    k0_values_precommit),
1192                                                                    v0_values_precommit),
1193                                                                    h0_values_precommit),
1194                                                                    framework::dataset::make("interleave_lhs", { false })),
1195                                                                    framework::dataset::make("interleave_rhs", { false })),
1196                                                                    framework::dataset::make("export_to_cl_image_rhs", false)),
1197                                                                    framework::dataset::make("DataType", DataType::F16)),
1198                                                                    a_values_precommit),
1199                                                                    beta_values_precommit),
1200                                                                    framework::dataset::make("broadcast_bias", { true } )),
1201                                                                    lhs_transpose_values),
1202                                                                    act_values),
1203                                                                    post_op_lists)
1204                                                                    )
1205 {
1206     // Validate output
1207     if(validate_result)
1208     {
1209         validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1210     }
1211     else
1212     {
1213         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1214         framework::ARM_COMPUTE_PRINT_INFO();
1215     }
1216 }
1217 
1218 TEST_SUITE_END() //  FusedPostOps
1219 
TEST_SUITE(ExportToCLImage)1220 TEST_SUITE(ExportToCLImage)
1221 DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zip(
1222                framework::dataset::make("Input0Info", { TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16),  // OK or incorrect if cl_khr_image2d_from_buffer not supported
1223                                                         TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16),  // OK or incorrect if cl_khr_image2d_from_buffer not supported
1224                                                         TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16),  // OK or incorrect if cl_khr_image2d_from_buffer not supported
1225                                                         TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16),  // Incorrect k0
1226                                                         TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16),  // Incorrect n0
1227 
1228                                                       }),
1229                framework::dataset::make("Input1Info",{ TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16),
1230                                                        TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16),
1231                                                        TensorInfo(TensorShape(512U, 8U, 2U), 1, DataType::F16),
1232                                                        TensorInfo(TensorShape(256U, 16U, 2U), 1, DataType::F16),
1233                                                        TensorInfo(TensorShape(128U, 32U, 2U), 1, DataType::F16),
1234 
1235                       })),
1236                framework::dataset::make("Input2Info", { TensorInfo(TensorShape(64U), 1, DataType::F16),
1237                                                         TensorInfo(TensorShape(64U), 1, DataType::F16),
1238                                                         TensorInfo(TensorShape(64U), 1, DataType::F16),
1239                                                         TensorInfo(TensorShape(64U), 1, DataType::F16),
1240                                                         TensorInfo(TensorShape(64U), 1, DataType::F16),
1241 
1242                                                       })),
1243                framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
1244                                                        TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
1245                                                        TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
1246                                                        TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
1247                                                        TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
1248                                                        TensorInfo(TensorShape(64U, 64U, 2U), 1, DataType::F16),
1249 
1250                            })),
1251                framework::dataset::make("LHSMInfo",{
1252                                                           GEMMLHSMatrixInfo(4, 4, 1, false, true),
1253                                                           GEMMLHSMatrixInfo(4, 8, 1, false, true),
1254                                                           GEMMLHSMatrixInfo(4, 4, 1, false, true),
1255                                                           GEMMLHSMatrixInfo(4, 2, 1, false, false),
1256                                                           GEMMLHSMatrixInfo(4, 4, 1, false, false),
1257 
1258                                 })),
1259                framework::dataset::make("RHSMInfo",{
1260                                                           GEMMRHSMatrixInfo(4, 4, 1, true, true, true),
1261                                                           GEMMRHSMatrixInfo(4, 8, 1, true, true, true),
1262                                                           GEMMRHSMatrixInfo(8, 4, 1, true, true, true),
1263                                                           GEMMRHSMatrixInfo(4, 2, 1, true, false, true),
1264                                                           GEMMRHSMatrixInfo(2, 4, 1, true, false, true),
1265                            })),
1266                framework::dataset::make("GEMMInfo",{GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
1267                                                                     64 /**<N Number of RHS columns*/,
1268                                                                     64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
1269                                                              false /**< reinterpret the input as 3D */,
1270                                                              true  /**< Flag used to broadcast the bias addition */,
1271                                                              false /**< wider accumm */,
1272                                                              false /**< has pad y */,
1273                                                            ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
1274                                                              1   /**< Multiplication factor for the width of the 1xW transposed block */,
1275                                                              1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
1276                                                              GEMMLHSMatrixInfo(),
1277                                                              GEMMRHSMatrixInfo(),
1278                                                              0  /**< Offset to be added to each element of the matrix A */,
1279                                                              0 /**< Offset to be added to each element of the matrix B */),
1280                                                     GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
1281                                                                     64 /**<N Number of RHS columns*/,
1282                                                                     64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
1283                                                              false /**< reinterpret the input as 3D */,
1284                                                              true  /**< Flag used to broadcast the bias addition */,
1285                                                              false /**< wider accumm */,
1286                                                              false /**< has pad y */,
1287                                                            ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
1288                                                              1   /**< Multiplication factor for the width of the 1xW transposed block */,
1289                                                              1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
1290                                                              GEMMLHSMatrixInfo(),
1291                                                              GEMMRHSMatrixInfo(),
1292                                                              0  /**< Offset to be added to each element of the matrix A */,
1293                                                              0 /**< Offset to be added to each element of the matrix B */),
1294                                                     GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
1295                                                                     64 /**<N Number of RHS columns*/,
1296                                                                     64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
1297                                                              false /**< reinterpret the input as 3D */,
1298                                                              true  /**< Flag used to broadcast the bias addition */,
1299                                                              false /**< wider accumm */,
1300                                                              false /**< has pad y */,
1301                                                            ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
1302                                                              1   /**< Multiplication factor for the width of the 1xW transposed block */,
1303                                                              1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
1304                                                              GEMMLHSMatrixInfo(),
1305                                                              GEMMRHSMatrixInfo(),
1306                                                              0  /**< Offset to be added to each element of the matrix A */,
1307                                                              0 /**< Offset to be added to each element of the matrix B */),
1308 
1309                                                     GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
1310                                                                     64 /**<N Number of RHS columns*/,
1311                                                                     64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
1312                                                              false /**< reinterpret the input as 3D */,
1313                                                              true  /**< Flag used to broadcast the bias addition */,
1314                                                              false /**< wider accumm */,
1315                                                              false /**< has pad y */,
1316                                                            ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
1317                                                              1   /**< Multiplication factor for the width of the 1xW transposed block */,
1318                                                              1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
1319                                                              GEMMLHSMatrixInfo(),
1320                                                              GEMMRHSMatrixInfo(),
1321                                                              0  /**< Offset to be added to each element of the matrix A */,
1322                                                              0 /**< Offset to be added to each element of the matrix B */),
1323                                                     GEMMKernelInfo( 64 /**<M Number of LHS rows*/,
1324                                                                     64 /**<N Number of RHS columns*/,
1325                                                                     64 /**<K Number of LHS columns or RHS rows */, 0 /**< Depth of the output tensor in case is reinterpreted as 3D */,
1326                                                              false /**< reinterpret the input as 3D */,
1327                                                              true  /**< Flag used to broadcast the bias addition */,
1328                                                              false /**< wider accumm */,
1329                                                              false /**< has pad y */,
1330                                                            ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
1331                                                              1   /**< Multiplication factor for the width of the 1xW transposed block */,
1332                                                              1   /**< Multiplication factor for the height of the 4x4 interleaved block */,
1333                                                              GEMMLHSMatrixInfo(),
1334                                                              GEMMRHSMatrixInfo(),
1335                                                              0  /**< Offset to be added to each element of the matrix A */,
1336                                                              0 /**< Offset to be added to each element of the matrix B */)
1337                                                     })),
1338                framework::dataset::make("Expected", { true,
1339                                                       true,
1340                                                       true,
1341                                                       false,
1342                                                       true})),
1343                     input0_info ,input1_info, input2_info, output_info, lhs_info, rhs_info, gemm_info, expected)
1344 {
1345    ARM_COMPUTE_EXPECT(bool(ClGemmMatrixMultiplyReshapedKernel::validate(&input0_info.clone()->set_is_resizable(true),
1346                                                           &input1_info.clone()->set_is_resizable(true),
1347                                                           &input2_info.clone()->set_is_resizable(true),
1348                                                           &output_info.clone()->set_is_resizable(true),1.f,1.f,
1349                                                           lhs_info,
1350                                                           rhs_info,
1351                                                           gemm_info)) == (expected && image2d_from_buffer_supported(CLKernelLibrary::get().get_device())), framework::LogLevel::ERRORS);
1352 }
1353 
1354 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::ALL,
1355                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1356                                                                    m_values,
1357                                                                    n_values),
1358                                                                    k_values),
1359                                                                    b_values),
1360                                                                    m0_values_precommit),
1361                                                                    n0_values_precommit),
1362                                                                    k0_values_precommit),
1363                                                                    v0_values_precommit),
1364                                                                    h0_values_precommit),
1365                                                                    i_values_lhs),
1366                                                                    i_values_rhs),
1367                                                                    framework::dataset::make("export_to_cl_image_rhs", true)),
1368                                                                    framework::dataset::make("DataType", DataType::F16)),
1369                                                                    a_values_precommit),
1370                                                                    beta_values_precommit),
1371                                                                    broadcast_bias_values),
1372                                                                    lhs_transpose_values),
1373                                                                    act_values))
1374 {
1375     // Validate output only if validate() is successful
1376     if(validate_result)
1377     {
1378         validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1379     }
1380     else
1381     {
1382         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1383         framework::ARM_COMPUTE_PRINT_INFO();
1384     }
1385 
1386 }
1387 
1388 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::NIGHTLY,
1389                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1390                                                                    m_values,
1391                                                                    n_values),
1392                                                                    k_values),
1393                                                                    b_values),
1394                                                                    m0_values_nightly),
1395                                                                    n0_export_to_cl_image_values_nightly),
1396                                                                    k0_export_to_cl_image_values_nightly),
1397                                                                    v0_values_nightly),
1398                                                                    h0_values_nightly),
1399                                                                    i_values_lhs),
1400                                                                    i_values_rhs),
1401                                                                    framework::dataset::make("export_to_cl_image_rhs", true)),
1402                                                                    framework::dataset::make("DataType", DataType::F16)),
1403                                                                    a_values_nightly),
1404                                                                    beta_values_nightly),
1405                                                                    broadcast_bias_values),
1406                                                                    lhs_transpose_values),
1407                                                                    act_values))
1408 {
1409     // Validate output only if validate() is successful
1410     if(validate_result)
1411     {
1412         validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1413     }
1414     else
1415     {
1416         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1417         framework::ARM_COMPUTE_PRINT_INFO();
1418     }
1419 }
1420 
1421 FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, framework::DatasetMode::ALL,
1422                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1423                                                                    m_w_values,
1424                                                                    m_h_values),
1425                                                                    n_values),
1426                                                                    k_values),
1427                                                                    b_values),
1428                                                                    m0_values_precommit),
1429                                                                    n0_values_precommit),
1430                                                                    k0_values_precommit),
1431                                                                    v0_values_precommit),
1432                                                                    h0_values_precommit),
1433                                                                    i_values_lhs),
1434                                                                    i_values_rhs),
1435                                                                    framework::dataset::make("export_to_cl_image_rhs", true)),
1436                                                                    framework::dataset::make("DataType", DataType::F16)),
1437                                                                    a_values_precommit),
1438                                                                    beta_values_precommit),
1439                                                                    lhs_transpose_values),
1440                                                                    act_values))
1441 {
1442     // Validate output only if validate() is successful
1443     if(validate_result)
1444     {
1445         validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1446     }
1447     else
1448     {
1449         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1450         framework::ARM_COMPUTE_PRINT_INFO();
1451     }
1452 }
1453 
1454 FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, framework::DatasetMode::NIGHTLY,
1455                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1456                                                                    m_w_values,
1457                                                                    m_h_values),
1458                                                                    n_values),
1459                                                                    k_values),
1460                                                                    b_values),
1461                                                                    m0_values_nightly),
1462                                                                    n0_export_to_cl_image_values_nightly),
1463                                                                    k0_export_to_cl_image_values_nightly),
1464                                                                    v0_values_nightly),
1465                                                                    h0_values_nightly),
1466                                                                    i_values_lhs),
1467                                                                    i_values_rhs),
1468                                                                    framework::dataset::make("export_to_cl_image_rhs", true)),
1469                                                                    framework::dataset::make("DataType", DataType::F16)),
1470                                                                    a_values_nightly),
1471                                                                    beta_values_nightly),
1472                                                                    lhs_transpose_values),
1473                                                                    act_values))
1474 {
1475     // Validate output only if validate() is successful
1476     if(validate_result)
1477     {
1478         validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1479     }
1480     else
1481     {
1482         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1483         framework::ARM_COMPUTE_PRINT_INFO();
1484     }
1485 }
1486 TEST_SUITE(FusedPostOps)
1487 
1488 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedWithPostOpsFixture<half>, framework::DatasetMode::ALL,
1489                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1490                                                                    m_values,
1491                                                                    n_values),
1492                                                                    k_values),
1493                                                                    b_values),
1494                                                                    m0_values_precommit),
1495                                                                    n0_values_precommit),
1496                                                                    k0_values_precommit),
1497                                                                    v0_values_precommit),
1498                                                                    h0_values_precommit),
1499                                                                    framework::dataset::make("interleave_lhs", { false })),
1500                                                                    framework::dataset::make("interleave_rhs", { false })),
1501                                                                    framework::dataset::make("export_to_cl_image_rhs", true)),
1502                                                                    framework::dataset::make("DataType", DataType::F16)),
1503                                                                    a_values_precommit),
1504                                                                    beta_values_precommit),
1505                                                                    framework::dataset::make("broadcast_bias", { true } )),
1506                                                                    lhs_transpose_values),
1507                                                                    act_values),
1508                                                                    post_op_lists)
1509                                                                    )
1510 {
1511     // Validate output only if validate() is successful
1512     if(validate_result)
1513     {
1514         validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
1515     }
1516     else
1517     {
1518         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1519         framework::ARM_COMPUTE_PRINT_INFO();
1520     }
1521 }
1522 
1523 TEST_SUITE_END() //  FusedPostOps
1524 
TEST_SUITE_END()1525 TEST_SUITE_END() // ExportToCLImage
1526 TEST_SUITE_END() // FP16
1527 
1528 TEST_SUITE(MixedPrecision)
1529 
1530 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture<half>, framework::DatasetMode::ALL,
1531                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1532                                                                    m_values,
1533                                                                    n_values),
1534                                                                    k_values),
1535                                                                    b_values),
1536                                                                    m0_values_precommit),
1537                                                                    n0_values_precommit),
1538                                                                    k0_values_precommit),
1539                                                                    v0_values_precommit),
1540                                                                    h0_values_precommit),
1541                                                                    i_values_lhs),
1542                                                                    i_values_rhs),
1543                                                                    framework::dataset::make("export_to_cl_image_rhs", false)),
1544                                                                    framework::dataset::make("DataType", DataType::F16)),
1545                                                                    a_values_precommit),
1546                                                                    beta_values_precommit),
1547                                                                    broadcast_bias_values),
1548                                                                    lhs_transpose_values),
1549                                                                    act_values))
1550 {
1551     // Validate output
1552     if(validate_result)
1553     {
1554         validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1555     }
1556     else
1557     {
1558         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1559         framework::ARM_COMPUTE_PRINT_INFO();
1560     }
1561 }
1562 
1563 FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture<half>, framework::DatasetMode::DISABLED,
1564                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1565                                                                    m_values,
1566                                                                    n_values),
1567                                                                    k_values),
1568                                                                    b_values),
1569                                                                    m0_values_nightly),
1570                                                                    n0_values_nightly),
1571                                                                    k0_values_nightly),
1572                                                                    v0_values_nightly),
1573                                                                    h0_values_nightly),
1574                                                                    i_values_lhs),
1575                                                                    i_values_rhs),
1576                                                                    framework::dataset::make("export_to_cl_image_rhs", false)),
1577                                                                    framework::dataset::make("DataType", DataType::F16)),
1578                                                                    a_values_nightly),
1579                                                                    beta_values_nightly),
1580                                                                    broadcast_bias_values),
1581                                                                    lhs_transpose_values),
1582                                                                    act_values))
1583 {
1584     // Validate output
1585     if(validate_result)
1586     {
1587         validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1588     }
1589     else
1590     {
1591         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1592         framework::ARM_COMPUTE_PRINT_INFO();
1593     }
1594 }
1595 
1596 FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture<half>, framework::DatasetMode::ALL,
1597                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1598                                                                    m_w_values,
1599                                                                    m_h_values),
1600                                                                    n_values),
1601                                                                    k_values),
1602                                                                    b_values),
1603                                                                    m0_values_precommit),
1604                                                                    n0_values_precommit),
1605                                                                    k0_values_precommit),
1606                                                                    v0_values_precommit),
1607                                                                    h0_values_precommit),
1608                                                                    i_values_lhs),
1609                                                                    i_values_rhs),
1610                                                                    framework::dataset::make("export_to_cl_image_rhs", false)),
1611                                                                    framework::dataset::make("DataType", DataType::F16)),
1612                                                                    a_values_precommit),
1613                                                                    beta_values_precommit),
1614                                                                    lhs_transpose_values),
1615                                                                    act_values))
1616 {
1617     // Validate output
1618     if(validate_result)
1619     {
1620         validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1621     }
1622     else
1623     {
1624         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1625         framework::ARM_COMPUTE_PRINT_INFO();
1626     }
1627 }
1628 
1629 FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture<half>, framework::DatasetMode::DISABLED,
1630                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1631                                                                    m_w_values,
1632                                                                    m_h_values),
1633                                                                    n_values),
1634                                                                    k_values),
1635                                                                    b_values),
1636                                                                    m0_values_nightly),
1637                                                                    n0_values_nightly),
1638                                                                    k0_values_nightly),
1639                                                                    v0_values_nightly),
1640                                                                    h0_values_nightly),
1641                                                                    i_values_lhs),
1642                                                                    i_values_rhs),
1643                                                                    framework::dataset::make("export_to_cl_image_rhs", false)),
1644                                                                    framework::dataset::make("DataType", DataType::F16)),
1645                                                                    a_values_nightly),
1646                                                                    beta_values_nightly),
1647                                                                    lhs_transpose_values),
1648                                                                    act_values))
1649 {
1650     // Validate output
1651     if(validate_result)
1652     {
1653         validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1654     }
1655     else
1656     {
1657         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1658         framework::ARM_COMPUTE_PRINT_INFO();
1659     }
1660 }
1661 
1662 TEST_SUITE(FusedPostOps)
1663 
1664 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedMixedPrecisionWithPostOpsFixture<half>, framework::DatasetMode::ALL,
1665                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
1666                                                                    m_values,
1667                                                                    n_values),
1668                                                                    k_values),
1669                                                                    b_values),
1670                                                                    m0_values_precommit),
1671                                                                    n0_values_precommit),
1672                                                                    k0_values_precommit),
1673                                                                    v0_values_precommit),
1674                                                                    h0_values_precommit),
1675                                                                    framework::dataset::make("interleave_lhs", { false })),
1676                                                                    framework::dataset::make("interleave_rhs", { false })),
1677                                                                    framework::dataset::make("export_to_cl_image_rhs", { true, false })),
1678                                                                    framework::dataset::make("DataType", DataType::F16)),
1679                                                                    a_values_precommit),
1680                                                                    beta_values_precommit),
1681                                                                    framework::dataset::make("broadcast_bias", { true } )),
1682                                                                    lhs_transpose_values),
1683                                                                    act_values),
1684                                                                    post_op_lists)
1685                                                                    )
1686 {
1687     // Validate output
1688     if(validate_result)
1689     {
1690         validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
1691     }
1692     else
1693     {
1694         ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
1695         framework::ARM_COMPUTE_PRINT_INFO();
1696     }
1697 }
1698 
1699 TEST_SUITE_END() // FusedPostOps
1700 
1701 TEST_SUITE_END() // MixedPrecision
1702 TEST_SUITE_END() // Float
1703 TEST_SUITE_END() // GEMMMatrixMultiplyReshaped
1704 TEST_SUITE_END() // CL
1705 } // namespace validation
1706 } // namespace test
1707 } // namespace arm_compute
1708