1*c217d954SCole Faust /*
2*c217d954SCole Faust * Copyright (c) 2017-2023 Arm Limited.
3*c217d954SCole Faust *
4*c217d954SCole Faust * SPDX-License-Identifier: MIT
5*c217d954SCole Faust *
6*c217d954SCole Faust * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust *
13*c217d954SCole Faust * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust * copies or substantial portions of the Software.
15*c217d954SCole Faust *
16*c217d954SCole Faust * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust * SOFTWARE.
23*c217d954SCole Faust */
24*c217d954SCole Faust #include "arm_compute/core/Types.h"
25*c217d954SCole Faust #include "arm_compute/runtime/NEON/functions/NEGEMM.h"
26*c217d954SCole Faust #include "arm_compute/runtime/Tensor.h"
27*c217d954SCole Faust #include "arm_compute/runtime/TensorAllocator.h"
28*c217d954SCole Faust #include "src/core/helpers/MemoryHelpers.h"
29*c217d954SCole Faust #include "src/cpu/kernels/CpuGemmInterleave4x4Kernel.h"
30*c217d954SCole Faust #include "src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h"
31*c217d954SCole Faust #include "src/cpu/kernels/CpuGemmTranspose1xWKernel.h"
32*c217d954SCole Faust #include "src/cpu/operators/CpuGemm.h"
33*c217d954SCole Faust #include "tests/NEON/Accessor.h"
34*c217d954SCole Faust #include "tests/NEON/Helper.h"
35*c217d954SCole Faust #include "tests/PaddingCalculator.h"
36*c217d954SCole Faust #include "tests/datasets/LargeGEMMDataset.h"
37*c217d954SCole Faust #include "tests/datasets/SmallGEMMDataset.h"
38*c217d954SCole Faust #include "tests/datasets/TinyGEMMDataset.h"
39*c217d954SCole Faust #include "tests/framework/Asserts.h"
40*c217d954SCole Faust #include "tests/framework/Macros.h"
41*c217d954SCole Faust #include "tests/framework/datasets/Datasets.h"
42*c217d954SCole Faust #include "tests/validation/Validation.h"
43*c217d954SCole Faust #include "tests/validation/fixtures/GEMMFixture.h"
44*c217d954SCole Faust #include "tests/validation/fixtures/GEMMInterleave4x4Fixture.h"
45*c217d954SCole Faust #include "tests/validation/fixtures/GEMMTranspose1xWFixture.h"
46*c217d954SCole Faust
47*c217d954SCole Faust namespace arm_compute
48*c217d954SCole Faust {
49*c217d954SCole Faust namespace test
50*c217d954SCole Faust {
51*c217d954SCole Faust namespace validation
52*c217d954SCole Faust {
53*c217d954SCole Faust namespace
54*c217d954SCole Faust {
55*c217d954SCole Faust constexpr AbsoluteTolerance<float> tolerance_f(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for FP32 data types */
56*c217d954SCole Faust #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
57*c217d954SCole Faust RelativeTolerance<half_float::half> rel_tolerance_f16(half(0.2)); /**< Relative tolerance value for comparing reference's output against implementation's output for FP16 data types */
58*c217d954SCole Faust const AbsoluteTolerance<float> abs_tolerance_f16(0.2f); /**< Absolute tolerance value for comparing reference's output against implementation's output for FP16 data types */
59*c217d954SCole Faust constexpr float tolerance_num = 0.07f; /**< Tolerance number for FP16 data types */
60*c217d954SCole Faust #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
61*c217d954SCole Faust /** CNN data types */
62*c217d954SCole Faust const auto CNNDataTypes = framework::dataset::make("DataType",
63*c217d954SCole Faust {
64*c217d954SCole Faust #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
65*c217d954SCole Faust DataType::F16,
66*c217d954SCole Faust #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
67*c217d954SCole Faust DataType::F32,
68*c217d954SCole Faust });
69*c217d954SCole Faust
70*c217d954SCole Faust const auto data_interleave = framework::dataset::make("M", 8, 12) * framework::dataset::make("N", 8, 12);
71*c217d954SCole Faust const auto data_transpose = framework::dataset::make("M", 8, 14) * framework::dataset::make("N", 7, 14);
72*c217d954SCole Faust
73*c217d954SCole Faust /** Zero padding test */
74*c217d954SCole Faust template <typename FunctionType>
validate_zero_padding(unsigned int dim0_value,unsigned int dim1_value)75*c217d954SCole Faust bool validate_zero_padding(unsigned int dim0_value, unsigned int dim1_value)
76*c217d954SCole Faust {
77*c217d954SCole Faust const TensorShape in_shape(dim0_value, dim1_value);
78*c217d954SCole Faust TensorInfo in(in_shape, 1, DataType::U32);
79*c217d954SCole Faust TensorInfo dst;
80*c217d954SCole Faust
81*c217d954SCole Faust ARM_COMPUTE_EXPECT(in.is_resizable(), framework::LogLevel::ERRORS);
82*c217d954SCole Faust
83*c217d954SCole Faust // Validate zero-padding
84*c217d954SCole Faust FunctionType func;
85*c217d954SCole Faust
86*c217d954SCole Faust func.configure(&in, &dst);
87*c217d954SCole Faust
88*c217d954SCole Faust return in.padding().empty();
89*c217d954SCole Faust }
90*c217d954SCole Faust
91*c217d954SCole Faust /* Zero padding test for GEMM kernels */
validate_gemm_zero_padding(const TensorShape shape0,const TensorShape shape1)92*c217d954SCole Faust bool validate_gemm_zero_padding(const TensorShape shape0, const TensorShape shape1)
93*c217d954SCole Faust {
94*c217d954SCole Faust // Create tensors
95*c217d954SCole Faust TensorInfo in0(shape0, 1, DataType::F32);
96*c217d954SCole Faust TensorInfo in1(shape1, 1, DataType::F32);
97*c217d954SCole Faust TensorInfo dst;
98*c217d954SCole Faust
99*c217d954SCole Faust // Validate zero-padding
100*c217d954SCole Faust cpu::kernels::CpuGemmMatrixMultiplyKernel gemm;
101*c217d954SCole Faust gemm.configure(&in0, &in1, &dst, 1.0, false);
102*c217d954SCole Faust
103*c217d954SCole Faust return in0.padding().empty() && in1.padding().empty() && dst.padding().empty();
104*c217d954SCole Faust }
105*c217d954SCole Faust } // namespace
106*c217d954SCole Faust
107*c217d954SCole Faust TEST_SUITE(NEON)
TEST_SUITE(GEMM)108*c217d954SCole Faust TEST_SUITE(GEMM)
109*c217d954SCole Faust
110*c217d954SCole Faust /** Test case for memory injection in @ref cpu::CpuGemm.
111*c217d954SCole Faust *
112*c217d954SCole Faust * Configure the operator once and inject memory at run-time in multiple executions.
113*c217d954SCole Faust *
114*c217d954SCole Faust * Checks performed in order:
115*c217d954SCole Faust * - Both runs compute the same output
116*c217d954SCole Faust */
117*c217d954SCole Faust TEST_CASE(MemoryInjection, framework::DatasetMode::ALL)
118*c217d954SCole Faust {
119*c217d954SCole Faust auto gemm = std::make_unique<cpu::CpuGemm>();
120*c217d954SCole Faust const auto lhs_info = TensorInfo(TensorShape(3U, 3U), 1, DataType::F32);
121*c217d954SCole Faust const auto rhs_info = TensorInfo(TensorShape(4U, 3U), 1, DataType::F32);
122*c217d954SCole Faust const auto c_info = TensorInfo(TensorShape(4U, 3U), 1, DataType::F32);
123*c217d954SCole Faust auto dst_info = TensorInfo(TensorShape(4U, 3U), 1, DataType::F32);
124*c217d954SCole Faust const auto gemm_info = GEMMInfo{};
125*c217d954SCole Faust gemm->configure(&lhs_info, &rhs_info, &c_info, &dst_info, 1.f, 1.f, gemm_info);
126*c217d954SCole Faust
127*c217d954SCole Faust // telhs are newly created every call of this lambda function
128*c217d954SCole Faust auto lhs = create_tensor<Tensor>(lhs_info);
129*c217d954SCole Faust auto rhs = create_tensor<Tensor>(rhs_info);
130*c217d954SCole Faust auto c = create_tensor<Tensor>(c_info);
131*c217d954SCole Faust lhs.allocator()->allocate();
132*c217d954SCole Faust rhs.allocator()->allocate();
133*c217d954SCole Faust c.allocator()->allocate();
134*c217d954SCole Faust
135*c217d954SCole Faust ITensorPack run_pack{ { TensorType::ACL_SRC_0, &lhs }, { TensorType::ACL_SRC_1, &rhs }, { TensorType::ACL_SRC_2, &c } };
136*c217d954SCole Faust ITensorPack prep_pack{ { TensorType::ACL_SRC_1, &rhs }, { TensorType::ACL_SRC_2, &c } };
137*c217d954SCole Faust
138*c217d954SCole Faust auto mg = MemoryGroup{};
139*c217d954SCole Faust auto ws = manage_workspace<Tensor>(gemm->workspace(), mg, run_pack, prep_pack);
140*c217d954SCole Faust
141*c217d954SCole Faust auto run_conv = [&]() -> Tensor
142*c217d954SCole Faust {
143*c217d954SCole Faust auto dst = create_tensor<Tensor>(dst_info);
144*c217d954SCole Faust dst.allocator()->allocate();
145*c217d954SCole Faust run_pack.add_tensor(TensorType::ACL_DST, &dst);
146*c217d954SCole Faust
147*c217d954SCole Faust library->fill_tensor_value(Accessor(lhs), 1.f);
148*c217d954SCole Faust library->fill_tensor_value(Accessor(rhs), 2.f);
149*c217d954SCole Faust library->fill_tensor_value(Accessor(c), 3.f);
150*c217d954SCole Faust // This operator is configured once and captured by this lambda.
151*c217d954SCole Faust gemm->prepare(prep_pack);
152*c217d954SCole Faust gemm->run(run_pack);
153*c217d954SCole Faust return dst;
154*c217d954SCole Faust };
155*c217d954SCole Faust auto result_0 = run_conv();
156*c217d954SCole Faust auto result_1 = run_conv();
157*c217d954SCole Faust for(size_t i = 0; i < result_0.info()->tensor_shape().total_size(); ++i)
158*c217d954SCole Faust {
159*c217d954SCole Faust ARM_COMPUTE_EXPECT(((float *)result_0.buffer())[i] == ((float *)result_1.buffer())[i], framework::LogLevel::ERRORS);
160*c217d954SCole Faust }
161*c217d954SCole Faust }
162*c217d954SCole Faust
163*c217d954SCole Faust /** Test case for memory injection in @ref NEGEMM.
164*c217d954SCole Faust *
165*c217d954SCole Faust * Make sure @ref NEGEMM still works through injecting the memory at configure time using the old API.
166*c217d954SCole Faust *
167*c217d954SCole Faust * Checks performed in order:
168*c217d954SCole Faust * - Both runs compute the same output
169*c217d954SCole Faust */
TEST_CASE(MultipleExecutionWithConfigure,framework::DatasetMode::ALL)170*c217d954SCole Faust TEST_CASE(MultipleExecutionWithConfigure, framework::DatasetMode::ALL)
171*c217d954SCole Faust {
172*c217d954SCole Faust auto gemm = std::make_unique<NEGEMM>();
173*c217d954SCole Faust const auto lhs_info = TensorInfo(TensorShape(3U, 3U), 1, DataType::F32);
174*c217d954SCole Faust const auto rhs_info = TensorInfo(TensorShape(4U, 3U), 1, DataType::F32);
175*c217d954SCole Faust const auto c_info = TensorInfo(TensorShape(4U, 3U), 1, DataType::F32);
176*c217d954SCole Faust auto dst_info = TensorInfo(TensorShape(4U, 3U), 1, DataType::F32);
177*c217d954SCole Faust const auto gemm_info = GEMMInfo{};
178*c217d954SCole Faust auto run_conv = [&]()
179*c217d954SCole Faust {
180*c217d954SCole Faust auto lhs = create_tensor<Tensor>(lhs_info);
181*c217d954SCole Faust auto rhs = create_tensor<Tensor>(rhs_info);
182*c217d954SCole Faust auto c = create_tensor<Tensor>(c_info);
183*c217d954SCole Faust auto dst = create_tensor<Tensor>(dst_info);
184*c217d954SCole Faust gemm->configure(&lhs, &rhs, &c, &dst, 1.f, 1.f, gemm_info);
185*c217d954SCole Faust lhs.allocator()->allocate();
186*c217d954SCole Faust rhs.allocator()->allocate();
187*c217d954SCole Faust c.allocator()->allocate();
188*c217d954SCole Faust dst.allocator()->allocate();
189*c217d954SCole Faust library->fill_tensor_value(Accessor(lhs), 1.f);
190*c217d954SCole Faust library->fill_tensor_value(Accessor(rhs), 2.f);
191*c217d954SCole Faust library->fill_tensor_value(Accessor(c), 3.f);
192*c217d954SCole Faust gemm->run();
193*c217d954SCole Faust return dst;
194*c217d954SCole Faust };
195*c217d954SCole Faust auto result_0 = run_conv();
196*c217d954SCole Faust auto result_1 = run_conv();
197*c217d954SCole Faust for(size_t i = 0; i < result_0.info()->tensor_shape().total_size(); ++i)
198*c217d954SCole Faust {
199*c217d954SCole Faust ARM_COMPUTE_EXPECT(((float *)result_0.buffer())[i] == ((float *)result_1.buffer())[i], framework::LogLevel::ERRORS);
200*c217d954SCole Faust }
201*c217d954SCole Faust }
202*c217d954SCole Faust
203*c217d954SCole Faust // *INDENT-OFF*
204*c217d954SCole Faust // clang-format off
205*c217d954SCole Faust DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
206*c217d954SCole Faust framework::dataset::make("LhsInfo", { TensorInfo(TensorShape(27U, 13U), 1, DataType::S32), // Unsupported data type
207*c217d954SCole Faust TensorInfo(TensorShape(27U, 13U), 1, DataType::F32),
208*c217d954SCole Faust }),
209*c217d954SCole Faust framework::dataset::make("RhsInfo",{ TensorInfo(TensorShape(8U, 27U), 1, DataType::S32),
210*c217d954SCole Faust TensorInfo(TensorShape(8U, 27U), 1, DataType::F32),
211*c217d954SCole Faust })),
212*c217d954SCole Faust framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(8U, 13U), 1, DataType::S32),
213*c217d954SCole Faust TensorInfo(TensorShape(8U, 13U), 1, DataType::F32),
214*c217d954SCole Faust })),
215*c217d954SCole Faust framework::dataset::make("Expected", { false, true })),
216*c217d954SCole Faust lhs_info, rhs_info, output_info, expected)
217*c217d954SCole Faust {
218*c217d954SCole Faust constexpr float alpha = 1.0;
219*c217d954SCole Faust constexpr float beta = 0.0;
220*c217d954SCole Faust const auto gemm_info = GEMMInfo();
221*c217d954SCole Faust bool is_valid = bool(NEGEMM::validate(&lhs_info.clone()->set_is_resizable(true), &rhs_info.clone()->set_is_resizable(true), nullptr, &output_info.clone()->set_is_resizable(true), alpha, beta, gemm_info));
222*c217d954SCole Faust ARM_COMPUTE_EXPECT(is_valid == expected, framework::LogLevel::ERRORS);
223*c217d954SCole Faust }
224*c217d954SCole Faust // clang-format on
225*c217d954SCole Faust // *INDENT-ON*
226*c217d954SCole Faust TEST_SUITE(KERNEL_SELECTION)
227*c217d954SCole Faust DATA_TEST_CASE(KernelSelection_mul_and_add, framework::DatasetMode::ALL,
228*c217d954SCole Faust combine(framework::dataset::make("CpuExt", std::string("NEON")),
229*c217d954SCole Faust framework::dataset::make("DataType", { DataType::F32,
230*c217d954SCole Faust DataType::F16
231*c217d954SCole Faust })),
232*c217d954SCole Faust cpu_ext, data_type)
233*c217d954SCole Faust {
234*c217d954SCole Faust using namespace cpu::kernels;
235*c217d954SCole Faust
236*c217d954SCole Faust cpuinfo::CpuIsaInfo cpu_isa{};
237*c217d954SCole Faust cpu_isa.neon = (cpu_ext == "NEON");
238*c217d954SCole Faust cpu_isa.fp16 = (data_type == DataType::F16);
239*c217d954SCole Faust
240*c217d954SCole Faust const auto *selected_impl_mul = CpuGemmMatrixMultiplyKernel::get_implementation(DataTypeISASelectorData{ data_type, cpu_isa }, cpu::KernelSelectionType::Preferred);
241*c217d954SCole Faust
242*c217d954SCole Faust ARM_COMPUTE_ERROR_ON_NULLPTR(selected_impl_mul);
243*c217d954SCole Faust
244*c217d954SCole Faust std::string expected = lower_string(cpu_ext) + "_" + cpu_impl_dt(data_type) + "_gemm_matrix_mul";
245*c217d954SCole Faust std::string actual = selected_impl_mul->name;
246*c217d954SCole Faust
247*c217d954SCole Faust ARM_COMPUTE_EXPECT_EQUAL(expected, actual, framework::LogLevel::ERRORS);
248*c217d954SCole Faust
249*c217d954SCole Faust const auto *selected_impl_add = CpuGemmMatrixAdditionKernel::get_implementation(DataTypeISASelectorData{ data_type, cpu_isa }, cpu::KernelSelectionType::Preferred);
250*c217d954SCole Faust
251*c217d954SCole Faust ARM_COMPUTE_ERROR_ON_NULLPTR(selected_impl_add);
252*c217d954SCole Faust
253*c217d954SCole Faust expected = lower_string(cpu_ext) + "_" + cpu_impl_dt(data_type) + "_gemm_matrix_add";
254*c217d954SCole Faust actual = selected_impl_add->name;
255*c217d954SCole Faust
256*c217d954SCole Faust ARM_COMPUTE_EXPECT_EQUAL(expected, actual, framework::LogLevel::ERRORS);
257*c217d954SCole Faust }
258*c217d954SCole Faust TEST_SUITE_END() // KERNEL_SELECTION
259*c217d954SCole Faust
260*c217d954SCole Faust TEST_SUITE(TRANSPOSE_1XW)
261*c217d954SCole Faust using CpuGemmTranspose1xW = NESynthetizeFunctionWithZeroConstantKernelBorder<cpu::kernels::CpuGemmTranspose1xWKernel>;
262*c217d954SCole Faust DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(
263*c217d954SCole Faust framework::dataset::make("N", { 1, 23, 63, 101 }),
264*c217d954SCole Faust framework::dataset::make("K", { 1, 47, 29, 27 })),
265*c217d954SCole Faust n_value, k_value)
266*c217d954SCole Faust {
267*c217d954SCole Faust bool status = validate_zero_padding<CpuGemmTranspose1xW>(n_value, k_value);
268*c217d954SCole Faust ARM_COMPUTE_EXPECT(status, framework::LogLevel::ERRORS);
269*c217d954SCole Faust }
270*c217d954SCole Faust
271*c217d954SCole Faust TEST_SUITE(U32)
272*c217d954SCole Faust using CpuGemmTranspose1xWFixture = GEMMTranspose1xWValidationFixture<Tensor, Accessor, CpuGemmTranspose1xW, uint32_t>;
273*c217d954SCole Faust FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * framework::dataset::make("DataType", DataType::U32))
274*c217d954SCole Faust {
275*c217d954SCole Faust // Validate output
276*c217d954SCole Faust validate(Accessor(_target), _reference);
277*c217d954SCole Faust }
278*c217d954SCole Faust TEST_SUITE_END() // U32
279*c217d954SCole Faust
280*c217d954SCole Faust TEST_SUITE(U16)
281*c217d954SCole Faust using CpuGemmTranspose1xWFixture = GEMMTranspose1xWValidationFixture<Tensor, Accessor, CpuGemmTranspose1xW, uint16_t>;
282*c217d954SCole Faust FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * framework::dataset::make("DataType", DataType::U16))
283*c217d954SCole Faust {
284*c217d954SCole Faust // Validate output
285*c217d954SCole Faust validate(Accessor(_target), _reference);
286*c217d954SCole Faust }
287*c217d954SCole Faust TEST_SUITE_END() // U16
288*c217d954SCole Faust
289*c217d954SCole Faust TEST_SUITE(U8)
290*c217d954SCole Faust using CpuGemmTranspose1xWFixture = GEMMTranspose1xWValidationFixture<Tensor, Accessor, CpuGemmTranspose1xW, uint8_t>;
291*c217d954SCole Faust FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * framework::dataset::make("DataType", DataType::U8))
292*c217d954SCole Faust {
293*c217d954SCole Faust // Validate output
294*c217d954SCole Faust validate(Accessor(_target), _reference);
295*c217d954SCole Faust }
296*c217d954SCole Faust TEST_SUITE_END() // U8
297*c217d954SCole Faust
298*c217d954SCole Faust TEST_SUITE_END() // TRANSPOSE_1XW
299*c217d954SCole Faust
300*c217d954SCole Faust TEST_SUITE(INTERLEAVE_4X4)
301*c217d954SCole Faust using CpuGemmInterleave4x4 = NESynthetizeFunctionWithZeroConstantKernelBorder<cpu::kernels::CpuGemmInterleave4x4Kernel>;
302*c217d954SCole Faust
303*c217d954SCole Faust DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(
304*c217d954SCole Faust framework::dataset::make("M", { 1, 23, 63, 101 }),
305*c217d954SCole Faust framework::dataset::make("K", { 1, 47, 29, 27 })),
306*c217d954SCole Faust m_value, k_value)
307*c217d954SCole Faust {
308*c217d954SCole Faust bool status = validate_zero_padding<cpu::kernels::CpuGemmInterleave4x4Kernel>(m_value, k_value);
309*c217d954SCole Faust ARM_COMPUTE_EXPECT(status, framework::LogLevel::ERRORS);
310*c217d954SCole Faust }
311*c217d954SCole Faust
312*c217d954SCole Faust TEST_SUITE(U32)
313*c217d954SCole Faust using CpuGemmInterleave4x4Fixture = GEMMInterleave4x4ValidationFixture<Tensor, Accessor, CpuGemmInterleave4x4, uint32_t>;
314*c217d954SCole Faust FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmInterleave4x4Fixture, framework::DatasetMode::PRECOMMIT, data_interleave * framework::dataset::make("DataType", DataType::U32))
315*c217d954SCole Faust {
316*c217d954SCole Faust // Validate output
317*c217d954SCole Faust validate(Accessor(_target), _reference);
318*c217d954SCole Faust }
319*c217d954SCole Faust TEST_SUITE_END() // U32
320*c217d954SCole Faust
321*c217d954SCole Faust TEST_SUITE(U16)
322*c217d954SCole Faust using CpuGemmInterleave4x4Fixture = GEMMInterleave4x4ValidationFixture<Tensor, Accessor, CpuGemmInterleave4x4, uint16_t>;
323*c217d954SCole Faust FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmInterleave4x4Fixture, framework::DatasetMode::PRECOMMIT, data_interleave * framework::dataset::make("DataType", DataType::U16))
324*c217d954SCole Faust {
325*c217d954SCole Faust // Validate output
326*c217d954SCole Faust validate(Accessor(_target), _reference);
327*c217d954SCole Faust }
328*c217d954SCole Faust TEST_SUITE_END() // U16
329*c217d954SCole Faust
330*c217d954SCole Faust TEST_SUITE(U8)
331*c217d954SCole Faust using CpuGemmInterleave4x4Fixture = GEMMInterleave4x4ValidationFixture<Tensor, Accessor, CpuGemmInterleave4x4, uint8_t>;
332*c217d954SCole Faust FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmInterleave4x4Fixture, framework::DatasetMode::PRECOMMIT, data_interleave * framework::dataset::make("DataType", DataType::QASYMM8))
333*c217d954SCole Faust {
334*c217d954SCole Faust // Validate output
335*c217d954SCole Faust validate(Accessor(_target), _reference);
336*c217d954SCole Faust }
337*c217d954SCole Faust TEST_SUITE_END() // U8
338*c217d954SCole Faust
339*c217d954SCole Faust TEST_SUITE_END() // INTERLEAVE_4X4
340*c217d954SCole Faust
341*c217d954SCole Faust template <typename T>
342*c217d954SCole Faust using NEGEMMFixture = GEMMValidationFixture<Tensor, Accessor, NEGEMM, T>;
343*c217d954SCole Faust
344*c217d954SCole Faust template <typename T>
345*c217d954SCole Faust using NEBatchedMatMulFixture = GEMMValidationFixture<Tensor, Accessor, NEGEMM, T, true, false, false, false, false, true>;
346*c217d954SCole Faust
347*c217d954SCole Faust TEST_SUITE(Float)
348*c217d954SCole Faust DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(framework::dataset::make("In0", { TensorShape(21U, 13U),
349*c217d954SCole Faust TensorShape(31U, 1U),
350*c217d954SCole Faust TensorShape(31U, 1U),
351*c217d954SCole Faust TensorShape(8U, 2U),
352*c217d954SCole Faust TensorShape(38U, 12U),
353*c217d954SCole Faust TensorShape(32U, 1U)
354*c217d954SCole Faust }),
355*c217d954SCole Faust framework::dataset::make("In1", { TensorShape(33U, 21U),
356*c217d954SCole Faust TensorShape(23U, 31U),
357*c217d954SCole Faust TensorShape(23U, 31U),
358*c217d954SCole Faust TensorShape(16U, 8U),
359*c217d954SCole Faust TensorShape(21U, 38U),
360*c217d954SCole Faust TensorShape(17U, 32U)
361*c217d954SCole Faust })),
362*c217d954SCole Faust shape0, shape1)
363*c217d954SCole Faust {
364*c217d954SCole Faust bool status = validate_gemm_zero_padding(shape0, shape1);
365*c217d954SCole Faust ARM_COMPUTE_EXPECT(status, framework::LogLevel::ERRORS);
366*c217d954SCole Faust }
367*c217d954SCole Faust
368*c217d954SCole Faust #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
369*c217d954SCole Faust TEST_SUITE(FP16)
370*c217d954SCole Faust FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMDataset(),
371*c217d954SCole Faust framework::dataset::make("ReshapeWeights", { true, false })),
372*c217d954SCole Faust framework::dataset::make("DataType", DataType::F16)))
373*c217d954SCole Faust {
374*c217d954SCole Faust // Validate output
375*c217d954SCole Faust validate(Accessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_f16);
376*c217d954SCole Faust }
377*c217d954SCole Faust
378*c217d954SCole Faust TEST_SUITE(BATCHED_MATMUL)
379*c217d954SCole Faust
380*c217d954SCole Faust FIXTURE_DATA_TEST_CASE(RunSmall, NEBatchedMatMulFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallBatchedMatMulDataset(),
381*c217d954SCole Faust framework::dataset::make("ReshapeWeights", { false })),
382*c217d954SCole Faust framework::dataset::make("DataType", DataType::F16)))
383*c217d954SCole Faust {
384*c217d954SCole Faust // Validate output
385*c217d954SCole Faust validate(Accessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_f16);
386*c217d954SCole Faust }
387*c217d954SCole Faust TEST_SUITE_END()
388*c217d954SCole Faust
389*c217d954SCole Faust FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMDataset(),
390*c217d954SCole Faust framework::dataset::make("ReshapeWeights", { true, false })),
391*c217d954SCole Faust
392*c217d954SCole Faust framework::dataset::make("DataType", DataType::F16)))
393*c217d954SCole Faust {
394*c217d954SCole Faust // Validate output
395*c217d954SCole Faust validate(Accessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_f16);
396*c217d954SCole Faust }
397*c217d954SCole Faust TEST_SUITE_END()
398*c217d954SCole Faust #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
399*c217d954SCole Faust
TEST_SUITE(FP32)400*c217d954SCole Faust TEST_SUITE(FP32)
401*c217d954SCole Faust FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMDataset(),
402*c217d954SCole Faust framework::dataset::make("ReshapeWeights", { true, false })),
403*c217d954SCole Faust
404*c217d954SCole Faust framework::dataset::make("DataType", DataType::F32)))
405*c217d954SCole Faust {
406*c217d954SCole Faust // Validate output
407*c217d954SCole Faust validate(Accessor(_target), _reference, tolerance_f);
408*c217d954SCole Faust }
409*c217d954SCole Faust FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMDataset(),
410*c217d954SCole Faust framework::dataset::make("ReshapeWeights", { true, false })),
411*c217d954SCole Faust
412*c217d954SCole Faust framework::dataset::make("DataType", DataType::F32)))
413*c217d954SCole Faust {
414*c217d954SCole Faust // Validate output
415*c217d954SCole Faust validate(Accessor(_target), _reference, tolerance_f);
416*c217d954SCole Faust }
417*c217d954SCole Faust
418*c217d954SCole Faust TEST_SUITE(BATCHED_MATMUL)
419*c217d954SCole Faust
TEST_SUITE(FP32)420*c217d954SCole Faust TEST_SUITE(FP32)
421*c217d954SCole Faust FIXTURE_DATA_TEST_CASE(RunSmall, NEBatchedMatMulFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallBatchedMatMulDataset(),
422*c217d954SCole Faust framework::dataset::make("ReshapeWeights", { false })),
423*c217d954SCole Faust framework::dataset::make("DataType", DataType::F32)))
424*c217d954SCole Faust {
425*c217d954SCole Faust // Validate output
426*c217d954SCole Faust validate(Accessor(_target), _reference, tolerance_f);
427*c217d954SCole Faust }
428*c217d954SCole Faust TEST_SUITE_END()
429*c217d954SCole Faust
430*c217d954SCole Faust TEST_SUITE_END()
431*c217d954SCole Faust
432*c217d954SCole Faust TEST_SUITE_END()
433*c217d954SCole Faust TEST_SUITE_END()
434*c217d954SCole Faust
435*c217d954SCole Faust TEST_SUITE_END()
436*c217d954SCole Faust TEST_SUITE_END()
437*c217d954SCole Faust } // namespace validation
438*c217d954SCole Faust } // namespace test
439*c217d954SCole Faust } // namespace arm_compute
440