1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker
7*89c4ff92SAndroid Build Coastguard Worker #include "armnnTestUtils/TensorHelpers.hpp"
8*89c4ff92SAndroid Build Coastguard Worker
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Logging.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Utils.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <reference/RefWorkloadFactory.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <reference/test/RefWorkloadFactoryHelper.hpp>
13*89c4ff92SAndroid Build Coastguard Worker
14*89c4ff92SAndroid Build Coastguard Worker #include <backendsCommon/test/WorkloadFactoryHelper.hpp>
15*89c4ff92SAndroid Build Coastguard Worker
16*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/LayerTestResult.hpp>
17*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/TensorCopyUtils.hpp>
18*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/WorkloadTestUtils.hpp>
19*89c4ff92SAndroid Build Coastguard Worker
20*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
21*89c4ff92SAndroid Build Coastguard Worker
ConfigureLoggingTest()22*89c4ff92SAndroid Build Coastguard Worker inline void ConfigureLoggingTest()
23*89c4ff92SAndroid Build Coastguard Worker {
24*89c4ff92SAndroid Build Coastguard Worker // Configures logging for both the ARMNN library and this test program.
25*89c4ff92SAndroid Build Coastguard Worker armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal);
26*89c4ff92SAndroid Build Coastguard Worker }
27*89c4ff92SAndroid Build Coastguard Worker
28*89c4ff92SAndroid Build Coastguard Worker // The following macros require the caller to have defined FactoryType, with one of the following using statements:
29*89c4ff92SAndroid Build Coastguard Worker //
30*89c4ff92SAndroid Build Coastguard Worker // using FactoryType = armnn::RefWorkloadFactory;
31*89c4ff92SAndroid Build Coastguard Worker // using FactoryType = armnn::ClWorkloadFactory;
32*89c4ff92SAndroid Build Coastguard Worker // using FactoryType = armnn::NeonWorkloadFactory;
33*89c4ff92SAndroid Build Coastguard Worker
34*89c4ff92SAndroid Build Coastguard Worker /// Executes CHECK_MESSAGE on CompareTensors() return value so that the predicate_result message is reported.
35*89c4ff92SAndroid Build Coastguard Worker /// If the test reports itself as not supported then the tensors are not compared.
36*89c4ff92SAndroid Build Coastguard Worker /// Additionally this checks that the supportedness reported by the test matches the name of the test.
37*89c4ff92SAndroid Build Coastguard Worker /// Unsupported tests must be 'tagged' by including "UNSUPPORTED" in their name.
38*89c4ff92SAndroid Build Coastguard Worker /// This is useful because it clarifies that the feature being tested is not actually supported
39*89c4ff92SAndroid Build Coastguard Worker /// (a passed test with the name of a feature would imply that feature was supported).
40*89c4ff92SAndroid Build Coastguard Worker /// If support is added for a feature, the test case will fail because the name incorrectly contains UNSUPPORTED.
41*89c4ff92SAndroid Build Coastguard Worker /// If support is removed for a feature, the test case will fail because the name doesn't contain UNSUPPORTED.
42*89c4ff92SAndroid Build Coastguard Worker template <typename T, std::size_t n>
CompareTestResultIfSupported(const std::string & testName,const LayerTestResult<T,n> & testResult)43*89c4ff92SAndroid Build Coastguard Worker void CompareTestResultIfSupported(const std::string& testName, const LayerTestResult<T, n>& testResult)
44*89c4ff92SAndroid Build Coastguard Worker {
45*89c4ff92SAndroid Build Coastguard Worker bool testNameIndicatesUnsupported = testName.find("UNSUPPORTED") != std::string::npos;
46*89c4ff92SAndroid Build Coastguard Worker CHECK_MESSAGE(testNameIndicatesUnsupported != testResult.m_Supported,
47*89c4ff92SAndroid Build Coastguard Worker "The test name does not match the supportedness it is reporting");
48*89c4ff92SAndroid Build Coastguard Worker if (testResult.m_Supported)
49*89c4ff92SAndroid Build Coastguard Worker {
50*89c4ff92SAndroid Build Coastguard Worker auto result = CompareTensors(testResult.m_ActualData,
51*89c4ff92SAndroid Build Coastguard Worker testResult.m_ExpectedData,
52*89c4ff92SAndroid Build Coastguard Worker testResult.m_ActualShape,
53*89c4ff92SAndroid Build Coastguard Worker testResult.m_ExpectedShape,
54*89c4ff92SAndroid Build Coastguard Worker testResult.m_CompareBoolean);
55*89c4ff92SAndroid Build Coastguard Worker CHECK_MESSAGE(result.m_Result, result.m_Message.str());
56*89c4ff92SAndroid Build Coastguard Worker }
57*89c4ff92SAndroid Build Coastguard Worker }
58*89c4ff92SAndroid Build Coastguard Worker
59*89c4ff92SAndroid Build Coastguard Worker template <typename T, std::size_t n>
CompareTestResultIfSupported(const std::string & testName,const std::vector<LayerTestResult<T,n>> & testResult)60*89c4ff92SAndroid Build Coastguard Worker void CompareTestResultIfSupported(const std::string& testName, const std::vector<LayerTestResult<T, n>>& testResult)
61*89c4ff92SAndroid Build Coastguard Worker {
62*89c4ff92SAndroid Build Coastguard Worker bool testNameIndicatesUnsupported = testName.find("UNSUPPORTED") != std::string::npos;
63*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < testResult.size(); ++i)
64*89c4ff92SAndroid Build Coastguard Worker {
65*89c4ff92SAndroid Build Coastguard Worker CHECK_MESSAGE(testNameIndicatesUnsupported != testResult[i].m_Supported,
66*89c4ff92SAndroid Build Coastguard Worker "The test name does not match the supportedness it is reporting");
67*89c4ff92SAndroid Build Coastguard Worker if (testResult[i].m_Supported)
68*89c4ff92SAndroid Build Coastguard Worker {
69*89c4ff92SAndroid Build Coastguard Worker auto result = CompareTensors(testResult[i].m_ActualData,
70*89c4ff92SAndroid Build Coastguard Worker testResult[i].m_ExpectedData,
71*89c4ff92SAndroid Build Coastguard Worker testResult[i].m_ActualShape,
72*89c4ff92SAndroid Build Coastguard Worker testResult[i].m_ExpectedShape);
73*89c4ff92SAndroid Build Coastguard Worker CHECK_MESSAGE(result.m_Result, result.m_Message.str());
74*89c4ff92SAndroid Build Coastguard Worker }
75*89c4ff92SAndroid Build Coastguard Worker }
76*89c4ff92SAndroid Build Coastguard Worker }
77*89c4ff92SAndroid Build Coastguard Worker
78*89c4ff92SAndroid Build Coastguard Worker template<typename FactoryType, typename TFuncPtr, typename... Args>
RunTestFunction(const char * testName,TFuncPtr testFunction,Args...args)79*89c4ff92SAndroid Build Coastguard Worker void RunTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
80*89c4ff92SAndroid Build Coastguard Worker {
81*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<armnn::IProfiler> profiler = std::make_unique<armnn::IProfiler>();
82*89c4ff92SAndroid Build Coastguard Worker armnn::ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
83*89c4ff92SAndroid Build Coastguard Worker
84*89c4ff92SAndroid Build Coastguard Worker auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
85*89c4ff92SAndroid Build Coastguard Worker FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
86*89c4ff92SAndroid Build Coastguard Worker
87*89c4ff92SAndroid Build Coastguard Worker auto testResult = (*testFunction)(workloadFactory, memoryManager, args...);
88*89c4ff92SAndroid Build Coastguard Worker CompareTestResultIfSupported(testName, testResult);
89*89c4ff92SAndroid Build Coastguard Worker
90*89c4ff92SAndroid Build Coastguard Worker armnn::ProfilerManager::GetInstance().RegisterProfiler(nullptr);
91*89c4ff92SAndroid Build Coastguard Worker }
92*89c4ff92SAndroid Build Coastguard Worker
93*89c4ff92SAndroid Build Coastguard Worker
94*89c4ff92SAndroid Build Coastguard Worker template<typename FactoryType, typename TFuncPtr, typename... Args>
RunTestFunctionUsingTensorHandleFactory(const char * testName,TFuncPtr testFunction,Args...args)95*89c4ff92SAndroid Build Coastguard Worker void RunTestFunctionUsingTensorHandleFactory(const char* testName, TFuncPtr testFunction, Args... args)
96*89c4ff92SAndroid Build Coastguard Worker {
97*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<armnn::IProfiler> profiler = std::make_unique<armnn::IProfiler>();
98*89c4ff92SAndroid Build Coastguard Worker armnn::ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
99*89c4ff92SAndroid Build Coastguard Worker
100*89c4ff92SAndroid Build Coastguard Worker auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
101*89c4ff92SAndroid Build Coastguard Worker FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
102*89c4ff92SAndroid Build Coastguard Worker
103*89c4ff92SAndroid Build Coastguard Worker auto tensorHandleFactory = WorkloadFactoryHelper<FactoryType>::GetTensorHandleFactory(memoryManager);
104*89c4ff92SAndroid Build Coastguard Worker
105*89c4ff92SAndroid Build Coastguard Worker auto testResult = (*testFunction)(workloadFactory, memoryManager, tensorHandleFactory, args...);
106*89c4ff92SAndroid Build Coastguard Worker CompareTestResultIfSupported(testName, testResult);
107*89c4ff92SAndroid Build Coastguard Worker
108*89c4ff92SAndroid Build Coastguard Worker armnn::ProfilerManager::GetInstance().RegisterProfiler(nullptr);
109*89c4ff92SAndroid Build Coastguard Worker }
110*89c4ff92SAndroid Build Coastguard Worker
111*89c4ff92SAndroid Build Coastguard Worker #define ARMNN_SIMPLE_TEST_CASE(TestName, TestFunction) \
112*89c4ff92SAndroid Build Coastguard Worker TEST_CASE(#TestName) \
113*89c4ff92SAndroid Build Coastguard Worker { \
114*89c4ff92SAndroid Build Coastguard Worker TestFunction(); \
115*89c4ff92SAndroid Build Coastguard Worker }
116*89c4ff92SAndroid Build Coastguard Worker
117*89c4ff92SAndroid Build Coastguard Worker #define ARMNN_AUTO_TEST_CASE(TestName, TestFunction, ...) \
118*89c4ff92SAndroid Build Coastguard Worker TEST_CASE(#TestName) \
119*89c4ff92SAndroid Build Coastguard Worker { \
120*89c4ff92SAndroid Build Coastguard Worker RunTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
121*89c4ff92SAndroid Build Coastguard Worker }
122*89c4ff92SAndroid Build Coastguard Worker
123*89c4ff92SAndroid Build Coastguard Worker #define ARMNN_AUTO_TEST_FIXTURE(TestName, Fixture, TestFunction, ...) \
124*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(Fixture, #TestName) \
125*89c4ff92SAndroid Build Coastguard Worker { \
126*89c4ff92SAndroid Build Coastguard Worker RunTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
127*89c4ff92SAndroid Build Coastguard Worker }
128*89c4ff92SAndroid Build Coastguard Worker
129*89c4ff92SAndroid Build Coastguard Worker #define ARMNN_AUTO_TEST_CASE_WITH_THF(TestName, TestFunction, ...) \
130*89c4ff92SAndroid Build Coastguard Worker TEST_CASE(#TestName) \
131*89c4ff92SAndroid Build Coastguard Worker { \
132*89c4ff92SAndroid Build Coastguard Worker RunTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
133*89c4ff92SAndroid Build Coastguard Worker }
134*89c4ff92SAndroid Build Coastguard Worker
135*89c4ff92SAndroid Build Coastguard Worker #define ARMNN_AUTO_TEST_FIXTURE_WITH_THF(TestName, Fixture, TestFunction, ...) \
136*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(Fixture, #TestName) \
137*89c4ff92SAndroid Build Coastguard Worker { \
138*89c4ff92SAndroid Build Coastguard Worker RunTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
139*89c4ff92SAndroid Build Coastguard Worker }
140*89c4ff92SAndroid Build Coastguard Worker
141*89c4ff92SAndroid Build Coastguard Worker template<typename FactoryType, typename TFuncPtr, typename... Args>
CompareRefTestFunction(const char * testName,TFuncPtr testFunction,Args...args)142*89c4ff92SAndroid Build Coastguard Worker void CompareRefTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
143*89c4ff92SAndroid Build Coastguard Worker {
144*89c4ff92SAndroid Build Coastguard Worker auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
145*89c4ff92SAndroid Build Coastguard Worker FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
146*89c4ff92SAndroid Build Coastguard Worker
147*89c4ff92SAndroid Build Coastguard Worker armnn::RefWorkloadFactory refWorkloadFactory;
148*89c4ff92SAndroid Build Coastguard Worker
149*89c4ff92SAndroid Build Coastguard Worker auto testResult = (*testFunction)(workloadFactory, memoryManager, refWorkloadFactory, args...);
150*89c4ff92SAndroid Build Coastguard Worker CompareTestResultIfSupported(testName, testResult);
151*89c4ff92SAndroid Build Coastguard Worker }
152*89c4ff92SAndroid Build Coastguard Worker
153*89c4ff92SAndroid Build Coastguard Worker template<typename FactoryType, typename TFuncPtr, typename... Args>
CompareRefTestFunctionUsingTensorHandleFactory(const char * testName,TFuncPtr testFunction,Args...args)154*89c4ff92SAndroid Build Coastguard Worker void CompareRefTestFunctionUsingTensorHandleFactory(const char* testName, TFuncPtr testFunction, Args... args)
155*89c4ff92SAndroid Build Coastguard Worker {
156*89c4ff92SAndroid Build Coastguard Worker auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
157*89c4ff92SAndroid Build Coastguard Worker FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
158*89c4ff92SAndroid Build Coastguard Worker auto tensorHandleFactory = WorkloadFactoryHelper<FactoryType>::GetTensorHandleFactory(memoryManager);
159*89c4ff92SAndroid Build Coastguard Worker
160*89c4ff92SAndroid Build Coastguard Worker armnn::RefWorkloadFactory refWorkloadFactory;
161*89c4ff92SAndroid Build Coastguard Worker auto refMemoryManager = WorkloadFactoryHelper<armnn::RefWorkloadFactory>::GetMemoryManager();
162*89c4ff92SAndroid Build Coastguard Worker auto refTensorHandleFactory = RefWorkloadFactoryHelper::GetTensorHandleFactory(refMemoryManager);
163*89c4ff92SAndroid Build Coastguard Worker
164*89c4ff92SAndroid Build Coastguard Worker auto testResult = (*testFunction)(
165*89c4ff92SAndroid Build Coastguard Worker workloadFactory, memoryManager, refWorkloadFactory, tensorHandleFactory, refTensorHandleFactory, args...);
166*89c4ff92SAndroid Build Coastguard Worker CompareTestResultIfSupported(testName, testResult);
167*89c4ff92SAndroid Build Coastguard Worker }
168*89c4ff92SAndroid Build Coastguard Worker
169*89c4ff92SAndroid Build Coastguard Worker #define ARMNN_COMPARE_REF_AUTO_TEST_CASE(TestName, TestFunction, ...) \
170*89c4ff92SAndroid Build Coastguard Worker TEST_CASE(#TestName) \
171*89c4ff92SAndroid Build Coastguard Worker { \
172*89c4ff92SAndroid Build Coastguard Worker CompareRefTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
173*89c4ff92SAndroid Build Coastguard Worker }
174*89c4ff92SAndroid Build Coastguard Worker
175*89c4ff92SAndroid Build Coastguard Worker #define ARMNN_COMPARE_REF_AUTO_TEST_CASE_WITH_THF(TestName, TestFunction, ...) \
176*89c4ff92SAndroid Build Coastguard Worker TEST_CASE(#TestName) \
177*89c4ff92SAndroid Build Coastguard Worker { \
178*89c4ff92SAndroid Build Coastguard Worker CompareRefTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
179*89c4ff92SAndroid Build Coastguard Worker }
180*89c4ff92SAndroid Build Coastguard Worker
181*89c4ff92SAndroid Build Coastguard Worker #define ARMNN_COMPARE_REF_FIXTURE_TEST_CASE(TestName, Fixture, TestFunction, ...) \
182*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(Fixture, #TestName) \
183*89c4ff92SAndroid Build Coastguard Worker { \
184*89c4ff92SAndroid Build Coastguard Worker CompareRefTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
185*89c4ff92SAndroid Build Coastguard Worker }
186*89c4ff92SAndroid Build Coastguard Worker
187*89c4ff92SAndroid Build Coastguard Worker #define ARMNN_COMPARE_REF_FIXTURE_TEST_CASE_WITH_THF(TestName, Fixture, TestFunction, ...) \
188*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(Fixture, #TestName) \
189*89c4ff92SAndroid Build Coastguard Worker { \
190*89c4ff92SAndroid Build Coastguard Worker CompareRefTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
191*89c4ff92SAndroid Build Coastguard Worker }
192