1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #pragma once 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker #include <ostream> 9*89c4ff92SAndroid Build Coastguard Worker #include <vector> 10*89c4ff92SAndroid Build Coastguard Worker 11*89c4ff92SAndroid Build Coastguard Worker namespace yolov3 { 12*89c4ff92SAndroid Build Coastguard Worker /** Non Maxima Suprresion configuration meta-data */ 13*89c4ff92SAndroid Build Coastguard Worker struct NMSConfig { 14*89c4ff92SAndroid Build Coastguard Worker unsigned int num_classes{0}; /**< Number of classes in the detected boxes */ 15*89c4ff92SAndroid Build Coastguard Worker unsigned int num_boxes{0}; /**< Number of detected boxes */ 16*89c4ff92SAndroid Build Coastguard Worker float confidence_threshold{0.8f}; /**< Inclusion confidence threshold for a box */ 17*89c4ff92SAndroid Build Coastguard Worker float iou_threshold{0.8f}; /**< Inclusion threshold for Intersection-Over-Union */ 18*89c4ff92SAndroid Build Coastguard Worker }; 19*89c4ff92SAndroid Build Coastguard Worker 20*89c4ff92SAndroid Build Coastguard Worker /** Box representation structure */ 21*89c4ff92SAndroid Build Coastguard Worker struct Box { 22*89c4ff92SAndroid Build Coastguard Worker float xmin; /**< X-pos position of the low left coordinate */ 23*89c4ff92SAndroid Build Coastguard Worker float xmax; /**< X-pos position of the top right coordinate */ 24*89c4ff92SAndroid Build Coastguard Worker float ymin; /**< Y-pos position of the low left coordinate */ 25*89c4ff92SAndroid Build Coastguard Worker float ymax; /**< Y-pos position of the top right coordinate */ 26*89c4ff92SAndroid Build Coastguard Worker }; 27*89c4ff92SAndroid Build Coastguard Worker 28*89c4ff92SAndroid Build Coastguard Worker /** Detection structure */ 29*89c4ff92SAndroid Build Coastguard Worker struct Detection { 30*89c4ff92SAndroid Build Coastguard Worker Box box; /**< Detection box */ 31*89c4ff92SAndroid Build Coastguard Worker float confidence; /**< Confidence of detection */ 32*89c4ff92SAndroid Build Coastguard Worker std::vector<float> classes; /**< Probability of classes */ 33*89c4ff92SAndroid Build Coastguard Worker }; 34*89c4ff92SAndroid Build Coastguard Worker 35*89c4ff92SAndroid Build Coastguard Worker /** Print identified yolo detections 36*89c4ff92SAndroid Build Coastguard Worker * 37*89c4ff92SAndroid Build Coastguard Worker * @param[in, out] os Output stream to print to 38*89c4ff92SAndroid Build Coastguard Worker * @param[in] detections Detections to print 39*89c4ff92SAndroid Build Coastguard Worker */ 40*89c4ff92SAndroid Build Coastguard Worker void print_detection(std::ostream& os, 41*89c4ff92SAndroid Build Coastguard Worker const std::vector<Detection>& detections); 42*89c4ff92SAndroid Build Coastguard Worker 43*89c4ff92SAndroid Build Coastguard Worker /** Compare a detection object with a vector of float values 44*89c4ff92SAndroid Build Coastguard Worker * 45*89c4ff92SAndroid Build Coastguard Worker * @param detection [in] Detection object 46*89c4ff92SAndroid Build Coastguard Worker * @param expected [in] Vector of expected float values 47*89c4ff92SAndroid Build Coastguard Worker * @return Boolean to represent if they match or not 48*89c4ff92SAndroid Build Coastguard Worker */ 49*89c4ff92SAndroid Build Coastguard Worker bool compare_detection(const yolov3::Detection& detection, 50*89c4ff92SAndroid Build Coastguard Worker const std::vector<float>& expected); 51*89c4ff92SAndroid Build Coastguard Worker 52*89c4ff92SAndroid Build Coastguard Worker /** Perform Non-Maxima Supression on a list of given detections 53*89c4ff92SAndroid Build Coastguard Worker * 54*89c4ff92SAndroid Build Coastguard Worker * @param[in] config Configuration metadata for NMS 55*89c4ff92SAndroid Build Coastguard Worker * @param[in] detected_boxes Detected boxes 56*89c4ff92SAndroid Build Coastguard Worker * 57*89c4ff92SAndroid Build Coastguard Worker * @return A vector with the final detections 58*89c4ff92SAndroid Build Coastguard Worker */ 59*89c4ff92SAndroid Build Coastguard Worker std::vector<Detection> nms(const NMSConfig& config, 60*89c4ff92SAndroid Build Coastguard Worker const std::vector<float>& detected_boxes); 61*89c4ff92SAndroid Build Coastguard Worker } // namespace yolov3 62