1*89c4ff92SAndroid Build Coastguard Worker// 2*89c4ff92SAndroid Build Coastguard Worker// Copyright © 2017 Arm Ltd. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker// SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker// 5*89c4ff92SAndroid Build Coastguard Worker#include "InferenceTest.hpp" 6*89c4ff92SAndroid Build Coastguard Worker 7*89c4ff92SAndroid Build Coastguard Worker#include <armnn/Utils.hpp> 8*89c4ff92SAndroid Build Coastguard Worker#include <armnn/utility/Assert.hpp> 9*89c4ff92SAndroid Build Coastguard Worker#include <armnn/utility/NumericCast.hpp> 10*89c4ff92SAndroid Build Coastguard Worker#include <armnnUtils/TContainer.hpp> 11*89c4ff92SAndroid Build Coastguard Worker 12*89c4ff92SAndroid Build Coastguard Worker#include "CxxoptsUtils.hpp" 13*89c4ff92SAndroid Build Coastguard Worker 14*89c4ff92SAndroid Build Coastguard Worker#include <cxxopts/cxxopts.hpp> 15*89c4ff92SAndroid Build Coastguard Worker#include <fmt/format.h> 16*89c4ff92SAndroid Build Coastguard Worker 17*89c4ff92SAndroid Build Coastguard Worker#include <fstream> 18*89c4ff92SAndroid Build Coastguard Worker#include <iostream> 19*89c4ff92SAndroid Build Coastguard Worker#include <iomanip> 20*89c4ff92SAndroid Build Coastguard Worker#include <array> 21*89c4ff92SAndroid Build Coastguard Worker#include <chrono> 22*89c4ff92SAndroid Build Coastguard Worker 23*89c4ff92SAndroid Build Coastguard Workerusing namespace std; 24*89c4ff92SAndroid Build Coastguard Workerusing namespace std::chrono; 25*89c4ff92SAndroid Build Coastguard Workerusing namespace armnn::test; 26*89c4ff92SAndroid Build Coastguard Worker 27*89c4ff92SAndroid Build Coastguard Workernamespace armnn 28*89c4ff92SAndroid Build Coastguard Worker{ 29*89c4ff92SAndroid Build Coastguard Workernamespace test 30*89c4ff92SAndroid Build Coastguard Worker{ 31*89c4ff92SAndroid Build Coastguard Worker 32*89c4ff92SAndroid Build Coastguard Workertemplate <typename TTestCaseDatabase, typename TModel> 33*89c4ff92SAndroid Build Coastguard WorkerClassifierTestCase<TTestCaseDatabase, TModel>::ClassifierTestCase( 34*89c4ff92SAndroid Build Coastguard Worker int& numInferencesRef, 35*89c4ff92SAndroid Build Coastguard Worker int& numCorrectInferencesRef, 36*89c4ff92SAndroid Build Coastguard Worker const std::vector<unsigned int>& validationPredictions, 37*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int>* validationPredictionsOut, 38*89c4ff92SAndroid Build Coastguard Worker TModel& model, 39*89c4ff92SAndroid Build Coastguard Worker unsigned int testCaseId, 40*89c4ff92SAndroid Build Coastguard Worker unsigned int label, 41*89c4ff92SAndroid Build Coastguard Worker std::vector<typename TModel::DataType> modelInput) 42*89c4ff92SAndroid Build Coastguard Worker : InferenceModelTestCase<TModel>( 43*89c4ff92SAndroid Build Coastguard Worker model, testCaseId, std::vector<armnnUtils::TContainer>{ modelInput }, { model.GetOutputSize() }) 44*89c4ff92SAndroid Build Coastguard Worker , m_Label(label) 45*89c4ff92SAndroid Build Coastguard Worker , m_QuantizationParams(model.GetQuantizationParams()) 46*89c4ff92SAndroid Build Coastguard Worker , m_NumInferencesRef(numInferencesRef) 47*89c4ff92SAndroid Build Coastguard Worker , m_NumCorrectInferencesRef(numCorrectInferencesRef) 48*89c4ff92SAndroid Build Coastguard Worker , m_ValidationPredictions(validationPredictions) 49*89c4ff92SAndroid Build Coastguard Worker , m_ValidationPredictionsOut(validationPredictionsOut) 50*89c4ff92SAndroid Build Coastguard Worker{ 51*89c4ff92SAndroid Build Coastguard Worker} 52*89c4ff92SAndroid Build Coastguard Worker 53*89c4ff92SAndroid Build Coastguard Workerstruct ClassifierResultProcessor 54*89c4ff92SAndroid Build Coastguard Worker{ 55*89c4ff92SAndroid Build Coastguard Worker using ResultMap = std::map<float,int>; 56*89c4ff92SAndroid Build Coastguard Worker 57*89c4ff92SAndroid Build Coastguard Worker ClassifierResultProcessor(float scale, int offset) 58*89c4ff92SAndroid Build Coastguard Worker : m_Scale(scale) 59*89c4ff92SAndroid Build Coastguard Worker , m_Offset(offset) 60*89c4ff92SAndroid Build Coastguard Worker {} 61*89c4ff92SAndroid Build Coastguard Worker 62*89c4ff92SAndroid Build Coastguard Worker void operator()(const std::vector<float>& values) 63*89c4ff92SAndroid Build Coastguard Worker { 64*89c4ff92SAndroid Build Coastguard Worker SortPredictions(values, [](float value) 65*89c4ff92SAndroid Build Coastguard Worker { 66*89c4ff92SAndroid Build Coastguard Worker return value; 67*89c4ff92SAndroid Build Coastguard Worker }); 68*89c4ff92SAndroid Build Coastguard Worker } 69*89c4ff92SAndroid Build Coastguard Worker 70*89c4ff92SAndroid Build Coastguard Worker void operator()(const std::vector<int8_t>& values) 71*89c4ff92SAndroid Build Coastguard Worker { 72*89c4ff92SAndroid Build Coastguard Worker SortPredictions(values, [](int8_t value) 73*89c4ff92SAndroid Build Coastguard Worker { 74*89c4ff92SAndroid Build Coastguard Worker return value; 75*89c4ff92SAndroid Build Coastguard Worker }); 76*89c4ff92SAndroid Build Coastguard Worker } 77*89c4ff92SAndroid Build Coastguard Worker 78*89c4ff92SAndroid Build Coastguard Worker void operator()(const std::vector<uint8_t>& values) 79*89c4ff92SAndroid Build Coastguard Worker { 80*89c4ff92SAndroid Build Coastguard Worker auto& scale = m_Scale; 81*89c4ff92SAndroid Build Coastguard Worker auto& offset = m_Offset; 82*89c4ff92SAndroid Build Coastguard Worker SortPredictions(values, [&scale, &offset](uint8_t value) 83*89c4ff92SAndroid Build Coastguard Worker { 84*89c4ff92SAndroid Build Coastguard Worker return armnn::Dequantize(value, scale, offset); 85*89c4ff92SAndroid Build Coastguard Worker }); 86*89c4ff92SAndroid Build Coastguard Worker } 87*89c4ff92SAndroid Build Coastguard Worker 88*89c4ff92SAndroid Build Coastguard Worker void operator()(const std::vector<int>& values) 89*89c4ff92SAndroid Build Coastguard Worker { 90*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(values); 91*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(false, "Non-float predictions output not supported."); 92*89c4ff92SAndroid Build Coastguard Worker } 93*89c4ff92SAndroid Build Coastguard Worker 94*89c4ff92SAndroid Build Coastguard Worker ResultMap& GetResultMap() { return m_ResultMap; } 95*89c4ff92SAndroid Build Coastguard Worker 96*89c4ff92SAndroid Build Coastguard Workerprivate: 97*89c4ff92SAndroid Build Coastguard Worker template<typename Container, typename Delegate> 98*89c4ff92SAndroid Build Coastguard Worker void SortPredictions(const Container& c, Delegate delegate) 99*89c4ff92SAndroid Build Coastguard Worker { 100*89c4ff92SAndroid Build Coastguard Worker int index = 0; 101*89c4ff92SAndroid Build Coastguard Worker for (const auto& value : c) 102*89c4ff92SAndroid Build Coastguard Worker { 103*89c4ff92SAndroid Build Coastguard Worker int classification = index++; 104*89c4ff92SAndroid Build Coastguard Worker // Take the first class with each probability 105*89c4ff92SAndroid Build Coastguard Worker // This avoids strange results when looping over batched results produced 106*89c4ff92SAndroid Build Coastguard Worker // with identical test data. 107*89c4ff92SAndroid Build Coastguard Worker ResultMap::iterator lb = m_ResultMap.lower_bound(value); 108*89c4ff92SAndroid Build Coastguard Worker 109*89c4ff92SAndroid Build Coastguard Worker if (lb == m_ResultMap.end() || !m_ResultMap.key_comp()(value, lb->first)) 110*89c4ff92SAndroid Build Coastguard Worker { 111*89c4ff92SAndroid Build Coastguard Worker // If the key is not already in the map, insert it. 112*89c4ff92SAndroid Build Coastguard Worker m_ResultMap.insert(lb, ResultMap::value_type(delegate(value), classification)); 113*89c4ff92SAndroid Build Coastguard Worker } 114*89c4ff92SAndroid Build Coastguard Worker } 115*89c4ff92SAndroid Build Coastguard Worker } 116*89c4ff92SAndroid Build Coastguard Worker 117*89c4ff92SAndroid Build Coastguard Worker ResultMap m_ResultMap; 118*89c4ff92SAndroid Build Coastguard Worker 119*89c4ff92SAndroid Build Coastguard Worker float m_Scale=0.0f; 120*89c4ff92SAndroid Build Coastguard Worker int m_Offset=0; 121*89c4ff92SAndroid Build Coastguard Worker}; 122*89c4ff92SAndroid Build Coastguard Worker 123*89c4ff92SAndroid Build Coastguard Workertemplate <typename TTestCaseDatabase, typename TModel> 124*89c4ff92SAndroid Build Coastguard WorkerTestCaseResult ClassifierTestCase<TTestCaseDatabase, TModel>::ProcessResult(const InferenceTestOptions& params) 125*89c4ff92SAndroid Build Coastguard Worker{ 126*89c4ff92SAndroid Build Coastguard Worker auto& output = this->GetOutputs()[0]; 127*89c4ff92SAndroid Build Coastguard Worker const auto testCaseId = this->GetTestCaseId(); 128*89c4ff92SAndroid Build Coastguard Worker 129*89c4ff92SAndroid Build Coastguard Worker ClassifierResultProcessor resultProcessor(m_QuantizationParams.first, m_QuantizationParams.second); 130*89c4ff92SAndroid Build Coastguard Worker mapbox::util::apply_visitor(resultProcessor, output); 131*89c4ff92SAndroid Build Coastguard Worker 132*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << "= Prediction values for test #" << testCaseId; 133*89c4ff92SAndroid Build Coastguard Worker auto it = resultProcessor.GetResultMap().rbegin(); 134*89c4ff92SAndroid Build Coastguard Worker for (int i=0; i<5 && it != resultProcessor.GetResultMap().rend(); ++i) 135*89c4ff92SAndroid Build Coastguard Worker { 136*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << "Top(" << (i+1) << ") prediction is " << it->second << 137*89c4ff92SAndroid Build Coastguard Worker " with value: " << (it->first); 138*89c4ff92SAndroid Build Coastguard Worker ++it; 139*89c4ff92SAndroid Build Coastguard Worker } 140*89c4ff92SAndroid Build Coastguard Worker 141*89c4ff92SAndroid Build Coastguard Worker unsigned int prediction = 0; 142*89c4ff92SAndroid Build Coastguard Worker mapbox::util::apply_visitor([&](auto&& value) 143*89c4ff92SAndroid Build Coastguard Worker { 144*89c4ff92SAndroid Build Coastguard Worker prediction = armnn::numeric_cast<unsigned int>( 145*89c4ff92SAndroid Build Coastguard Worker std::distance(value.begin(), std::max_element(value.begin(), value.end()))); 146*89c4ff92SAndroid Build Coastguard Worker }, 147*89c4ff92SAndroid Build Coastguard Worker output); 148*89c4ff92SAndroid Build Coastguard Worker 149*89c4ff92SAndroid Build Coastguard Worker // If we're just running the defaultTestCaseIds, each one must be classified correctly. 150*89c4ff92SAndroid Build Coastguard Worker if (params.m_IterationCount == 0 && prediction != m_Label) 151*89c4ff92SAndroid Build Coastguard Worker { 152*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" << 153*89c4ff92SAndroid Build Coastguard Worker " is incorrect (should be " << m_Label << ")"; 154*89c4ff92SAndroid Build Coastguard Worker return TestCaseResult::Failed; 155*89c4ff92SAndroid Build Coastguard Worker } 156*89c4ff92SAndroid Build Coastguard Worker 157*89c4ff92SAndroid Build Coastguard Worker // If a validation file was provided as input, it checks that the prediction matches. 158*89c4ff92SAndroid Build Coastguard Worker if (!m_ValidationPredictions.empty() && prediction != m_ValidationPredictions[testCaseId]) 159*89c4ff92SAndroid Build Coastguard Worker { 160*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" << 161*89c4ff92SAndroid Build Coastguard Worker " doesn't match the prediction in the validation file (" << m_ValidationPredictions[testCaseId] << ")"; 162*89c4ff92SAndroid Build Coastguard Worker return TestCaseResult::Failed; 163*89c4ff92SAndroid Build Coastguard Worker } 164*89c4ff92SAndroid Build Coastguard Worker 165*89c4ff92SAndroid Build Coastguard Worker // If a validation file was requested as output, it stores the predictions. 166*89c4ff92SAndroid Build Coastguard Worker if (m_ValidationPredictionsOut) 167*89c4ff92SAndroid Build Coastguard Worker { 168*89c4ff92SAndroid Build Coastguard Worker m_ValidationPredictionsOut->push_back(prediction); 169*89c4ff92SAndroid Build Coastguard Worker } 170*89c4ff92SAndroid Build Coastguard Worker 171*89c4ff92SAndroid Build Coastguard Worker // Updates accuracy stats. 172*89c4ff92SAndroid Build Coastguard Worker m_NumInferencesRef++; 173*89c4ff92SAndroid Build Coastguard Worker if (prediction == m_Label) 174*89c4ff92SAndroid Build Coastguard Worker { 175*89c4ff92SAndroid Build Coastguard Worker m_NumCorrectInferencesRef++; 176*89c4ff92SAndroid Build Coastguard Worker } 177*89c4ff92SAndroid Build Coastguard Worker 178*89c4ff92SAndroid Build Coastguard Worker return TestCaseResult::Ok; 179*89c4ff92SAndroid Build Coastguard Worker} 180*89c4ff92SAndroid Build Coastguard Worker 181*89c4ff92SAndroid Build Coastguard Workertemplate <typename TDatabase, typename InferenceModel> 182*89c4ff92SAndroid Build Coastguard Workertemplate <typename TConstructDatabaseCallable, typename TConstructModelCallable> 183*89c4ff92SAndroid Build Coastguard WorkerClassifierTestCaseProvider<TDatabase, InferenceModel>::ClassifierTestCaseProvider( 184*89c4ff92SAndroid Build Coastguard Worker TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel) 185*89c4ff92SAndroid Build Coastguard Worker : m_ConstructModel(constructModel) 186*89c4ff92SAndroid Build Coastguard Worker , m_ConstructDatabase(constructDatabase) 187*89c4ff92SAndroid Build Coastguard Worker , m_NumInferences(0) 188*89c4ff92SAndroid Build Coastguard Worker , m_NumCorrectInferences(0) 189*89c4ff92SAndroid Build Coastguard Worker{ 190*89c4ff92SAndroid Build Coastguard Worker} 191*89c4ff92SAndroid Build Coastguard Worker 192*89c4ff92SAndroid Build Coastguard Workertemplate <typename TDatabase, typename InferenceModel> 193*89c4ff92SAndroid Build Coastguard Workervoid ClassifierTestCaseProvider<TDatabase, InferenceModel>::AddCommandLineOptions( 194*89c4ff92SAndroid Build Coastguard Worker cxxopts::Options& options, std::vector<std::string>& required) 195*89c4ff92SAndroid Build Coastguard Worker{ 196*89c4ff92SAndroid Build Coastguard Worker options 197*89c4ff92SAndroid Build Coastguard Worker .allow_unrecognised_options() 198*89c4ff92SAndroid Build Coastguard Worker .add_options() 199*89c4ff92SAndroid Build Coastguard Worker ("validation-file-in", 200*89c4ff92SAndroid Build Coastguard Worker "Reads expected predictions from the given file and confirms they match the actual predictions.", 201*89c4ff92SAndroid Build Coastguard Worker cxxopts::value<std::string>(m_ValidationFileIn)->default_value("")) 202*89c4ff92SAndroid Build Coastguard Worker ("validation-file-out", "Predictions are saved to the given file for later use via --validation-file-in.", 203*89c4ff92SAndroid Build Coastguard Worker cxxopts::value<std::string>(m_ValidationFileOut)->default_value("")) 204*89c4ff92SAndroid Build Coastguard Worker ("d,data-dir", "Path to directory containing test data", cxxopts::value<std::string>(m_DataDir)); 205*89c4ff92SAndroid Build Coastguard Worker 206*89c4ff92SAndroid Build Coastguard Worker required.emplace_back("data-dir"); //add to required arguments to check 207*89c4ff92SAndroid Build Coastguard Worker 208*89c4ff92SAndroid Build Coastguard Worker InferenceModel::AddCommandLineOptions(options, m_ModelCommandLineOptions, required); 209*89c4ff92SAndroid Build Coastguard Worker} 210*89c4ff92SAndroid Build Coastguard Worker 211*89c4ff92SAndroid Build Coastguard Workertemplate <typename TDatabase, typename InferenceModel> 212*89c4ff92SAndroid Build Coastguard Workerbool ClassifierTestCaseProvider<TDatabase, InferenceModel>::ProcessCommandLineOptions( 213*89c4ff92SAndroid Build Coastguard Worker const InferenceTestOptions& commonOptions) 214*89c4ff92SAndroid Build Coastguard Worker{ 215*89c4ff92SAndroid Build Coastguard Worker if (!ValidateDirectory(m_DataDir)) 216*89c4ff92SAndroid Build Coastguard Worker { 217*89c4ff92SAndroid Build Coastguard Worker return false; 218*89c4ff92SAndroid Build Coastguard Worker } 219*89c4ff92SAndroid Build Coastguard Worker 220*89c4ff92SAndroid Build Coastguard Worker ReadPredictions(); 221*89c4ff92SAndroid Build Coastguard Worker 222*89c4ff92SAndroid Build Coastguard Worker m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions); 223*89c4ff92SAndroid Build Coastguard Worker if (!m_Model) 224*89c4ff92SAndroid Build Coastguard Worker { 225*89c4ff92SAndroid Build Coastguard Worker return false; 226*89c4ff92SAndroid Build Coastguard Worker } 227*89c4ff92SAndroid Build Coastguard Worker 228*89c4ff92SAndroid Build Coastguard Worker m_Database = std::make_unique<TDatabase>(m_ConstructDatabase(m_DataDir.c_str(), *m_Model)); 229*89c4ff92SAndroid Build Coastguard Worker if (!m_Database) 230*89c4ff92SAndroid Build Coastguard Worker { 231*89c4ff92SAndroid Build Coastguard Worker return false; 232*89c4ff92SAndroid Build Coastguard Worker } 233*89c4ff92SAndroid Build Coastguard Worker 234*89c4ff92SAndroid Build Coastguard Worker return true; 235*89c4ff92SAndroid Build Coastguard Worker} 236*89c4ff92SAndroid Build Coastguard Worker 237*89c4ff92SAndroid Build Coastguard Workertemplate <typename TDatabase, typename InferenceModel> 238*89c4ff92SAndroid Build Coastguard Workerstd::unique_ptr<IInferenceTestCase> 239*89c4ff92SAndroid Build Coastguard WorkerClassifierTestCaseProvider<TDatabase, InferenceModel>::GetTestCase(unsigned int testCaseId) 240*89c4ff92SAndroid Build Coastguard Worker{ 241*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<typename TDatabase::TTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId); 242*89c4ff92SAndroid Build Coastguard Worker if (testCaseData == nullptr) 243*89c4ff92SAndroid Build Coastguard Worker { 244*89c4ff92SAndroid Build Coastguard Worker return nullptr; 245*89c4ff92SAndroid Build Coastguard Worker } 246*89c4ff92SAndroid Build Coastguard Worker 247*89c4ff92SAndroid Build Coastguard Worker return std::make_unique<ClassifierTestCase<TDatabase, InferenceModel>>( 248*89c4ff92SAndroid Build Coastguard Worker m_NumInferences, 249*89c4ff92SAndroid Build Coastguard Worker m_NumCorrectInferences, 250*89c4ff92SAndroid Build Coastguard Worker m_ValidationPredictions, 251*89c4ff92SAndroid Build Coastguard Worker m_ValidationFileOut.empty() ? nullptr : &m_ValidationPredictionsOut, 252*89c4ff92SAndroid Build Coastguard Worker *m_Model, 253*89c4ff92SAndroid Build Coastguard Worker testCaseId, 254*89c4ff92SAndroid Build Coastguard Worker testCaseData->m_Label, 255*89c4ff92SAndroid Build Coastguard Worker std::move(testCaseData->m_InputImage)); 256*89c4ff92SAndroid Build Coastguard Worker} 257*89c4ff92SAndroid Build Coastguard Worker 258*89c4ff92SAndroid Build Coastguard Workertemplate <typename TDatabase, typename InferenceModel> 259*89c4ff92SAndroid Build Coastguard Workerbool ClassifierTestCaseProvider<TDatabase, InferenceModel>::OnInferenceTestFinished() 260*89c4ff92SAndroid Build Coastguard Worker{ 261*89c4ff92SAndroid Build Coastguard Worker const double accuracy = armnn::numeric_cast<double>(m_NumCorrectInferences) / 262*89c4ff92SAndroid Build Coastguard Worker armnn::numeric_cast<double>(m_NumInferences); 263*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << std::fixed << std::setprecision(3) << "Overall accuracy: " << accuracy; 264*89c4ff92SAndroid Build Coastguard Worker 265*89c4ff92SAndroid Build Coastguard Worker // If a validation file was requested as output, the predictions are saved to it. 266*89c4ff92SAndroid Build Coastguard Worker if (!m_ValidationFileOut.empty()) 267*89c4ff92SAndroid Build Coastguard Worker { 268*89c4ff92SAndroid Build Coastguard Worker std::ofstream validationFileOut(m_ValidationFileOut.c_str(), std::ios_base::trunc | std::ios_base::out); 269*89c4ff92SAndroid Build Coastguard Worker if (validationFileOut.good()) 270*89c4ff92SAndroid Build Coastguard Worker { 271*89c4ff92SAndroid Build Coastguard Worker for (const unsigned int prediction : m_ValidationPredictionsOut) 272*89c4ff92SAndroid Build Coastguard Worker { 273*89c4ff92SAndroid Build Coastguard Worker validationFileOut << prediction << std::endl; 274*89c4ff92SAndroid Build Coastguard Worker } 275*89c4ff92SAndroid Build Coastguard Worker } 276*89c4ff92SAndroid Build Coastguard Worker else 277*89c4ff92SAndroid Build Coastguard Worker { 278*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "Failed to open output validation file: " << m_ValidationFileOut; 279*89c4ff92SAndroid Build Coastguard Worker return false; 280*89c4ff92SAndroid Build Coastguard Worker } 281*89c4ff92SAndroid Build Coastguard Worker } 282*89c4ff92SAndroid Build Coastguard Worker 283*89c4ff92SAndroid Build Coastguard Worker return true; 284*89c4ff92SAndroid Build Coastguard Worker} 285*89c4ff92SAndroid Build Coastguard Worker 286*89c4ff92SAndroid Build Coastguard Workertemplate <typename TDatabase, typename InferenceModel> 287*89c4ff92SAndroid Build Coastguard Workervoid ClassifierTestCaseProvider<TDatabase, InferenceModel>::ReadPredictions() 288*89c4ff92SAndroid Build Coastguard Worker{ 289*89c4ff92SAndroid Build Coastguard Worker // Reads the expected predictions from the input validation file (if provided). 290*89c4ff92SAndroid Build Coastguard Worker if (!m_ValidationFileIn.empty()) 291*89c4ff92SAndroid Build Coastguard Worker { 292*89c4ff92SAndroid Build Coastguard Worker std::ifstream validationFileIn(m_ValidationFileIn.c_str(), std::ios_base::in); 293*89c4ff92SAndroid Build Coastguard Worker if (validationFileIn.good()) 294*89c4ff92SAndroid Build Coastguard Worker { 295*89c4ff92SAndroid Build Coastguard Worker while (!validationFileIn.eof()) 296*89c4ff92SAndroid Build Coastguard Worker { 297*89c4ff92SAndroid Build Coastguard Worker unsigned int i; 298*89c4ff92SAndroid Build Coastguard Worker validationFileIn >> i; 299*89c4ff92SAndroid Build Coastguard Worker m_ValidationPredictions.emplace_back(i); 300*89c4ff92SAndroid Build Coastguard Worker } 301*89c4ff92SAndroid Build Coastguard Worker } 302*89c4ff92SAndroid Build Coastguard Worker else 303*89c4ff92SAndroid Build Coastguard Worker { 304*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception(fmt::format("Failed to open input validation file: {}" 305*89c4ff92SAndroid Build Coastguard Worker , m_ValidationFileIn)); 306*89c4ff92SAndroid Build Coastguard Worker } 307*89c4ff92SAndroid Build Coastguard Worker } 308*89c4ff92SAndroid Build Coastguard Worker} 309*89c4ff92SAndroid Build Coastguard Worker 310*89c4ff92SAndroid Build Coastguard Workertemplate<typename TConstructTestCaseProvider> 311*89c4ff92SAndroid Build Coastguard Workerint InferenceTestMain(int argc, 312*89c4ff92SAndroid Build Coastguard Worker char* argv[], 313*89c4ff92SAndroid Build Coastguard Worker const std::vector<unsigned int>& defaultTestCaseIds, 314*89c4ff92SAndroid Build Coastguard Worker TConstructTestCaseProvider constructTestCaseProvider) 315*89c4ff92SAndroid Build Coastguard Worker{ 316*89c4ff92SAndroid Build Coastguard Worker // Configures logging for both the ARMNN library and this test program. 317*89c4ff92SAndroid Build Coastguard Worker#ifdef NDEBUG 318*89c4ff92SAndroid Build Coastguard Worker armnn::LogSeverity level = armnn::LogSeverity::Info; 319*89c4ff92SAndroid Build Coastguard Worker#else 320*89c4ff92SAndroid Build Coastguard Worker armnn::LogSeverity level = armnn::LogSeverity::Debug; 321*89c4ff92SAndroid Build Coastguard Worker#endif 322*89c4ff92SAndroid Build Coastguard Worker armnn::ConfigureLogging(true, true, level); 323*89c4ff92SAndroid Build Coastguard Worker 324*89c4ff92SAndroid Build Coastguard Worker try 325*89c4ff92SAndroid Build Coastguard Worker { 326*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<IInferenceTestCaseProvider> testCaseProvider = constructTestCaseProvider(); 327*89c4ff92SAndroid Build Coastguard Worker if (!testCaseProvider) 328*89c4ff92SAndroid Build Coastguard Worker { 329*89c4ff92SAndroid Build Coastguard Worker return 1; 330*89c4ff92SAndroid Build Coastguard Worker } 331*89c4ff92SAndroid Build Coastguard Worker 332*89c4ff92SAndroid Build Coastguard Worker InferenceTestOptions inferenceTestOptions; 333*89c4ff92SAndroid Build Coastguard Worker if (!ParseCommandLine(argc, argv, *testCaseProvider, inferenceTestOptions)) 334*89c4ff92SAndroid Build Coastguard Worker { 335*89c4ff92SAndroid Build Coastguard Worker return 1; 336*89c4ff92SAndroid Build Coastguard Worker } 337*89c4ff92SAndroid Build Coastguard Worker 338*89c4ff92SAndroid Build Coastguard Worker const bool success = InferenceTest(inferenceTestOptions, defaultTestCaseIds, *testCaseProvider); 339*89c4ff92SAndroid Build Coastguard Worker return success ? 0 : 1; 340*89c4ff92SAndroid Build Coastguard Worker } 341*89c4ff92SAndroid Build Coastguard Worker catch (armnn::Exception const& e) 342*89c4ff92SAndroid Build Coastguard Worker { 343*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(fatal) << "Armnn Error: " << e.what(); 344*89c4ff92SAndroid Build Coastguard Worker return 1; 345*89c4ff92SAndroid Build Coastguard Worker } 346*89c4ff92SAndroid Build Coastguard Worker} 347*89c4ff92SAndroid Build Coastguard Worker 348*89c4ff92SAndroid Build Coastguard Worker// 349*89c4ff92SAndroid Build Coastguard Worker// This function allows us to create a classifier inference test based on: 350*89c4ff92SAndroid Build Coastguard Worker// - a model file name 351*89c4ff92SAndroid Build Coastguard Worker// - which can be a binary or a text file for protobuf formats 352*89c4ff92SAndroid Build Coastguard Worker// - an input tensor name 353*89c4ff92SAndroid Build Coastguard Worker// - an output tensor name 354*89c4ff92SAndroid Build Coastguard Worker// - a set of test case ids 355*89c4ff92SAndroid Build Coastguard Worker// - a callback method which creates an object that can return images 356*89c4ff92SAndroid Build Coastguard Worker// called 'Database' in these tests 357*89c4ff92SAndroid Build Coastguard Worker// - and an input tensor shape 358*89c4ff92SAndroid Build Coastguard Worker// 359*89c4ff92SAndroid Build Coastguard Workertemplate<typename TDatabase, 360*89c4ff92SAndroid Build Coastguard Worker typename TParser, 361*89c4ff92SAndroid Build Coastguard Worker typename TConstructDatabaseCallable> 362*89c4ff92SAndroid Build Coastguard Workerint ClassifierInferenceTestMain(int argc, 363*89c4ff92SAndroid Build Coastguard Worker char* argv[], 364*89c4ff92SAndroid Build Coastguard Worker const char* modelFilename, 365*89c4ff92SAndroid Build Coastguard Worker bool isModelBinary, 366*89c4ff92SAndroid Build Coastguard Worker const char* inputBindingName, 367*89c4ff92SAndroid Build Coastguard Worker const char* outputBindingName, 368*89c4ff92SAndroid Build Coastguard Worker const std::vector<unsigned int>& defaultTestCaseIds, 369*89c4ff92SAndroid Build Coastguard Worker TConstructDatabaseCallable constructDatabase, 370*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorShape* inputTensorShape) 371*89c4ff92SAndroid Build Coastguard Worker 372*89c4ff92SAndroid Build Coastguard Worker{ 373*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(modelFilename); 374*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(inputBindingName); 375*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(outputBindingName); 376*89c4ff92SAndroid Build Coastguard Worker 377*89c4ff92SAndroid Build Coastguard Worker return InferenceTestMain(argc, argv, defaultTestCaseIds, 378*89c4ff92SAndroid Build Coastguard Worker [=] 379*89c4ff92SAndroid Build Coastguard Worker () 380*89c4ff92SAndroid Build Coastguard Worker { 381*89c4ff92SAndroid Build Coastguard Worker using InferenceModel = InferenceModel<TParser, typename TDatabase::DataType>; 382*89c4ff92SAndroid Build Coastguard Worker using TestCaseProvider = ClassifierTestCaseProvider<TDatabase, InferenceModel>; 383*89c4ff92SAndroid Build Coastguard Worker 384*89c4ff92SAndroid Build Coastguard Worker return make_unique<TestCaseProvider>(constructDatabase, 385*89c4ff92SAndroid Build Coastguard Worker [&] 386*89c4ff92SAndroid Build Coastguard Worker (const InferenceTestOptions &commonOptions, 387*89c4ff92SAndroid Build Coastguard Worker typename InferenceModel::CommandLineOptions modelOptions) 388*89c4ff92SAndroid Build Coastguard Worker { 389*89c4ff92SAndroid Build Coastguard Worker if (!ValidateDirectory(modelOptions.m_ModelDir)) 390*89c4ff92SAndroid Build Coastguard Worker { 391*89c4ff92SAndroid Build Coastguard Worker return std::unique_ptr<InferenceModel>(); 392*89c4ff92SAndroid Build Coastguard Worker } 393*89c4ff92SAndroid Build Coastguard Worker 394*89c4ff92SAndroid Build Coastguard Worker typename InferenceModel::Params modelParams; 395*89c4ff92SAndroid Build Coastguard Worker modelParams.m_ModelPath = modelOptions.m_ModelDir + modelFilename; 396*89c4ff92SAndroid Build Coastguard Worker modelParams.m_InputBindings = { inputBindingName }; 397*89c4ff92SAndroid Build Coastguard Worker modelParams.m_OutputBindings = { outputBindingName }; 398*89c4ff92SAndroid Build Coastguard Worker 399*89c4ff92SAndroid Build Coastguard Worker if (inputTensorShape) 400*89c4ff92SAndroid Build Coastguard Worker { 401*89c4ff92SAndroid Build Coastguard Worker modelParams.m_InputShapes.push_back(*inputTensorShape); 402*89c4ff92SAndroid Build Coastguard Worker } 403*89c4ff92SAndroid Build Coastguard Worker 404*89c4ff92SAndroid Build Coastguard Worker modelParams.m_IsModelBinary = isModelBinary; 405*89c4ff92SAndroid Build Coastguard Worker modelParams.m_ComputeDevices = modelOptions.GetComputeDevicesAsBackendIds(); 406*89c4ff92SAndroid Build Coastguard Worker modelParams.m_VisualizePostOptimizationModel = modelOptions.m_VisualizePostOptimizationModel; 407*89c4ff92SAndroid Build Coastguard Worker modelParams.m_EnableFp16TurboMode = modelOptions.m_EnableFp16TurboMode; 408*89c4ff92SAndroid Build Coastguard Worker 409*89c4ff92SAndroid Build Coastguard Worker return std::make_unique<InferenceModel>(modelParams, 410*89c4ff92SAndroid Build Coastguard Worker commonOptions.m_EnableProfiling, 411*89c4ff92SAndroid Build Coastguard Worker commonOptions.m_DynamicBackendsPath); 412*89c4ff92SAndroid Build Coastguard Worker }); 413*89c4ff92SAndroid Build Coastguard Worker }); 414*89c4ff92SAndroid Build Coastguard Worker} 415*89c4ff92SAndroid Build Coastguard Worker 416*89c4ff92SAndroid Build Coastguard Worker} // namespace test 417*89c4ff92SAndroid Build Coastguard Worker} // namespace armnn 418