xref: /aosp_15_r20/external/armnn/python/pyarmnn/examples/common/cv_utils.py (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1# Copyright © 2020-2022 Arm Ltd and Contributors. All rights reserved.
2# SPDX-License-Identifier: MIT
3
4"""
5This file contains helper functions for reading video/image data and
6 pre/postprocessing of video/image data using OpenCV.
7"""
8
9import os
10
11import cv2
12import numpy as np
13
14
15def preprocess(frame: np.ndarray, input_data_type, input_data_shape: tuple, is_normalised: bool,
16               keep_aspect_ratio: bool=True):
17    """
18    Takes a frame, resizes, swaps channels and converts data type to match
19    model input layer.
20
21    Args:
22        frame: Captured frame from video.
23        input_data_type:  Contains data type of model input layer.
24        input_data_shape: Contains shape of model input layer.
25        is_normalised: if the input layer expects normalised data
26        keep_aspect_ratio: Network executor's input data aspect ratio
27
28    Returns:
29        Input tensor.
30    """
31    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
32
33    if keep_aspect_ratio:
34        # Swap channels and resize frame to model resolution
35        resized_frame = resize_with_aspect_ratio(frame, input_data_shape)
36    else:
37        # select the height and width from input_data_shape
38        frame_height = input_data_shape[1]
39        frame_width = input_data_shape[2]
40        resized_frame = cv2.resize(frame, (frame_width, frame_height))
41    # Expand dimensions and convert data type to match model input
42    if np.float32 == input_data_type:
43        data_type = np.float32
44        if is_normalised:
45            resized_frame = resized_frame.astype("float32")/255
46    else:
47        data_type = np.uint8
48
49    resized_frame = np.expand_dims(np.asarray(resized_frame, dtype=data_type), axis=0)
50    assert resized_frame.shape == input_data_shape
51
52    return resized_frame
53
54
55def resize_with_aspect_ratio(frame: np.ndarray, input_data_shape: tuple):
56    """
57    Resizes frame while maintaining aspect ratio, padding any empty space.
58
59    Args:
60        frame: Captured frame.
61        input_data_shape: Contains shape of model input layer.
62
63    Returns:
64        Frame resized to the size of model input layer.
65    """
66    aspect_ratio = frame.shape[1] / frame.shape[0]
67    _, model_height, model_width, _ = input_data_shape
68
69    if aspect_ratio >= 1.0:
70        new_height, new_width = int(model_width / aspect_ratio), model_width
71        b_padding, r_padding = model_height - new_height, 0
72    else:
73        new_height, new_width = model_height, int(model_height * aspect_ratio)
74        b_padding, r_padding = 0, model_width - new_width
75
76    # Resize and pad any empty space
77    frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_LINEAR)
78    frame = cv2.copyMakeBorder(frame, top=0, bottom=b_padding, left=0, right=r_padding,
79                               borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0])
80    return frame
81
82
83def create_video_writer(video: cv2.VideoCapture, video_path: str, output_path: str):
84    """
85    Creates a video writer object to write processed frames to file.
86
87    Args:
88        video: Video capture object, contains information about data source.
89        video_path: User-specified video file path.
90        output_path: Optional path to save the processed video.
91
92    Returns:
93        Video writer object.
94    """
95    _, ext = os.path.splitext(video_path)
96
97    if output_path is not None:
98        assert os.path.isdir(output_path)
99
100    i, filename = 0, os.path.join(output_path if output_path is not None else str(), f'object_detection_demo{ext}')
101    while os.path.exists(filename):
102        i += 1
103        filename = os.path.join(output_path if output_path is not None else str(), f'object_detection_demo({i}){ext}')
104
105    video_writer = cv2.VideoWriter(filename=filename,
106                                   fourcc=get_source_encoding_int(video),
107                                   fps=int(video.get(cv2.CAP_PROP_FPS)),
108                                   frameSize=(int(video.get(cv2.CAP_PROP_FRAME_WIDTH)),
109                                              int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))))
110    return video_writer
111
112
113def init_video_file_capture(video_path: str, output_path: str):
114    """
115    Creates a video capture object from a video file.
116
117    Args:
118        video_path: User-specified video file path.
119        output_path: Optional path to save the processed video.
120
121    Returns:
122        Video capture object to capture frames, video writer object to write processed
123        frames to file, plus total frame count of video source to iterate through.
124    """
125    if not os.path.exists(video_path):
126        raise FileNotFoundError(f'Video file not found for: {video_path}')
127    video = cv2.VideoCapture(video_path)
128    if not video.isOpened:
129        raise RuntimeError(f'Failed to open video capture from file: {video_path}')
130
131    video_writer = create_video_writer(video, video_path, output_path)
132    iter_frame_count = range(int(video.get(cv2.CAP_PROP_FRAME_COUNT)))
133    return video, video_writer, iter_frame_count
134
135
136def init_video_stream_capture(video_source: int):
137    """
138    Creates a video capture object from a device.
139
140    Args:
141        video_source: Device index used to read video stream.
142
143    Returns:
144        Video capture object used to capture frames from a video stream.
145    """
146    video = cv2.VideoCapture(video_source)
147    if not video.isOpened:
148        raise RuntimeError(f'Failed to open video capture for device with index: {video_source}')
149    print('Processing video stream. Press \'Esc\' key to exit the demo.')
150    return video
151
152
153def draw_bounding_boxes(frame: np.ndarray, detections: list, resize_factor, labels: dict):
154    """
155    Draws bounding boxes around detected objects and adds a label and confidence score.
156
157    Args:
158        frame: The original captured frame from video source.
159        detections: A list of detected objects in the form [class, [box positions], confidence].
160        resize_factor: Resizing factor to scale box coordinates to output frame size.
161        labels: Dictionary of labels and colors keyed on the classification index.
162    """
163    for detection in detections:
164        class_idx, box, confidence = [d for d in detection]
165        label, color = labels[class_idx][0].capitalize(), labels[class_idx][1]
166
167        # Obtain frame size and resized bounding box positions
168        frame_height, frame_width = frame.shape[:2]
169        x_min, y_min, x_max, y_max = [int(position * resize_factor) for position in box]
170
171        # Ensure box stays within the frame
172        x_min, y_min = max(0, x_min), max(0, y_min)
173        x_max, y_max = min(frame_width, x_max), min(frame_height, y_max)
174
175        # Draw bounding box around detected object
176        cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), color, 2)
177
178        # Create label for detected object class
179        label = f'{label} {confidence * 100:.1f}%'
180        label_color = (0, 0, 0) if sum(color) > 200 else (255, 255, 255)
181
182        # Make sure label always stays on-screen
183        x_text, y_text = cv2.getTextSize(label, cv2.FONT_HERSHEY_DUPLEX, 1, 1)[0][:2]
184
185        lbl_box_xy_min = (x_min, y_min if y_min<25 else y_min - y_text)
186        lbl_box_xy_max = (x_min + int(0.55 * x_text), y_min + y_text if y_min<25 else y_min)
187        lbl_text_pos = (x_min + 5, y_min + 16 if y_min < 25 else y_min - 5)
188
189        # Add label and confidence value
190        cv2.rectangle(frame, lbl_box_xy_min, lbl_box_xy_max, color, -1)
191        cv2.putText(frame, label, lbl_text_pos, cv2.FONT_HERSHEY_DUPLEX, 0.50,
192                    label_color, 1, cv2.LINE_AA)
193
194
195def get_source_encoding_int(video_capture):
196    return int(video_capture.get(cv2.CAP_PROP_FOURCC))
197
198
199def crop_bounding_box_object(input_frame: np.ndarray, x_min: float, y_min: float, x_max: float, y_max: float):
200    """
201        Creates a cropped image based on x and y coordinates.
202
203        Args:
204            input_frame: Image to crop
205            x_min, y_min, x_max, y_max: Coordinates of the bounding box
206
207        Returns:
208            Cropped image
209    """
210    # Adding +1 to exclude the bounding box pixels.
211    cropped_image = input_frame[int(y_min) + 1:int(y_max), int(x_min) + 1:int(x_max)]
212    return cropped_image
213