1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #include "ModelAccuracyChecker.hpp" 6 #include <armnnUtils/TContainer.hpp> 7 8 #include <doctest/doctest.h> 9 10 #include <iostream> 11 #include <string> 12 13 using namespace armnnUtils; 14 15 namespace { 16 struct TestHelper 17 { GetValidationLabelSet__anon05e136c50111::TestHelper18 const std::map<std::string, std::string> GetValidationLabelSet() 19 { 20 std::map<std::string, std::string> validationLabelSet; 21 validationLabelSet.insert(std::make_pair("val_01.JPEG", "goldfinch")); 22 validationLabelSet.insert(std::make_pair("val_02.JPEG", "magpie")); 23 validationLabelSet.insert(std::make_pair("val_03.JPEG", "brambling")); 24 validationLabelSet.insert(std::make_pair("val_04.JPEG", "robin")); 25 validationLabelSet.insert(std::make_pair("val_05.JPEG", "indigo bird")); 26 validationLabelSet.insert(std::make_pair("val_06.JPEG", "ostrich")); 27 validationLabelSet.insert(std::make_pair("val_07.JPEG", "jay")); 28 validationLabelSet.insert(std::make_pair("val_08.JPEG", "snowbird")); 29 validationLabelSet.insert(std::make_pair("val_09.JPEG", "house finch")); 30 validationLabelSet.insert(std::make_pair("val_09.JPEG", "bulbul")); 31 32 return validationLabelSet; 33 } GetModelOutputLabels__anon05e136c50111::TestHelper34 const std::vector<armnnUtils::LabelCategoryNames> GetModelOutputLabels() 35 { 36 const std::vector<armnnUtils::LabelCategoryNames> modelOutputLabels = 37 { 38 {"ostrich", "Struthio camelus"}, 39 {"brambling", "Fringilla montifringilla"}, 40 {"goldfinch", "Carduelis carduelis"}, 41 {"house finch", "linnet", "Carpodacus mexicanus"}, 42 {"junco", "snowbird"}, 43 {"indigo bunting", "indigo finch", "indigo bird", "Passerina cyanea"}, 44 {"robin", "American robin", "Turdus migratorius"}, 45 {"bulbul"}, 46 {"jay"}, 47 {"magpie"} 48 }; 49 return modelOutputLabels; 50 } 51 }; 52 } 53 54 TEST_SUITE("ModelAccuracyCheckerTest") 55 { 56 57 TEST_CASE_FIXTURE(TestHelper, "TestFloat32OutputTensorAccuracy") 58 { 59 ModelAccuracyChecker checker(GetValidationLabelSet(), GetModelOutputLabels()); 60 61 // Add image 1 and check accuracy 62 std::vector<float> inferenceOutputVector1 = {0.05f, 0.10f, 0.70f, 0.15f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; 63 armnnUtils::TContainer inference1Container(inferenceOutputVector1); 64 std::vector<armnnUtils::TContainer> outputTensor1; 65 outputTensor1.push_back(inference1Container); 66 67 std::string imageName = "val_01.JPEG"; 68 checker.AddImageResult<armnnUtils::TContainer>(imageName, outputTensor1); 69 70 // Top 1 Accuracy 71 float totalAccuracy = checker.GetAccuracy(1); 72 CHECK(totalAccuracy == 100.0f); 73 74 // Add image 2 and check accuracy 75 std::vector<float> inferenceOutputVector2 = {0.10f, 0.0f, 0.0f, 0.0f, 0.05f, 0.70f, 0.0f, 0.0f, 0.0f, 0.15f}; 76 armnnUtils::TContainer inference2Container(inferenceOutputVector2); 77 std::vector<armnnUtils::TContainer> outputTensor2; 78 outputTensor2.push_back(inference2Container); 79 80 imageName = "val_02.JPEG"; 81 checker.AddImageResult<armnnUtils::TContainer>(imageName, outputTensor2); 82 83 // Top 1 Accuracy 84 totalAccuracy = checker.GetAccuracy(1); 85 CHECK(totalAccuracy == 50.0f); 86 87 // Top 2 Accuracy 88 totalAccuracy = checker.GetAccuracy(2); 89 CHECK(totalAccuracy == 100.0f); 90 91 // Add image 3 and check accuracy 92 std::vector<float> inferenceOutputVector3 = {0.0f, 0.10f, 0.0f, 0.0f, 0.05f, 0.70f, 0.0f, 0.0f, 0.0f, 0.15f}; 93 armnnUtils::TContainer inference3Container(inferenceOutputVector3); 94 std::vector<armnnUtils::TContainer> outputTensor3; 95 outputTensor3.push_back(inference3Container); 96 97 imageName = "val_03.JPEG"; 98 checker.AddImageResult<armnnUtils::TContainer>(imageName, outputTensor3); 99 100 // Top 1 Accuracy 101 totalAccuracy = checker.GetAccuracy(1); 102 CHECK(totalAccuracy == 33.3333321f); 103 104 // Top 2 Accuracy 105 totalAccuracy = checker.GetAccuracy(2); 106 CHECK(totalAccuracy == 66.6666641f); 107 108 // Top 3 Accuracy 109 totalAccuracy = checker.GetAccuracy(3); 110 CHECK(totalAccuracy == 100.0f); 111 } 112 113 } 114