xref: /aosp_15_r20/external/armnn/python/pyarmnn/examples/object_detection/style_transfer.py (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1# Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
2# SPDX-License-Identifier: MIT
3
4import numpy as np
5import urllib.request
6import cv2
7import network_executor_tflite
8import cv_utils
9
10
11def style_transfer_postprocess(preprocessed_frame: np.ndarray, image_shape: tuple):
12    """
13        Resizes the output frame of style transfer network and changes the color back to original configuration
14
15        Args:
16            preprocessed_frame: A preprocessed frame after style transfer.
17            image_shape: Contains shape of the original frame before preprocessing.
18
19        Returns:
20            Resizing factor to scale coordinates according to image_shape.
21    """
22
23    postprocessed_frame = np.squeeze(preprocessed_frame, axis=0)
24    # select original height and width from image_shape
25    frame_height = image_shape[0]
26    frame_width = image_shape[1]
27    postprocessed_frame = cv2.resize(postprocessed_frame, (frame_width, frame_height)).astype("float32") * 255
28    postprocessed_frame = cv2.cvtColor(postprocessed_frame, cv2.COLOR_RGB2BGR)
29
30    return postprocessed_frame
31
32
33def create_stylized_detection(style_transfer_executor, style_transfer_class, frame: np.ndarray,
34                              detections: list, resize_factor, labels: dict):
35    """
36        Perform style transfer on a detected class in a frame
37
38        Args:
39            style_transfer_executor: The style transfer executor
40            style_transfer_class: The class detected to change its style
41            frame: The original captured frame from video source.
42            detections: A list of detected objects in the form [class, [box positions], confidence].
43            resize_factor: Resizing factor to scale box coordinates to output frame size.
44            labels: Dictionary of labels and colors keyed on the classification index.
45    """
46    for detection in detections:
47        class_idx, box, confidence = [d for d in detection]
48        label = labels[class_idx][0]
49        if label.lower() == style_transfer_class.lower():
50            # Obtain frame size and resized bounding box positions
51            frame_height, frame_width = frame.shape[:2]
52            x_min, y_min, x_max, y_max = [int(position * resize_factor) for position in box]
53
54            # Ensure box stays within the frame
55            x_min, y_min = max(0, x_min), max(0, y_min)
56            x_max, y_max = min(frame_width, x_max), min(frame_height, y_max)
57
58            # Crop only the detected object
59            cropped_frame = cv_utils.crop_bounding_box_object(frame, x_min, y_min, x_max, y_max)
60
61            # Run style_transfer on preprocessed_frame
62            stylized_frame = style_transfer_executor.run_style_transfer(cropped_frame)
63
64            # Paste stylized_frame on the original frame in the correct place
65            frame[int(y_min)+1:int(y_max),  int(x_min)+1:int(x_max)] = stylized_frame
66
67    return frame
68
69
70class StyleTransfer:
71
72    def __init__(self, style_predict_model_path: str, style_transfer_model_path: str,
73                 style_image: np.ndarray, backends: list, delegate_path: str):
74        """
75        Creates an inference executor for style predict network, style transfer network,
76        list of backends and a style image.
77
78        Args:
79            style_predict_model_path: model which is used to create a style bottleneck
80            style_transfer_model_path: model which is used to create stylized frames
81            style_image: an image to create the style bottleneck
82            backends: List of backends to optimize network.
83            delegate_path: tflite delegate file path (.so).
84        """
85
86        self.style_predict_executor = network_executor_tflite.TFLiteNetworkExecutor(style_predict_model_path, backends,
87                                                                                    delegate_path)
88        self.style_transfer_executor = network_executor_tflite.TFLiteNetworkExecutor(style_transfer_model_path,
89                                                                                     backends,
90                                                                                     delegate_path)
91        self.style_bottleneck = self.run_style_predict(style_image)
92
93    def get_style_predict_executor_shape(self):
94        """
95            Get the input shape of the initiated network.
96
97            Returns:
98                tuple: The Shape of the network input.
99        """
100        return self.style_predict_executor.get_shape()
101
102    # Function to run create a style_bottleneck using preprocessed style image.
103    def run_style_predict(self, style_image):
104        """
105            Creates bottleneck tensor for a given style image.
106
107            Args:
108                style_image: an image to create the style bottleneck
109
110            Returns:
111                style bottleneck tensor
112        """
113        # The style image has to be preprocessed to (1, 256, 256, 3)
114        preprocessed_style_image = cv_utils.preprocess(style_image, self.style_predict_executor.get_data_type(),
115                                                       self.style_predict_executor.get_shape(), True, keep_aspect_ratio=False)
116        # output[0] is the style bottleneck tensor
117        style_bottleneck = self.style_predict_executor.run([preprocessed_style_image])[0]
118
119        return style_bottleneck
120
121    # Run style transform on preprocessed style image
122    def run_style_transfer(self, content_image):
123        """
124            Runs inference for given content_image and style bottleneck to create a stylized image.
125
126            Args:
127                content_image:a content image to stylize
128        """
129        # The content image has to be preprocessed to (1, 384, 384, 3)
130        preprocessed_style_image = cv_utils.preprocess(content_image, np.float32,
131                                                       self.style_transfer_executor.get_shape(), True, keep_aspect_ratio=False)
132
133        # Transform content image. output[0] is the stylized image
134        stylized_image = self.style_transfer_executor.run([preprocessed_style_image, self.style_bottleneck])[0]
135
136        post_stylized_image = style_transfer_postprocess(stylized_image, content_image.shape)
137
138        return post_stylized_image
139