xref: /aosp_15_r20/external/armnn/tests/InferenceTest.inl (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker//
2*89c4ff92SAndroid Build Coastguard Worker// Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker// SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker//
5*89c4ff92SAndroid Build Coastguard Worker#include "InferenceTest.hpp"
6*89c4ff92SAndroid Build Coastguard Worker
7*89c4ff92SAndroid Build Coastguard Worker#include <armnn/Utils.hpp>
8*89c4ff92SAndroid Build Coastguard Worker#include <armnn/utility/Assert.hpp>
9*89c4ff92SAndroid Build Coastguard Worker#include <armnn/utility/NumericCast.hpp>
10*89c4ff92SAndroid Build Coastguard Worker#include <armnnUtils/TContainer.hpp>
11*89c4ff92SAndroid Build Coastguard Worker
12*89c4ff92SAndroid Build Coastguard Worker#include "CxxoptsUtils.hpp"
13*89c4ff92SAndroid Build Coastguard Worker
14*89c4ff92SAndroid Build Coastguard Worker#include <cxxopts/cxxopts.hpp>
15*89c4ff92SAndroid Build Coastguard Worker#include <fmt/format.h>
16*89c4ff92SAndroid Build Coastguard Worker
17*89c4ff92SAndroid Build Coastguard Worker#include <fstream>
18*89c4ff92SAndroid Build Coastguard Worker#include <iostream>
19*89c4ff92SAndroid Build Coastguard Worker#include <iomanip>
20*89c4ff92SAndroid Build Coastguard Worker#include <array>
21*89c4ff92SAndroid Build Coastguard Worker#include <chrono>
22*89c4ff92SAndroid Build Coastguard Worker
23*89c4ff92SAndroid Build Coastguard Workerusing namespace std;
24*89c4ff92SAndroid Build Coastguard Workerusing namespace std::chrono;
25*89c4ff92SAndroid Build Coastguard Workerusing namespace armnn::test;
26*89c4ff92SAndroid Build Coastguard Worker
27*89c4ff92SAndroid Build Coastguard Workernamespace armnn
28*89c4ff92SAndroid Build Coastguard Worker{
29*89c4ff92SAndroid Build Coastguard Workernamespace test
30*89c4ff92SAndroid Build Coastguard Worker{
31*89c4ff92SAndroid Build Coastguard Worker
32*89c4ff92SAndroid Build Coastguard Workertemplate <typename TTestCaseDatabase, typename TModel>
33*89c4ff92SAndroid Build Coastguard WorkerClassifierTestCase<TTestCaseDatabase, TModel>::ClassifierTestCase(
34*89c4ff92SAndroid Build Coastguard Worker    int& numInferencesRef,
35*89c4ff92SAndroid Build Coastguard Worker    int& numCorrectInferencesRef,
36*89c4ff92SAndroid Build Coastguard Worker    const std::vector<unsigned int>& validationPredictions,
37*89c4ff92SAndroid Build Coastguard Worker    std::vector<unsigned int>* validationPredictionsOut,
38*89c4ff92SAndroid Build Coastguard Worker    TModel& model,
39*89c4ff92SAndroid Build Coastguard Worker    unsigned int testCaseId,
40*89c4ff92SAndroid Build Coastguard Worker    unsigned int label,
41*89c4ff92SAndroid Build Coastguard Worker    std::vector<typename TModel::DataType> modelInput)
42*89c4ff92SAndroid Build Coastguard Worker    : InferenceModelTestCase<TModel>(
43*89c4ff92SAndroid Build Coastguard Worker            model, testCaseId, std::vector<armnnUtils::TContainer>{ modelInput }, { model.GetOutputSize() })
44*89c4ff92SAndroid Build Coastguard Worker    , m_Label(label)
45*89c4ff92SAndroid Build Coastguard Worker    , m_QuantizationParams(model.GetQuantizationParams())
46*89c4ff92SAndroid Build Coastguard Worker    , m_NumInferencesRef(numInferencesRef)
47*89c4ff92SAndroid Build Coastguard Worker    , m_NumCorrectInferencesRef(numCorrectInferencesRef)
48*89c4ff92SAndroid Build Coastguard Worker    , m_ValidationPredictions(validationPredictions)
49*89c4ff92SAndroid Build Coastguard Worker    , m_ValidationPredictionsOut(validationPredictionsOut)
50*89c4ff92SAndroid Build Coastguard Worker{
51*89c4ff92SAndroid Build Coastguard Worker}
52*89c4ff92SAndroid Build Coastguard Worker
53*89c4ff92SAndroid Build Coastguard Workerstruct ClassifierResultProcessor
54*89c4ff92SAndroid Build Coastguard Worker{
55*89c4ff92SAndroid Build Coastguard Worker    using ResultMap = std::map<float,int>;
56*89c4ff92SAndroid Build Coastguard Worker
57*89c4ff92SAndroid Build Coastguard Worker    ClassifierResultProcessor(float scale, int offset)
58*89c4ff92SAndroid Build Coastguard Worker        : m_Scale(scale)
59*89c4ff92SAndroid Build Coastguard Worker        , m_Offset(offset)
60*89c4ff92SAndroid Build Coastguard Worker    {}
61*89c4ff92SAndroid Build Coastguard Worker
62*89c4ff92SAndroid Build Coastguard Worker    void operator()(const std::vector<float>& values)
63*89c4ff92SAndroid Build Coastguard Worker    {
64*89c4ff92SAndroid Build Coastguard Worker        SortPredictions(values, [](float value)
65*89c4ff92SAndroid Build Coastguard Worker                                {
66*89c4ff92SAndroid Build Coastguard Worker                                    return value;
67*89c4ff92SAndroid Build Coastguard Worker                                });
68*89c4ff92SAndroid Build Coastguard Worker    }
69*89c4ff92SAndroid Build Coastguard Worker
70*89c4ff92SAndroid Build Coastguard Worker    void operator()(const std::vector<int8_t>& values)
71*89c4ff92SAndroid Build Coastguard Worker    {
72*89c4ff92SAndroid Build Coastguard Worker        SortPredictions(values, [](int8_t value)
73*89c4ff92SAndroid Build Coastguard Worker        {
74*89c4ff92SAndroid Build Coastguard Worker            return value;
75*89c4ff92SAndroid Build Coastguard Worker        });
76*89c4ff92SAndroid Build Coastguard Worker    }
77*89c4ff92SAndroid Build Coastguard Worker
78*89c4ff92SAndroid Build Coastguard Worker    void operator()(const std::vector<uint8_t>& values)
79*89c4ff92SAndroid Build Coastguard Worker    {
80*89c4ff92SAndroid Build Coastguard Worker        auto& scale = m_Scale;
81*89c4ff92SAndroid Build Coastguard Worker        auto& offset = m_Offset;
82*89c4ff92SAndroid Build Coastguard Worker        SortPredictions(values, [&scale, &offset](uint8_t value)
83*89c4ff92SAndroid Build Coastguard Worker                                {
84*89c4ff92SAndroid Build Coastguard Worker                                    return armnn::Dequantize(value, scale, offset);
85*89c4ff92SAndroid Build Coastguard Worker                                });
86*89c4ff92SAndroid Build Coastguard Worker    }
87*89c4ff92SAndroid Build Coastguard Worker
88*89c4ff92SAndroid Build Coastguard Worker    void operator()(const std::vector<int>& values)
89*89c4ff92SAndroid Build Coastguard Worker    {
90*89c4ff92SAndroid Build Coastguard Worker        IgnoreUnused(values);
91*89c4ff92SAndroid Build Coastguard Worker        ARMNN_ASSERT_MSG(false, "Non-float predictions output not supported.");
92*89c4ff92SAndroid Build Coastguard Worker    }
93*89c4ff92SAndroid Build Coastguard Worker
94*89c4ff92SAndroid Build Coastguard Worker    ResultMap& GetResultMap() { return m_ResultMap; }
95*89c4ff92SAndroid Build Coastguard Worker
96*89c4ff92SAndroid Build Coastguard Workerprivate:
97*89c4ff92SAndroid Build Coastguard Worker    template<typename Container, typename Delegate>
98*89c4ff92SAndroid Build Coastguard Worker    void SortPredictions(const Container& c, Delegate delegate)
99*89c4ff92SAndroid Build Coastguard Worker    {
100*89c4ff92SAndroid Build Coastguard Worker        int index = 0;
101*89c4ff92SAndroid Build Coastguard Worker        for (const auto& value : c)
102*89c4ff92SAndroid Build Coastguard Worker        {
103*89c4ff92SAndroid Build Coastguard Worker            int classification = index++;
104*89c4ff92SAndroid Build Coastguard Worker            // Take the first class with each probability
105*89c4ff92SAndroid Build Coastguard Worker            // This avoids strange results when looping over batched results produced
106*89c4ff92SAndroid Build Coastguard Worker            // with identical test data.
107*89c4ff92SAndroid Build Coastguard Worker            ResultMap::iterator lb = m_ResultMap.lower_bound(value);
108*89c4ff92SAndroid Build Coastguard Worker
109*89c4ff92SAndroid Build Coastguard Worker            if (lb == m_ResultMap.end() || !m_ResultMap.key_comp()(value, lb->first))
110*89c4ff92SAndroid Build Coastguard Worker            {
111*89c4ff92SAndroid Build Coastguard Worker                // If the key is not already in the map, insert it.
112*89c4ff92SAndroid Build Coastguard Worker                m_ResultMap.insert(lb, ResultMap::value_type(delegate(value), classification));
113*89c4ff92SAndroid Build Coastguard Worker            }
114*89c4ff92SAndroid Build Coastguard Worker        }
115*89c4ff92SAndroid Build Coastguard Worker    }
116*89c4ff92SAndroid Build Coastguard Worker
117*89c4ff92SAndroid Build Coastguard Worker    ResultMap m_ResultMap;
118*89c4ff92SAndroid Build Coastguard Worker
119*89c4ff92SAndroid Build Coastguard Worker    float m_Scale=0.0f;
120*89c4ff92SAndroid Build Coastguard Worker    int m_Offset=0;
121*89c4ff92SAndroid Build Coastguard Worker};
122*89c4ff92SAndroid Build Coastguard Worker
123*89c4ff92SAndroid Build Coastguard Workertemplate <typename TTestCaseDatabase, typename TModel>
124*89c4ff92SAndroid Build Coastguard WorkerTestCaseResult ClassifierTestCase<TTestCaseDatabase, TModel>::ProcessResult(const InferenceTestOptions& params)
125*89c4ff92SAndroid Build Coastguard Worker{
126*89c4ff92SAndroid Build Coastguard Worker    auto& output = this->GetOutputs()[0];
127*89c4ff92SAndroid Build Coastguard Worker    const auto testCaseId = this->GetTestCaseId();
128*89c4ff92SAndroid Build Coastguard Worker
129*89c4ff92SAndroid Build Coastguard Worker    ClassifierResultProcessor resultProcessor(m_QuantizationParams.first, m_QuantizationParams.second);
130*89c4ff92SAndroid Build Coastguard Worker    mapbox::util::apply_visitor(resultProcessor, output);
131*89c4ff92SAndroid Build Coastguard Worker
132*89c4ff92SAndroid Build Coastguard Worker    ARMNN_LOG(info) << "= Prediction values for test #" << testCaseId;
133*89c4ff92SAndroid Build Coastguard Worker    auto it = resultProcessor.GetResultMap().rbegin();
134*89c4ff92SAndroid Build Coastguard Worker    for (int i=0; i<5 && it != resultProcessor.GetResultMap().rend(); ++i)
135*89c4ff92SAndroid Build Coastguard Worker    {
136*89c4ff92SAndroid Build Coastguard Worker        ARMNN_LOG(info) << "Top(" << (i+1) << ") prediction is " << it->second <<
137*89c4ff92SAndroid Build Coastguard Worker          " with value: " << (it->first);
138*89c4ff92SAndroid Build Coastguard Worker        ++it;
139*89c4ff92SAndroid Build Coastguard Worker    }
140*89c4ff92SAndroid Build Coastguard Worker
141*89c4ff92SAndroid Build Coastguard Worker    unsigned int prediction = 0;
142*89c4ff92SAndroid Build Coastguard Worker    mapbox::util::apply_visitor([&](auto&& value)
143*89c4ff92SAndroid Build Coastguard Worker                         {
144*89c4ff92SAndroid Build Coastguard Worker                             prediction = armnn::numeric_cast<unsigned int>(
145*89c4ff92SAndroid Build Coastguard Worker                                     std::distance(value.begin(), std::max_element(value.begin(), value.end())));
146*89c4ff92SAndroid Build Coastguard Worker                         },
147*89c4ff92SAndroid Build Coastguard Worker                         output);
148*89c4ff92SAndroid Build Coastguard Worker
149*89c4ff92SAndroid Build Coastguard Worker    // If we're just running the defaultTestCaseIds, each one must be classified correctly.
150*89c4ff92SAndroid Build Coastguard Worker    if (params.m_IterationCount == 0 && prediction != m_Label)
151*89c4ff92SAndroid Build Coastguard Worker    {
152*89c4ff92SAndroid Build Coastguard Worker        ARMNN_LOG(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
153*89c4ff92SAndroid Build Coastguard Worker            " is incorrect (should be " << m_Label << ")";
154*89c4ff92SAndroid Build Coastguard Worker        return TestCaseResult::Failed;
155*89c4ff92SAndroid Build Coastguard Worker    }
156*89c4ff92SAndroid Build Coastguard Worker
157*89c4ff92SAndroid Build Coastguard Worker    // If a validation file was provided as input, it checks that the prediction matches.
158*89c4ff92SAndroid Build Coastguard Worker    if (!m_ValidationPredictions.empty() && prediction != m_ValidationPredictions[testCaseId])
159*89c4ff92SAndroid Build Coastguard Worker    {
160*89c4ff92SAndroid Build Coastguard Worker        ARMNN_LOG(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
161*89c4ff92SAndroid Build Coastguard Worker            " doesn't match the prediction in the validation file (" << m_ValidationPredictions[testCaseId] << ")";
162*89c4ff92SAndroid Build Coastguard Worker        return TestCaseResult::Failed;
163*89c4ff92SAndroid Build Coastguard Worker    }
164*89c4ff92SAndroid Build Coastguard Worker
165*89c4ff92SAndroid Build Coastguard Worker    // If a validation file was requested as output, it stores the predictions.
166*89c4ff92SAndroid Build Coastguard Worker    if (m_ValidationPredictionsOut)
167*89c4ff92SAndroid Build Coastguard Worker    {
168*89c4ff92SAndroid Build Coastguard Worker        m_ValidationPredictionsOut->push_back(prediction);
169*89c4ff92SAndroid Build Coastguard Worker    }
170*89c4ff92SAndroid Build Coastguard Worker
171*89c4ff92SAndroid Build Coastguard Worker    // Updates accuracy stats.
172*89c4ff92SAndroid Build Coastguard Worker    m_NumInferencesRef++;
173*89c4ff92SAndroid Build Coastguard Worker    if (prediction == m_Label)
174*89c4ff92SAndroid Build Coastguard Worker    {
175*89c4ff92SAndroid Build Coastguard Worker        m_NumCorrectInferencesRef++;
176*89c4ff92SAndroid Build Coastguard Worker    }
177*89c4ff92SAndroid Build Coastguard Worker
178*89c4ff92SAndroid Build Coastguard Worker    return TestCaseResult::Ok;
179*89c4ff92SAndroid Build Coastguard Worker}
180*89c4ff92SAndroid Build Coastguard Worker
181*89c4ff92SAndroid Build Coastguard Workertemplate <typename TDatabase, typename InferenceModel>
182*89c4ff92SAndroid Build Coastguard Workertemplate <typename TConstructDatabaseCallable, typename TConstructModelCallable>
183*89c4ff92SAndroid Build Coastguard WorkerClassifierTestCaseProvider<TDatabase, InferenceModel>::ClassifierTestCaseProvider(
184*89c4ff92SAndroid Build Coastguard Worker    TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel)
185*89c4ff92SAndroid Build Coastguard Worker    : m_ConstructModel(constructModel)
186*89c4ff92SAndroid Build Coastguard Worker    , m_ConstructDatabase(constructDatabase)
187*89c4ff92SAndroid Build Coastguard Worker    , m_NumInferences(0)
188*89c4ff92SAndroid Build Coastguard Worker    , m_NumCorrectInferences(0)
189*89c4ff92SAndroid Build Coastguard Worker{
190*89c4ff92SAndroid Build Coastguard Worker}
191*89c4ff92SAndroid Build Coastguard Worker
192*89c4ff92SAndroid Build Coastguard Workertemplate <typename TDatabase, typename InferenceModel>
193*89c4ff92SAndroid Build Coastguard Workervoid ClassifierTestCaseProvider<TDatabase, InferenceModel>::AddCommandLineOptions(
194*89c4ff92SAndroid Build Coastguard Worker    cxxopts::Options& options, std::vector<std::string>& required)
195*89c4ff92SAndroid Build Coastguard Worker{
196*89c4ff92SAndroid Build Coastguard Worker    options
197*89c4ff92SAndroid Build Coastguard Worker        .allow_unrecognised_options()
198*89c4ff92SAndroid Build Coastguard Worker        .add_options()
199*89c4ff92SAndroid Build Coastguard Worker            ("validation-file-in",
200*89c4ff92SAndroid Build Coastguard Worker             "Reads expected predictions from the given file and confirms they match the actual predictions.",
201*89c4ff92SAndroid Build Coastguard Worker             cxxopts::value<std::string>(m_ValidationFileIn)->default_value(""))
202*89c4ff92SAndroid Build Coastguard Worker            ("validation-file-out", "Predictions are saved to the given file for later use via --validation-file-in.",
203*89c4ff92SAndroid Build Coastguard Worker             cxxopts::value<std::string>(m_ValidationFileOut)->default_value(""))
204*89c4ff92SAndroid Build Coastguard Worker            ("d,data-dir", "Path to directory containing test data", cxxopts::value<std::string>(m_DataDir));
205*89c4ff92SAndroid Build Coastguard Worker
206*89c4ff92SAndroid Build Coastguard Worker    required.emplace_back("data-dir"); //add to required arguments to check
207*89c4ff92SAndroid Build Coastguard Worker
208*89c4ff92SAndroid Build Coastguard Worker    InferenceModel::AddCommandLineOptions(options, m_ModelCommandLineOptions, required);
209*89c4ff92SAndroid Build Coastguard Worker}
210*89c4ff92SAndroid Build Coastguard Worker
211*89c4ff92SAndroid Build Coastguard Workertemplate <typename TDatabase, typename InferenceModel>
212*89c4ff92SAndroid Build Coastguard Workerbool ClassifierTestCaseProvider<TDatabase, InferenceModel>::ProcessCommandLineOptions(
213*89c4ff92SAndroid Build Coastguard Worker        const InferenceTestOptions& commonOptions)
214*89c4ff92SAndroid Build Coastguard Worker{
215*89c4ff92SAndroid Build Coastguard Worker    if (!ValidateDirectory(m_DataDir))
216*89c4ff92SAndroid Build Coastguard Worker    {
217*89c4ff92SAndroid Build Coastguard Worker        return false;
218*89c4ff92SAndroid Build Coastguard Worker    }
219*89c4ff92SAndroid Build Coastguard Worker
220*89c4ff92SAndroid Build Coastguard Worker    ReadPredictions();
221*89c4ff92SAndroid Build Coastguard Worker
222*89c4ff92SAndroid Build Coastguard Worker    m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
223*89c4ff92SAndroid Build Coastguard Worker    if (!m_Model)
224*89c4ff92SAndroid Build Coastguard Worker    {
225*89c4ff92SAndroid Build Coastguard Worker        return false;
226*89c4ff92SAndroid Build Coastguard Worker    }
227*89c4ff92SAndroid Build Coastguard Worker
228*89c4ff92SAndroid Build Coastguard Worker    m_Database = std::make_unique<TDatabase>(m_ConstructDatabase(m_DataDir.c_str(), *m_Model));
229*89c4ff92SAndroid Build Coastguard Worker    if (!m_Database)
230*89c4ff92SAndroid Build Coastguard Worker    {
231*89c4ff92SAndroid Build Coastguard Worker        return false;
232*89c4ff92SAndroid Build Coastguard Worker    }
233*89c4ff92SAndroid Build Coastguard Worker
234*89c4ff92SAndroid Build Coastguard Worker    return true;
235*89c4ff92SAndroid Build Coastguard Worker}
236*89c4ff92SAndroid Build Coastguard Worker
237*89c4ff92SAndroid Build Coastguard Workertemplate <typename TDatabase, typename InferenceModel>
238*89c4ff92SAndroid Build Coastguard Workerstd::unique_ptr<IInferenceTestCase>
239*89c4ff92SAndroid Build Coastguard WorkerClassifierTestCaseProvider<TDatabase, InferenceModel>::GetTestCase(unsigned int testCaseId)
240*89c4ff92SAndroid Build Coastguard Worker{
241*89c4ff92SAndroid Build Coastguard Worker    std::unique_ptr<typename TDatabase::TTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
242*89c4ff92SAndroid Build Coastguard Worker    if (testCaseData == nullptr)
243*89c4ff92SAndroid Build Coastguard Worker    {
244*89c4ff92SAndroid Build Coastguard Worker        return nullptr;
245*89c4ff92SAndroid Build Coastguard Worker    }
246*89c4ff92SAndroid Build Coastguard Worker
247*89c4ff92SAndroid Build Coastguard Worker    return std::make_unique<ClassifierTestCase<TDatabase, InferenceModel>>(
248*89c4ff92SAndroid Build Coastguard Worker        m_NumInferences,
249*89c4ff92SAndroid Build Coastguard Worker        m_NumCorrectInferences,
250*89c4ff92SAndroid Build Coastguard Worker        m_ValidationPredictions,
251*89c4ff92SAndroid Build Coastguard Worker        m_ValidationFileOut.empty() ? nullptr : &m_ValidationPredictionsOut,
252*89c4ff92SAndroid Build Coastguard Worker        *m_Model,
253*89c4ff92SAndroid Build Coastguard Worker        testCaseId,
254*89c4ff92SAndroid Build Coastguard Worker        testCaseData->m_Label,
255*89c4ff92SAndroid Build Coastguard Worker        std::move(testCaseData->m_InputImage));
256*89c4ff92SAndroid Build Coastguard Worker}
257*89c4ff92SAndroid Build Coastguard Worker
258*89c4ff92SAndroid Build Coastguard Workertemplate <typename TDatabase, typename InferenceModel>
259*89c4ff92SAndroid Build Coastguard Workerbool ClassifierTestCaseProvider<TDatabase, InferenceModel>::OnInferenceTestFinished()
260*89c4ff92SAndroid Build Coastguard Worker{
261*89c4ff92SAndroid Build Coastguard Worker    const double accuracy = armnn::numeric_cast<double>(m_NumCorrectInferences) /
262*89c4ff92SAndroid Build Coastguard Worker        armnn::numeric_cast<double>(m_NumInferences);
263*89c4ff92SAndroid Build Coastguard Worker    ARMNN_LOG(info) << std::fixed << std::setprecision(3) << "Overall accuracy: " << accuracy;
264*89c4ff92SAndroid Build Coastguard Worker
265*89c4ff92SAndroid Build Coastguard Worker    // If a validation file was requested as output, the predictions are saved to it.
266*89c4ff92SAndroid Build Coastguard Worker    if (!m_ValidationFileOut.empty())
267*89c4ff92SAndroid Build Coastguard Worker    {
268*89c4ff92SAndroid Build Coastguard Worker        std::ofstream validationFileOut(m_ValidationFileOut.c_str(), std::ios_base::trunc | std::ios_base::out);
269*89c4ff92SAndroid Build Coastguard Worker        if (validationFileOut.good())
270*89c4ff92SAndroid Build Coastguard Worker        {
271*89c4ff92SAndroid Build Coastguard Worker            for (const unsigned int prediction : m_ValidationPredictionsOut)
272*89c4ff92SAndroid Build Coastguard Worker            {
273*89c4ff92SAndroid Build Coastguard Worker                validationFileOut << prediction << std::endl;
274*89c4ff92SAndroid Build Coastguard Worker            }
275*89c4ff92SAndroid Build Coastguard Worker        }
276*89c4ff92SAndroid Build Coastguard Worker        else
277*89c4ff92SAndroid Build Coastguard Worker        {
278*89c4ff92SAndroid Build Coastguard Worker            ARMNN_LOG(error) << "Failed to open output validation file: " << m_ValidationFileOut;
279*89c4ff92SAndroid Build Coastguard Worker            return false;
280*89c4ff92SAndroid Build Coastguard Worker        }
281*89c4ff92SAndroid Build Coastguard Worker    }
282*89c4ff92SAndroid Build Coastguard Worker
283*89c4ff92SAndroid Build Coastguard Worker    return true;
284*89c4ff92SAndroid Build Coastguard Worker}
285*89c4ff92SAndroid Build Coastguard Worker
286*89c4ff92SAndroid Build Coastguard Workertemplate <typename TDatabase, typename InferenceModel>
287*89c4ff92SAndroid Build Coastguard Workervoid ClassifierTestCaseProvider<TDatabase, InferenceModel>::ReadPredictions()
288*89c4ff92SAndroid Build Coastguard Worker{
289*89c4ff92SAndroid Build Coastguard Worker    // Reads the expected predictions from the input validation file (if provided).
290*89c4ff92SAndroid Build Coastguard Worker    if (!m_ValidationFileIn.empty())
291*89c4ff92SAndroid Build Coastguard Worker    {
292*89c4ff92SAndroid Build Coastguard Worker        std::ifstream validationFileIn(m_ValidationFileIn.c_str(), std::ios_base::in);
293*89c4ff92SAndroid Build Coastguard Worker        if (validationFileIn.good())
294*89c4ff92SAndroid Build Coastguard Worker        {
295*89c4ff92SAndroid Build Coastguard Worker            while (!validationFileIn.eof())
296*89c4ff92SAndroid Build Coastguard Worker            {
297*89c4ff92SAndroid Build Coastguard Worker                unsigned int i;
298*89c4ff92SAndroid Build Coastguard Worker                validationFileIn >> i;
299*89c4ff92SAndroid Build Coastguard Worker                m_ValidationPredictions.emplace_back(i);
300*89c4ff92SAndroid Build Coastguard Worker            }
301*89c4ff92SAndroid Build Coastguard Worker        }
302*89c4ff92SAndroid Build Coastguard Worker        else
303*89c4ff92SAndroid Build Coastguard Worker        {
304*89c4ff92SAndroid Build Coastguard Worker            throw armnn::Exception(fmt::format("Failed to open input validation file: {}"
305*89c4ff92SAndroid Build Coastguard Worker                , m_ValidationFileIn));
306*89c4ff92SAndroid Build Coastguard Worker        }
307*89c4ff92SAndroid Build Coastguard Worker    }
308*89c4ff92SAndroid Build Coastguard Worker}
309*89c4ff92SAndroid Build Coastguard Worker
310*89c4ff92SAndroid Build Coastguard Workertemplate<typename TConstructTestCaseProvider>
311*89c4ff92SAndroid Build Coastguard Workerint InferenceTestMain(int argc,
312*89c4ff92SAndroid Build Coastguard Worker    char* argv[],
313*89c4ff92SAndroid Build Coastguard Worker    const std::vector<unsigned int>& defaultTestCaseIds,
314*89c4ff92SAndroid Build Coastguard Worker    TConstructTestCaseProvider constructTestCaseProvider)
315*89c4ff92SAndroid Build Coastguard Worker{
316*89c4ff92SAndroid Build Coastguard Worker    // Configures logging for both the ARMNN library and this test program.
317*89c4ff92SAndroid Build Coastguard Worker#ifdef NDEBUG
318*89c4ff92SAndroid Build Coastguard Worker    armnn::LogSeverity level = armnn::LogSeverity::Info;
319*89c4ff92SAndroid Build Coastguard Worker#else
320*89c4ff92SAndroid Build Coastguard Worker    armnn::LogSeverity level = armnn::LogSeverity::Debug;
321*89c4ff92SAndroid Build Coastguard Worker#endif
322*89c4ff92SAndroid Build Coastguard Worker    armnn::ConfigureLogging(true, true, level);
323*89c4ff92SAndroid Build Coastguard Worker
324*89c4ff92SAndroid Build Coastguard Worker    try
325*89c4ff92SAndroid Build Coastguard Worker    {
326*89c4ff92SAndroid Build Coastguard Worker        std::unique_ptr<IInferenceTestCaseProvider> testCaseProvider = constructTestCaseProvider();
327*89c4ff92SAndroid Build Coastguard Worker        if (!testCaseProvider)
328*89c4ff92SAndroid Build Coastguard Worker        {
329*89c4ff92SAndroid Build Coastguard Worker            return 1;
330*89c4ff92SAndroid Build Coastguard Worker        }
331*89c4ff92SAndroid Build Coastguard Worker
332*89c4ff92SAndroid Build Coastguard Worker        InferenceTestOptions inferenceTestOptions;
333*89c4ff92SAndroid Build Coastguard Worker        if (!ParseCommandLine(argc, argv, *testCaseProvider, inferenceTestOptions))
334*89c4ff92SAndroid Build Coastguard Worker        {
335*89c4ff92SAndroid Build Coastguard Worker            return 1;
336*89c4ff92SAndroid Build Coastguard Worker        }
337*89c4ff92SAndroid Build Coastguard Worker
338*89c4ff92SAndroid Build Coastguard Worker        const bool success = InferenceTest(inferenceTestOptions, defaultTestCaseIds, *testCaseProvider);
339*89c4ff92SAndroid Build Coastguard Worker        return success ? 0 : 1;
340*89c4ff92SAndroid Build Coastguard Worker    }
341*89c4ff92SAndroid Build Coastguard Worker    catch (armnn::Exception const& e)
342*89c4ff92SAndroid Build Coastguard Worker    {
343*89c4ff92SAndroid Build Coastguard Worker        ARMNN_LOG(fatal) << "Armnn Error: " << e.what();
344*89c4ff92SAndroid Build Coastguard Worker        return 1;
345*89c4ff92SAndroid Build Coastguard Worker    }
346*89c4ff92SAndroid Build Coastguard Worker}
347*89c4ff92SAndroid Build Coastguard Worker
348*89c4ff92SAndroid Build Coastguard Worker//
349*89c4ff92SAndroid Build Coastguard Worker// This function allows us to create a classifier inference test based on:
350*89c4ff92SAndroid Build Coastguard Worker//  - a model file name
351*89c4ff92SAndroid Build Coastguard Worker//  - which can be a binary or a text file for protobuf formats
352*89c4ff92SAndroid Build Coastguard Worker//  - an input tensor name
353*89c4ff92SAndroid Build Coastguard Worker//  - an output tensor name
354*89c4ff92SAndroid Build Coastguard Worker//  - a set of test case ids
355*89c4ff92SAndroid Build Coastguard Worker//  - a callback method which creates an object that can return images
356*89c4ff92SAndroid Build Coastguard Worker//    called 'Database' in these tests
357*89c4ff92SAndroid Build Coastguard Worker//  - and an input tensor shape
358*89c4ff92SAndroid Build Coastguard Worker//
359*89c4ff92SAndroid Build Coastguard Workertemplate<typename TDatabase,
360*89c4ff92SAndroid Build Coastguard Worker         typename TParser,
361*89c4ff92SAndroid Build Coastguard Worker         typename TConstructDatabaseCallable>
362*89c4ff92SAndroid Build Coastguard Workerint ClassifierInferenceTestMain(int argc,
363*89c4ff92SAndroid Build Coastguard Worker                                char* argv[],
364*89c4ff92SAndroid Build Coastguard Worker                                const char* modelFilename,
365*89c4ff92SAndroid Build Coastguard Worker                                bool isModelBinary,
366*89c4ff92SAndroid Build Coastguard Worker                                const char* inputBindingName,
367*89c4ff92SAndroid Build Coastguard Worker                                const char* outputBindingName,
368*89c4ff92SAndroid Build Coastguard Worker                                const std::vector<unsigned int>& defaultTestCaseIds,
369*89c4ff92SAndroid Build Coastguard Worker                                TConstructDatabaseCallable constructDatabase,
370*89c4ff92SAndroid Build Coastguard Worker                                const armnn::TensorShape* inputTensorShape)
371*89c4ff92SAndroid Build Coastguard Worker
372*89c4ff92SAndroid Build Coastguard Worker{
373*89c4ff92SAndroid Build Coastguard Worker    ARMNN_ASSERT(modelFilename);
374*89c4ff92SAndroid Build Coastguard Worker    ARMNN_ASSERT(inputBindingName);
375*89c4ff92SAndroid Build Coastguard Worker    ARMNN_ASSERT(outputBindingName);
376*89c4ff92SAndroid Build Coastguard Worker
377*89c4ff92SAndroid Build Coastguard Worker    return InferenceTestMain(argc, argv, defaultTestCaseIds,
378*89c4ff92SAndroid Build Coastguard Worker        [=]
379*89c4ff92SAndroid Build Coastguard Worker        ()
380*89c4ff92SAndroid Build Coastguard Worker        {
381*89c4ff92SAndroid Build Coastguard Worker            using InferenceModel = InferenceModel<TParser, typename TDatabase::DataType>;
382*89c4ff92SAndroid Build Coastguard Worker            using TestCaseProvider = ClassifierTestCaseProvider<TDatabase, InferenceModel>;
383*89c4ff92SAndroid Build Coastguard Worker
384*89c4ff92SAndroid Build Coastguard Worker            return make_unique<TestCaseProvider>(constructDatabase,
385*89c4ff92SAndroid Build Coastguard Worker                [&]
386*89c4ff92SAndroid Build Coastguard Worker                (const InferenceTestOptions &commonOptions,
387*89c4ff92SAndroid Build Coastguard Worker                 typename InferenceModel::CommandLineOptions modelOptions)
388*89c4ff92SAndroid Build Coastguard Worker                {
389*89c4ff92SAndroid Build Coastguard Worker                    if (!ValidateDirectory(modelOptions.m_ModelDir))
390*89c4ff92SAndroid Build Coastguard Worker                    {
391*89c4ff92SAndroid Build Coastguard Worker                        return std::unique_ptr<InferenceModel>();
392*89c4ff92SAndroid Build Coastguard Worker                    }
393*89c4ff92SAndroid Build Coastguard Worker
394*89c4ff92SAndroid Build Coastguard Worker                    typename InferenceModel::Params modelParams;
395*89c4ff92SAndroid Build Coastguard Worker                    modelParams.m_ModelPath = modelOptions.m_ModelDir + modelFilename;
396*89c4ff92SAndroid Build Coastguard Worker                    modelParams.m_InputBindings  = { inputBindingName };
397*89c4ff92SAndroid Build Coastguard Worker                    modelParams.m_OutputBindings = { outputBindingName };
398*89c4ff92SAndroid Build Coastguard Worker
399*89c4ff92SAndroid Build Coastguard Worker                    if (inputTensorShape)
400*89c4ff92SAndroid Build Coastguard Worker                    {
401*89c4ff92SAndroid Build Coastguard Worker                        modelParams.m_InputShapes.push_back(*inputTensorShape);
402*89c4ff92SAndroid Build Coastguard Worker                    }
403*89c4ff92SAndroid Build Coastguard Worker
404*89c4ff92SAndroid Build Coastguard Worker                    modelParams.m_IsModelBinary = isModelBinary;
405*89c4ff92SAndroid Build Coastguard Worker                    modelParams.m_ComputeDevices = modelOptions.GetComputeDevicesAsBackendIds();
406*89c4ff92SAndroid Build Coastguard Worker                    modelParams.m_VisualizePostOptimizationModel = modelOptions.m_VisualizePostOptimizationModel;
407*89c4ff92SAndroid Build Coastguard Worker                    modelParams.m_EnableFp16TurboMode = modelOptions.m_EnableFp16TurboMode;
408*89c4ff92SAndroid Build Coastguard Worker
409*89c4ff92SAndroid Build Coastguard Worker                    return std::make_unique<InferenceModel>(modelParams,
410*89c4ff92SAndroid Build Coastguard Worker                                                            commonOptions.m_EnableProfiling,
411*89c4ff92SAndroid Build Coastguard Worker                                                            commonOptions.m_DynamicBackendsPath);
412*89c4ff92SAndroid Build Coastguard Worker            });
413*89c4ff92SAndroid Build Coastguard Worker        });
414*89c4ff92SAndroid Build Coastguard Worker}
415*89c4ff92SAndroid Build Coastguard Worker
416*89c4ff92SAndroid Build Coastguard Worker} // namespace test
417*89c4ff92SAndroid Build Coastguard Worker} // namespace armnn
418