xref: /aosp_15_r20/external/armnn/tests/ExecuteNetwork/ExecuteNetwork.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ExecuteNetworkProgramOptions.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "ArmNNExecutor.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_TFLITE_DELEGATE)
9*89c4ff92SAndroid Build Coastguard Worker #include "TfliteExecutor.hpp"
10*89c4ff92SAndroid Build Coastguard Worker #endif
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Logging.hpp>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker 
BuildExecutor(ProgramOptions & programOptions)14*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<IExecutor> BuildExecutor(ProgramOptions& programOptions)
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker     if (programOptions.m_ExNetParams.m_TfLiteExecutor == ExecuteNetworkParams::TfLiteExecutor::ArmNNTfLiteDelegate ||
17*89c4ff92SAndroid Build Coastguard Worker         programOptions.m_ExNetParams.m_TfLiteExecutor == ExecuteNetworkParams::TfLiteExecutor::TfliteInterpreter)
18*89c4ff92SAndroid Build Coastguard Worker     {
19*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_TFLITE_DELEGATE)
20*89c4ff92SAndroid Build Coastguard Worker         return std::make_unique<TfLiteExecutor>(programOptions.m_ExNetParams, programOptions.m_RuntimeOptions);
21*89c4ff92SAndroid Build Coastguard Worker #else
22*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(fatal) << "Not built with Arm NN Tensorflow-Lite delegate support.";
23*89c4ff92SAndroid Build Coastguard Worker         return nullptr;
24*89c4ff92SAndroid Build Coastguard Worker #endif
25*89c4ff92SAndroid Build Coastguard Worker     }
26*89c4ff92SAndroid Build Coastguard Worker     else
27*89c4ff92SAndroid Build Coastguard Worker     {
28*89c4ff92SAndroid Build Coastguard Worker         return std::make_unique<ArmNNExecutor>(programOptions.m_ExNetParams, programOptions.m_RuntimeOptions);
29*89c4ff92SAndroid Build Coastguard Worker     }
30*89c4ff92SAndroid Build Coastguard Worker }
31*89c4ff92SAndroid Build Coastguard Worker 
32*89c4ff92SAndroid Build Coastguard Worker // MAIN
main(int argc,const char * argv[])33*89c4ff92SAndroid Build Coastguard Worker int main(int argc, const char* argv[])
34*89c4ff92SAndroid Build Coastguard Worker {
35*89c4ff92SAndroid Build Coastguard Worker     // Configures logging for both the ARMNN library and this test program.
36*89c4ff92SAndroid Build Coastguard Worker #ifdef NDEBUG
37*89c4ff92SAndroid Build Coastguard Worker     armnn::LogSeverity level = armnn::LogSeverity::Info;
38*89c4ff92SAndroid Build Coastguard Worker #else
39*89c4ff92SAndroid Build Coastguard Worker     armnn::LogSeverity level = armnn::LogSeverity::Debug;
40*89c4ff92SAndroid Build Coastguard Worker #endif
41*89c4ff92SAndroid Build Coastguard Worker     armnn::ConfigureLogging(true, true, level);
42*89c4ff92SAndroid Build Coastguard Worker 
43*89c4ff92SAndroid Build Coastguard Worker 
44*89c4ff92SAndroid Build Coastguard Worker     // Get ExecuteNetwork parameters and runtime options from command line
45*89c4ff92SAndroid Build Coastguard Worker     // This might throw an InvalidArgumentException if the user provided invalid inputs
46*89c4ff92SAndroid Build Coastguard Worker     ProgramOptions programOptions;
47*89c4ff92SAndroid Build Coastguard Worker     try
48*89c4ff92SAndroid Build Coastguard Worker     {
49*89c4ff92SAndroid Build Coastguard Worker         programOptions.ParseOptions(argc, argv);
50*89c4ff92SAndroid Build Coastguard Worker     }
51*89c4ff92SAndroid Build Coastguard Worker     catch (const std::exception& e)
52*89c4ff92SAndroid Build Coastguard Worker     {
53*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(fatal) << e.what();
54*89c4ff92SAndroid Build Coastguard Worker         return EXIT_FAILURE;
55*89c4ff92SAndroid Build Coastguard Worker     }
56*89c4ff92SAndroid Build Coastguard Worker 
57*89c4ff92SAndroid Build Coastguard Worker     std::vector<const void*> outputResults;
58*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<IExecutor> executor;
59*89c4ff92SAndroid Build Coastguard Worker     try
60*89c4ff92SAndroid Build Coastguard Worker     {
61*89c4ff92SAndroid Build Coastguard Worker         executor = BuildExecutor(programOptions);
62*89c4ff92SAndroid Build Coastguard Worker         if ((!executor) || (executor->m_constructionFailed))
63*89c4ff92SAndroid Build Coastguard Worker         {
64*89c4ff92SAndroid Build Coastguard Worker             return EXIT_FAILURE;
65*89c4ff92SAndroid Build Coastguard Worker         }
66*89c4ff92SAndroid Build Coastguard Worker     }
67*89c4ff92SAndroid Build Coastguard Worker     catch (const std::exception& e)
68*89c4ff92SAndroid Build Coastguard Worker     {
69*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(fatal) << e.what();
70*89c4ff92SAndroid Build Coastguard Worker         return EXIT_FAILURE;
71*89c4ff92SAndroid Build Coastguard Worker     }
72*89c4ff92SAndroid Build Coastguard Worker 
73*89c4ff92SAndroid Build Coastguard Worker 
74*89c4ff92SAndroid Build Coastguard Worker     executor->PrintNetworkInfo();
75*89c4ff92SAndroid Build Coastguard Worker     outputResults = executor->Execute();
76*89c4ff92SAndroid Build Coastguard Worker 
77*89c4ff92SAndroid Build Coastguard Worker     if (!programOptions.m_ExNetParams.m_ComparisonComputeDevices.empty() ||
78*89c4ff92SAndroid Build Coastguard Worker          programOptions.m_ExNetParams.m_CompareWithTflite)
79*89c4ff92SAndroid Build Coastguard Worker     {
80*89c4ff92SAndroid Build Coastguard Worker         ExecuteNetworkParams comparisonParams = programOptions.m_ExNetParams;
81*89c4ff92SAndroid Build Coastguard Worker         comparisonParams.m_ComputeDevices = programOptions.m_ExNetParams.m_ComparisonComputeDevices;
82*89c4ff92SAndroid Build Coastguard Worker 
83*89c4ff92SAndroid Build Coastguard Worker         if (programOptions.m_ExNetParams.m_CompareWithTflite)
84*89c4ff92SAndroid Build Coastguard Worker         {
85*89c4ff92SAndroid Build Coastguard Worker             comparisonParams.m_TfLiteExecutor = ExecuteNetworkParams::TfLiteExecutor::TfliteInterpreter;
86*89c4ff92SAndroid Build Coastguard Worker         }
87*89c4ff92SAndroid Build Coastguard Worker 
88*89c4ff92SAndroid Build Coastguard Worker         auto comparisonExecutor = BuildExecutor(programOptions);
89*89c4ff92SAndroid Build Coastguard Worker 
90*89c4ff92SAndroid Build Coastguard Worker         if (!comparisonExecutor)
91*89c4ff92SAndroid Build Coastguard Worker         {
92*89c4ff92SAndroid Build Coastguard Worker             return EXIT_FAILURE;
93*89c4ff92SAndroid Build Coastguard Worker         }
94*89c4ff92SAndroid Build Coastguard Worker 
95*89c4ff92SAndroid Build Coastguard Worker         comparisonExecutor->PrintNetworkInfo();
96*89c4ff92SAndroid Build Coastguard Worker         comparisonExecutor->Execute();
97*89c4ff92SAndroid Build Coastguard Worker 
98*89c4ff92SAndroid Build Coastguard Worker         comparisonExecutor->CompareAndPrintResult(outputResults);
99*89c4ff92SAndroid Build Coastguard Worker     }
100*89c4ff92SAndroid Build Coastguard Worker }
101