1# Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 2# SPDX-License-Identifier: MIT 3 4import numpy as np 5import urllib.request 6import cv2 7import network_executor_tflite 8import cv_utils 9 10 11def style_transfer_postprocess(preprocessed_frame: np.ndarray, image_shape: tuple): 12 """ 13 Resizes the output frame of style transfer network and changes the color back to original configuration 14 15 Args: 16 preprocessed_frame: A preprocessed frame after style transfer. 17 image_shape: Contains shape of the original frame before preprocessing. 18 19 Returns: 20 Resizing factor to scale coordinates according to image_shape. 21 """ 22 23 postprocessed_frame = np.squeeze(preprocessed_frame, axis=0) 24 # select original height and width from image_shape 25 frame_height = image_shape[0] 26 frame_width = image_shape[1] 27 postprocessed_frame = cv2.resize(postprocessed_frame, (frame_width, frame_height)).astype("float32") * 255 28 postprocessed_frame = cv2.cvtColor(postprocessed_frame, cv2.COLOR_RGB2BGR) 29 30 return postprocessed_frame 31 32 33def create_stylized_detection(style_transfer_executor, style_transfer_class, frame: np.ndarray, 34 detections: list, resize_factor, labels: dict): 35 """ 36 Perform style transfer on a detected class in a frame 37 38 Args: 39 style_transfer_executor: The style transfer executor 40 style_transfer_class: The class detected to change its style 41 frame: The original captured frame from video source. 42 detections: A list of detected objects in the form [class, [box positions], confidence]. 43 resize_factor: Resizing factor to scale box coordinates to output frame size. 44 labels: Dictionary of labels and colors keyed on the classification index. 45 """ 46 for detection in detections: 47 class_idx, box, confidence = [d for d in detection] 48 label = labels[class_idx][0] 49 if label.lower() == style_transfer_class.lower(): 50 # Obtain frame size and resized bounding box positions 51 frame_height, frame_width = frame.shape[:2] 52 x_min, y_min, x_max, y_max = [int(position * resize_factor) for position in box] 53 54 # Ensure box stays within the frame 55 x_min, y_min = max(0, x_min), max(0, y_min) 56 x_max, y_max = min(frame_width, x_max), min(frame_height, y_max) 57 58 # Crop only the detected object 59 cropped_frame = cv_utils.crop_bounding_box_object(frame, x_min, y_min, x_max, y_max) 60 61 # Run style_transfer on preprocessed_frame 62 stylized_frame = style_transfer_executor.run_style_transfer(cropped_frame) 63 64 # Paste stylized_frame on the original frame in the correct place 65 frame[int(y_min)+1:int(y_max), int(x_min)+1:int(x_max)] = stylized_frame 66 67 return frame 68 69 70class StyleTransfer: 71 72 def __init__(self, style_predict_model_path: str, style_transfer_model_path: str, 73 style_image: np.ndarray, backends: list, delegate_path: str): 74 """ 75 Creates an inference executor for style predict network, style transfer network, 76 list of backends and a style image. 77 78 Args: 79 style_predict_model_path: model which is used to create a style bottleneck 80 style_transfer_model_path: model which is used to create stylized frames 81 style_image: an image to create the style bottleneck 82 backends: List of backends to optimize network. 83 delegate_path: tflite delegate file path (.so). 84 """ 85 86 self.style_predict_executor = network_executor_tflite.TFLiteNetworkExecutor(style_predict_model_path, backends, 87 delegate_path) 88 self.style_transfer_executor = network_executor_tflite.TFLiteNetworkExecutor(style_transfer_model_path, 89 backends, 90 delegate_path) 91 self.style_bottleneck = self.run_style_predict(style_image) 92 93 def get_style_predict_executor_shape(self): 94 """ 95 Get the input shape of the initiated network. 96 97 Returns: 98 tuple: The Shape of the network input. 99 """ 100 return self.style_predict_executor.get_shape() 101 102 # Function to run create a style_bottleneck using preprocessed style image. 103 def run_style_predict(self, style_image): 104 """ 105 Creates bottleneck tensor for a given style image. 106 107 Args: 108 style_image: an image to create the style bottleneck 109 110 Returns: 111 style bottleneck tensor 112 """ 113 # The style image has to be preprocessed to (1, 256, 256, 3) 114 preprocessed_style_image = cv_utils.preprocess(style_image, self.style_predict_executor.get_data_type(), 115 self.style_predict_executor.get_shape(), True, keep_aspect_ratio=False) 116 # output[0] is the style bottleneck tensor 117 style_bottleneck = self.style_predict_executor.run([preprocessed_style_image])[0] 118 119 return style_bottleneck 120 121 # Run style transform on preprocessed style image 122 def run_style_transfer(self, content_image): 123 """ 124 Runs inference for given content_image and style bottleneck to create a stylized image. 125 126 Args: 127 content_image:a content image to stylize 128 """ 129 # The content image has to be preprocessed to (1, 384, 384, 3) 130 preprocessed_style_image = cv_utils.preprocess(content_image, np.float32, 131 self.style_transfer_executor.get_shape(), True, keep_aspect_ratio=False) 132 133 # Transform content image. output[0] is the stylized image 134 stylized_image = self.style_transfer_executor.run([preprocessed_style_image, self.style_bottleneck])[0] 135 136 post_stylized_image = style_transfer_postprocess(stylized_image, content_image.shape) 137 138 return post_stylized_image 139