1# Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 2# SPDX-License-Identifier: MIT 3 4import os 5import cv2 6import numpy as np 7 8from context import style_transfer 9from context import cv_utils 10 11 12def test_style_transfer_postprocess(test_data_folder): 13 content_image = "messi5.jpg" 14 target_shape = (1,256,256,3) 15 keep_aspect_ratio = False 16 image = cv2.imread(os.path.join(test_data_folder, content_image)) 17 original_shape = image.shape 18 preprocessed_image = cv_utils.preprocess(image, np.float32, target_shape, False, keep_aspect_ratio) 19 assert preprocessed_image.shape == target_shape 20 21 postprocess_image = style_transfer.style_transfer_postprocess(preprocessed_image, original_shape) 22 assert postprocess_image.shape == original_shape 23 24 25def test_style_transfer(test_data_folder): 26 style_predict_model_path = os.path.join(test_data_folder, "style_predict.tflite") 27 style_transfer_model_path = os.path.join(test_data_folder, "style_transfer.tflite") 28 backends = ["CpuAcc", "CpuRef"] 29 delegate_path = os.path.join(test_data_folder, "libarmnnDelegate.so") 30 image = cv2.imread(os.path.join(test_data_folder, "messi5.jpg")) 31 32 style_transfer_executor = style_transfer.StyleTransfer(style_predict_model_path, style_transfer_model_path, 33 image, backends, delegate_path) 34 35 assert style_transfer_executor.get_style_predict_executor_shape() == (1, 256, 256, 3) 36 37def test_run_style_transfer(test_data_folder): 38 style_predict_model_path = os.path.join(test_data_folder, "style_predict.tflite") 39 style_transfer_model_path = os.path.join(test_data_folder, "style_transfer.tflite") 40 backends = ["CpuAcc", "CpuRef"] 41 delegate_path = os.path.join(test_data_folder, "libarmnnDelegate.so") 42 style_image = cv2.imread(os.path.join(test_data_folder, "messi5.jpg")) 43 content_image = cv2.imread(os.path.join(test_data_folder, "basketball1.png")) 44 45 style_transfer_executor = style_transfer.StyleTransfer(style_predict_model_path, style_transfer_model_path, 46 style_image, backends, delegate_path) 47 48 stylized_image = style_transfer_executor.run_style_transfer(content_image) 49 assert stylized_image.shape == content_image.shape 50 51 52def test_create_stylized_detection(test_data_folder): 53 style_predict_model_path = os.path.join(test_data_folder, "style_predict.tflite") 54 style_transfer_model_path = os.path.join(test_data_folder, "style_transfer.tflite") 55 backends = ["CpuAcc", "CpuRef"] 56 delegate_path = os.path.join(test_data_folder, "libarmnnDelegate.so") 57 58 style_image = cv2.imread(os.path.join(test_data_folder, "messi5.jpg")) 59 content_image = cv2.imread(os.path.join(test_data_folder, "basketball1.png")) 60 detections = [(0.0, [0.16745174, 0.15101701, 0.5371381, 0.74165875], 0.87597656)] 61 labels = {0: ('person', (50.888902345757494, 129.61878417939724, 207.2891028294508)), 62 1: ('bicycle', (55.055339686943654, 55.828708219750574, 43.550389695374676)), 63 2: ('car', (95.33096265662336, 194.872841553212, 218.58516479057758))} 64 style_transfer_executor = style_transfer.StyleTransfer(style_predict_model_path, style_transfer_model_path, 65 style_image, backends, delegate_path) 66 67 stylized_image = style_transfer.create_stylized_detection(style_transfer_executor, 'person', content_image, 68 detections, 720, labels) 69 70 assert stylized_image.shape == content_image.shape 71