xref: /aosp_15_r20/external/armnn/tests/MultipleNetworksCifar10/MultipleNetworksCifar10.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "armnn/ArmNN.hpp"
7 #include "armnn/Utils.hpp"
8 #include "armnn/INetwork.hpp"
9 #include "armnnTfLiteParser/TfLiteParser.hpp"
10 #include "../Cifar10Database.hpp"
11 #include "../InferenceTest.hpp"
12 #include "../InferenceModel.hpp"
13 
14 #include <cxxopts/cxxopts.hpp>
15 
16 #include <iostream>
17 #include <chrono>
18 #include <vector>
19 #include <array>
20 
21 
22 using namespace std;
23 using namespace std::chrono;
24 using namespace armnn::test;
25 
main(int argc,char * argv[])26 int main(int argc, char* argv[])
27 {
28 #ifdef NDEBUG
29     armnn::LogSeverity level = armnn::LogSeverity::Info;
30 #else
31     armnn::LogSeverity level = armnn::LogSeverity::Debug;
32 #endif
33 
34     try
35     {
36         // Configures logging for both the ARMNN library and this test program.
37         armnn::ConfigureLogging(true, true, level);
38 
39         std::vector<armnn::BackendId> computeDevice;
40         std::string modelDir;
41         std::string dataDir;
42 
43         const std::string backendsMessage = "Which device to run layers on by default. Possible choices: "
44                                           + armnn::BackendRegistryInstance().GetBackendIdsAsString();
45 
46         cxxopts::Options in_options("MultipleNetworksCifar10",
47                                     "Run multiple networks inference tests using Cifar-10 data.");
48 
49         try
50         {
51             // Adds generic options needed for all inference tests.
52             in_options.add_options()
53                 ("h,help", "Display help messages")
54                 ("m,model-dir", "Path to directory containing the Cifar10 model file",
55                  cxxopts::value<std::string>(modelDir))
56                 ("c,compute", backendsMessage.c_str(),
57                  cxxopts::value<std::vector<armnn::BackendId>>(computeDevice)->default_value("CpuAcc,CpuRef"))
58                 ("d,data-dir", "Path to directory containing the Cifar10 test data",
59                  cxxopts::value<std::string>(dataDir));
60 
61             auto result = in_options.parse(argc, argv);
62 
63             if(result.count("help") > 0)
64             {
65                 std::cout << in_options.help() << std::endl;
66                 return EXIT_FAILURE;
67             }
68 
69             //ensure mandatory parameters given
70             std::string mandatorySingleParameters[] = {"model-dir", "data-dir"};
71             for (auto param : mandatorySingleParameters)
72             {
73                 if(result.count(param) > 0)
74                 {
75                     std::string dir = result[param].as<std::string>();
76 
77                     if(!ValidateDirectory(dir)) {
78                         return EXIT_FAILURE;
79                     }
80                 } else {
81                     std::cerr << "Parameter \'--" << param << "\' is required but missing." << std::endl;
82                     return EXIT_FAILURE;
83                 }
84             }
85         }
86         catch (const cxxopts::OptionException& e)
87         {
88             std::cerr << e.what() << std::endl << in_options.help() << std::endl;
89             return EXIT_FAILURE;
90         }
91 
92         fs::path modelPath = fs::path(modelDir + "/cifar10_tf.prototxt");
93 
94         // Create runtime
95         // This will also load dynamic backend in case that the dynamic backend path is specified
96         armnn::IRuntime::CreationOptions options;
97         armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
98 
99         // Check if the requested backend are all valid
100         std::string invalidBackends;
101         if (!CheckRequestedBackendsAreValid(computeDevice, armnn::Optional<std::string&>(invalidBackends)))
102         {
103             ARMNN_LOG(fatal) << "The list of preferred devices contains invalid backend IDs: "
104                              << invalidBackends;
105             return EXIT_FAILURE;
106         }
107 
108         // Loads networks.
109         armnn::Status status;
110         struct Net
111         {
112             Net(armnn::NetworkId netId,
113                 const std::pair<armnn::LayerBindingId, armnn::TensorInfo>& in,
114                 const std::pair<armnn::LayerBindingId, armnn::TensorInfo>& out)
115             : m_Network(netId)
116             , m_InputBindingInfo(in)
117             , m_OutputBindingInfo(out)
118             {}
119 
120             armnn::NetworkId m_Network;
121             std::pair<armnn::LayerBindingId, armnn::TensorInfo> m_InputBindingInfo;
122             std::pair<armnn::LayerBindingId, armnn::TensorInfo> m_OutputBindingInfo;
123         };
124         std::vector<Net> networks;
125 
126         armnnTfLiteParser::ITfLiteParserPtr parser(armnnTfLiteParser::ITfLiteParserPtr::Create());
127 
128         const int networksCount = 4;
129         for (int i = 0; i < networksCount; ++i)
130         {
131             // Creates a network from a file on the disk.
132             armnn::INetworkPtr network = parser->CreateNetworkFromBinaryFile(modelPath.c_str(), {}, { "prob" });
133 
134             // Optimizes the network.
135             armnn::IOptimizedNetworkPtr optimizedNet(nullptr, nullptr);
136             try
137             {
138                 optimizedNet = armnn::Optimize(*network, computeDevice, runtime->GetDeviceSpec());
139             }
140             catch (const armnn::Exception& e)
141             {
142                 std::stringstream message;
143                 message << "armnn::Exception ("<<e.what()<<") caught from optimize.";
144                 ARMNN_LOG(fatal) << message.str();
145                 return EXIT_FAILURE;
146             }
147 
148             // Loads the network into the runtime.
149             armnn::NetworkId networkId;
150             status = runtime->LoadNetwork(networkId, std::move(optimizedNet));
151             if (status == armnn::Status::Failure)
152             {
153                 ARMNN_LOG(fatal) << "armnn::IRuntime: Failed to load network";
154                 return EXIT_FAILURE;
155             }
156 
157             networks.emplace_back(networkId,
158                 parser->GetNetworkInputBindingInfo("data"),
159                 parser->GetNetworkOutputBindingInfo("prob"));
160         }
161 
162         // Loads a test case and tests inference.
163         if (!ValidateDirectory(dataDir))
164         {
165             return EXIT_FAILURE;
166         }
167         Cifar10Database cifar10(dataDir);
168 
169         for (unsigned int i = 0; i < 3; ++i)
170         {
171             // Loads test case data (including image data).
172             std::unique_ptr<Cifar10Database::TTestCaseData> testCaseData = cifar10.GetTestCaseData(i);
173 
174             // Tests inference.
175             std::vector<TContainer> outputs;
176             outputs.reserve(networksCount);
177 
178             for (unsigned int j = 0; j < networksCount; ++j)
179             {
180                 outputs.push_back(std::vector<float>(10));
181             }
182 
183             for (unsigned int k = 0; k < networksCount; ++k)
184             {
185                 std::vector<armnn::BindingPointInfo> inputBindings  = { networks[k].m_InputBindingInfo  };
186                 std::vector<armnn::BindingPointInfo> outputBindings = { networks[k].m_OutputBindingInfo };
187 
188                 std::vector<TContainer> inputDataContainers = { testCaseData->m_InputImage };
189                 std::vector<TContainer> outputDataContainers = { outputs[k] };
190 
191                 status = runtime->EnqueueWorkload(networks[k].m_Network,
192                     armnnUtils::MakeInputTensors(inputBindings, inputDataContainers),
193                     armnnUtils::MakeOutputTensors(outputBindings, outputDataContainers));
194                 if (status == armnn::Status::Failure)
195                 {
196                     ARMNN_LOG(fatal) << "armnn::IRuntime: Failed to enqueue workload";
197                     return EXIT_FAILURE;
198                 }
199             }
200 
201             // Compares outputs.
202             std::vector<float> output0 = mapbox::util::get<std::vector<float>>(outputs[0]);
203 
204             for (unsigned int k = 1; k < networksCount; ++k)
205             {
206                 std::vector<float> outputK = mapbox::util::get<std::vector<float>>(outputs[k]);
207 
208                 if (!std::equal(output0.begin(), output0.end(), outputK.begin(), outputK.end()))
209                 {
210                     ARMNN_LOG(error) << "Multiple networks inference failed!";
211                     return EXIT_FAILURE;
212                 }
213             }
214         }
215 
216         ARMNN_LOG(info) << "Multiple networks inference ran successfully!";
217         return EXIT_SUCCESS;
218     }
219     catch (const armnn::Exception& e)
220     {
221         // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
222         // exception of type std::length_error.
223         // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
224         std::cerr << "Armnn Error: " << e.what() << std::endl;
225         return EXIT_FAILURE;
226     }
227     catch (const std::exception& e)
228     {
229         // Coverity fix: various boost exceptions can be thrown by methods called by this test.
230         std::cerr << "WARNING: MultipleNetworksCifar10: An error has occurred when running the "
231                      "multiple networks inference tests: " << e.what() << std::endl;
232         return EXIT_FAILURE;
233     }
234 }
235