xref: /aosp_15_r20/external/armnn/tests/ExecuteNetwork/ExecuteNetwork.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ExecuteNetworkProgramOptions.hpp"
7 #include "ArmNNExecutor.hpp"
8 #if defined(ARMNN_TFLITE_DELEGATE)
9 #include "TfliteExecutor.hpp"
10 #endif
11 #include <armnn/Logging.hpp>
12 
13 
BuildExecutor(ProgramOptions & programOptions)14 std::unique_ptr<IExecutor> BuildExecutor(ProgramOptions& programOptions)
15 {
16     if (programOptions.m_ExNetParams.m_TfLiteExecutor == ExecuteNetworkParams::TfLiteExecutor::ArmNNTfLiteDelegate ||
17         programOptions.m_ExNetParams.m_TfLiteExecutor == ExecuteNetworkParams::TfLiteExecutor::TfliteInterpreter)
18     {
19 #if defined(ARMNN_TFLITE_DELEGATE)
20         return std::make_unique<TfLiteExecutor>(programOptions.m_ExNetParams, programOptions.m_RuntimeOptions);
21 #else
22         ARMNN_LOG(fatal) << "Not built with Arm NN Tensorflow-Lite delegate support.";
23         return nullptr;
24 #endif
25     }
26     else
27     {
28         return std::make_unique<ArmNNExecutor>(programOptions.m_ExNetParams, programOptions.m_RuntimeOptions);
29     }
30 }
31 
32 // MAIN
main(int argc,const char * argv[])33 int main(int argc, const char* argv[])
34 {
35     // Configures logging for both the ARMNN library and this test program.
36 #ifdef NDEBUG
37     armnn::LogSeverity level = armnn::LogSeverity::Info;
38 #else
39     armnn::LogSeverity level = armnn::LogSeverity::Debug;
40 #endif
41     armnn::ConfigureLogging(true, true, level);
42 
43 
44     // Get ExecuteNetwork parameters and runtime options from command line
45     // This might throw an InvalidArgumentException if the user provided invalid inputs
46     ProgramOptions programOptions;
47     try
48     {
49         programOptions.ParseOptions(argc, argv);
50     }
51     catch (const std::exception& e)
52     {
53         ARMNN_LOG(fatal) << e.what();
54         return EXIT_FAILURE;
55     }
56 
57     std::vector<const void*> outputResults;
58     std::unique_ptr<IExecutor> executor;
59     try
60     {
61         executor = BuildExecutor(programOptions);
62         if ((!executor) || (executor->m_constructionFailed))
63         {
64             return EXIT_FAILURE;
65         }
66     }
67     catch (const std::exception& e)
68     {
69         ARMNN_LOG(fatal) << e.what();
70         return EXIT_FAILURE;
71     }
72 
73 
74     executor->PrintNetworkInfo();
75     outputResults = executor->Execute();
76 
77     if (!programOptions.m_ExNetParams.m_ComparisonComputeDevices.empty() ||
78          programOptions.m_ExNetParams.m_CompareWithTflite)
79     {
80         ExecuteNetworkParams comparisonParams = programOptions.m_ExNetParams;
81         comparisonParams.m_ComputeDevices = programOptions.m_ExNetParams.m_ComparisonComputeDevices;
82 
83         if (programOptions.m_ExNetParams.m_CompareWithTflite)
84         {
85             comparisonParams.m_TfLiteExecutor = ExecuteNetworkParams::TfLiteExecutor::TfliteInterpreter;
86         }
87 
88         auto comparisonExecutor = BuildExecutor(programOptions);
89 
90         if (!comparisonExecutor)
91         {
92             return EXIT_FAILURE;
93         }
94 
95         comparisonExecutor->PrintNetworkInfo();
96         comparisonExecutor->Execute();
97 
98         comparisonExecutor->CompareAndPrintResult(outputResults);
99     }
100 }
101