1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include "ModelAccuracyChecker.hpp"
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Logging.hpp>
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include <map>
11*89c4ff92SAndroid Build Coastguard Worker #include <vector>
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker namespace armnnUtils
14*89c4ff92SAndroid Build Coastguard Worker {
15*89c4ff92SAndroid Build Coastguard Worker
ModelAccuracyChecker(const std::map<std::string,std::string> & validationLabels,const std::vector<LabelCategoryNames> & modelOutputLabels)16*89c4ff92SAndroid Build Coastguard Worker armnnUtils::ModelAccuracyChecker::ModelAccuracyChecker(const std::map<std::string, std::string>& validationLabels,
17*89c4ff92SAndroid Build Coastguard Worker const std::vector<LabelCategoryNames>& modelOutputLabels)
18*89c4ff92SAndroid Build Coastguard Worker : m_GroundTruthLabelSet(validationLabels)
19*89c4ff92SAndroid Build Coastguard Worker , m_ModelOutputLabels(modelOutputLabels)
20*89c4ff92SAndroid Build Coastguard Worker {}
21*89c4ff92SAndroid Build Coastguard Worker
GetAccuracy(unsigned int k)22*89c4ff92SAndroid Build Coastguard Worker float ModelAccuracyChecker::GetAccuracy(unsigned int k)
23*89c4ff92SAndroid Build Coastguard Worker {
24*89c4ff92SAndroid Build Coastguard Worker if (k > 10)
25*89c4ff92SAndroid Build Coastguard Worker {
26*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(warning) << "Accuracy Tool only supports a maximum of Top 10 Accuracy. "
27*89c4ff92SAndroid Build Coastguard Worker "Printing Top 10 Accuracy result!";
28*89c4ff92SAndroid Build Coastguard Worker k = 10;
29*89c4ff92SAndroid Build Coastguard Worker }
30*89c4ff92SAndroid Build Coastguard Worker unsigned int total = 0;
31*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = k; i > 0; --i)
32*89c4ff92SAndroid Build Coastguard Worker {
33*89c4ff92SAndroid Build Coastguard Worker total += m_TopK[i];
34*89c4ff92SAndroid Build Coastguard Worker }
35*89c4ff92SAndroid Build Coastguard Worker return static_cast<float>(total * 100) / static_cast<float>(m_ImagesProcessed);
36*89c4ff92SAndroid Build Coastguard Worker }
37*89c4ff92SAndroid Build Coastguard Worker
38*89c4ff92SAndroid Build Coastguard Worker // Split a string into tokens by a delimiter
39*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string>
SplitBy(const std::string & originalString,const std::string & delimiter,bool includeEmptyToken)40*89c4ff92SAndroid Build Coastguard Worker SplitBy(const std::string& originalString, const std::string& delimiter, bool includeEmptyToken)
41*89c4ff92SAndroid Build Coastguard Worker {
42*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> tokens;
43*89c4ff92SAndroid Build Coastguard Worker size_t cur = 0;
44*89c4ff92SAndroid Build Coastguard Worker size_t next = 0;
45*89c4ff92SAndroid Build Coastguard Worker while ((next = originalString.find(delimiter, cur)) != std::string::npos)
46*89c4ff92SAndroid Build Coastguard Worker {
47*89c4ff92SAndroid Build Coastguard Worker // Skip empty tokens, unless explicitly stated to include them.
48*89c4ff92SAndroid Build Coastguard Worker if (next - cur > 0 || includeEmptyToken)
49*89c4ff92SAndroid Build Coastguard Worker {
50*89c4ff92SAndroid Build Coastguard Worker tokens.push_back(originalString.substr(cur, next - cur));
51*89c4ff92SAndroid Build Coastguard Worker }
52*89c4ff92SAndroid Build Coastguard Worker cur = next + delimiter.size();
53*89c4ff92SAndroid Build Coastguard Worker }
54*89c4ff92SAndroid Build Coastguard Worker // Get the remaining token
55*89c4ff92SAndroid Build Coastguard Worker // Skip empty tokens, unless explicitly stated to include them.
56*89c4ff92SAndroid Build Coastguard Worker if (originalString.size() - cur > 0 || includeEmptyToken)
57*89c4ff92SAndroid Build Coastguard Worker {
58*89c4ff92SAndroid Build Coastguard Worker tokens.push_back(originalString.substr(cur, originalString.size() - cur));
59*89c4ff92SAndroid Build Coastguard Worker }
60*89c4ff92SAndroid Build Coastguard Worker return tokens;
61*89c4ff92SAndroid Build Coastguard Worker }
62*89c4ff92SAndroid Build Coastguard Worker
63*89c4ff92SAndroid Build Coastguard Worker // Remove any preceding and trailing character specified in the characterSet.
Strip(const std::string & originalString,const std::string & characterSet)64*89c4ff92SAndroid Build Coastguard Worker std::string Strip(const std::string& originalString, const std::string& characterSet)
65*89c4ff92SAndroid Build Coastguard Worker {
66*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(!characterSet.empty());
67*89c4ff92SAndroid Build Coastguard Worker const std::size_t firstFound = originalString.find_first_not_of(characterSet);
68*89c4ff92SAndroid Build Coastguard Worker const std::size_t lastFound = originalString.find_last_not_of(characterSet);
69*89c4ff92SAndroid Build Coastguard Worker // Return empty if the originalString is empty or the originalString contains only to-be-striped characters
70*89c4ff92SAndroid Build Coastguard Worker if (firstFound == std::string::npos || lastFound == std::string::npos)
71*89c4ff92SAndroid Build Coastguard Worker {
72*89c4ff92SAndroid Build Coastguard Worker return "";
73*89c4ff92SAndroid Build Coastguard Worker }
74*89c4ff92SAndroid Build Coastguard Worker return originalString.substr(firstFound, lastFound + 1 - firstFound);
75*89c4ff92SAndroid Build Coastguard Worker }
76*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnUtils