xref: /aosp_15_r20/external/armnn/python/pyarmnn/examples/object_detection/ssd.py (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
2# SPDX-License-Identifier: MIT
3
4"""
5Contains functions specific to decoding and processing inference results for SSD Mobilenet V1 models.
6"""
7
8import cv2
9import numpy as np
10
11
12def ssd_processing(output: np.ndarray, confidence_threshold=0.60):
13    """
14    Gets class, bounding box positions and confidence from the four outputs of the SSD model.
15
16    Args:
17         output: Vector of outputs from network.
18         confidence_threshold: Selects only strong detections above this value.
19
20    Returns:
21        A list of detected objects in the form [class, [box positions], confidence]
22    """
23    if len(output) != 4:
24        raise RuntimeError('Number of outputs from SSD model does not equal 4')
25
26    position, classification, confidence, num_detections = [index[0] for index in output]
27
28    detections = []
29    for i in range(int(num_detections)):
30        if confidence[i] > confidence_threshold:
31            class_idx = classification[i]
32            box = position[i, :4]
33            # Reorder positions in format [x_min, y_min, x_max, y_max]
34            box[0], box[1], box[2], box[3] = box[1], box[0], box[3], box[2]
35            confidence_value = confidence[i]
36            detections.append((class_idx, box, confidence_value))
37    return detections
38
39
40def ssd_resize_factor(video: cv2.VideoCapture):
41    """
42    Gets a multiplier to scale the bounding box positions to
43    their correct position in the frame.
44
45    Args:
46        video: Video capture object, contains information about data source.
47
48    Returns:
49        Resizing factor to scale box coordinates to output frame size.
50    """
51    frame_height = video.get(cv2.CAP_PROP_FRAME_HEIGHT)
52    frame_width = video.get(cv2.CAP_PROP_FRAME_WIDTH)
53    return max(frame_height, frame_width)
54