xref: /aosp_15_r20/external/armnn/samples/ObjectDetection/test/PipelineTest.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include <catch.hpp>
6 #include <opencv2/opencv.hpp>
7 #include "ObjectDetectionPipeline.hpp"
8 #include "Types.hpp"
9 
GetResourceFilePath(const std::string & filename)10 static std::string GetResourceFilePath(const std::string& filename)
11 {
12     std::string testResources = TEST_RESOURCE_DIR;
13     if (0 == testResources.size())
14     {
15         throw "Invalid test resources directory provided";
16     }
17     else
18     {
19         if(testResources.back() != '/')
20         {
21             return testResources + "/" + filename;
22         }
23         else
24         {
25             return testResources + filename;
26         }
27     }
28 }
29 
30 TEST_CASE("Test Network Execution SSD_MOBILE")
31 {
32     std::string testResources = TEST_RESOURCE_DIR;
33     REQUIRE(testResources != "");
34     // Create the network options
35     common::PipelineOptions options;
36     options.m_ModelFilePath = GetResourceFilePath("ssd_mobilenet_v1.tflite");
37     options.m_ModelName = "SSD_MOBILE";
38     options.m_backends = {"CpuAcc", "CpuRef"};
39 
40     od::IPipelinePtr objectDetectionPipeline = od::CreatePipeline(options);
41 
42     common::InferenceResults<float> results;
43     cv::Mat processed;
44     cv::Mat inputFrame = cv::imread(GetResourceFilePath("basketball1.png"), cv::IMREAD_COLOR);
45     cv::cvtColor(inputFrame, inputFrame, cv::COLOR_BGR2RGB);
46 
47     objectDetectionPipeline->PreProcessing(inputFrame, processed);
48 
49     CHECK(processed.type() == CV_8UC3);
50     CHECK(processed.cols == 300);
51     CHECK(processed.rows == 300);
52 
53     objectDetectionPipeline->Inference(processed, results);
54     objectDetectionPipeline->PostProcessing(results,
__anon384230110102(od::DetectedObjects detects) 55                                             [](od::DetectedObjects detects) -> void {
56                                                 CHECK(detects.size() == 2);
57                                                 CHECK(detects[0].GetLabel() == "0");
58                                             });
59 
60 }
61