1# 2# Copyright © 2021 Arm Ltd and Contributors. All rights reserved. 3# SPDX-License-Identifier: MIT 4# 5import argparse 6from pathlib import Path 7from typing import Union 8 9import tflite_runtime.interpreter as tflite 10from PIL import Image 11import numpy as np 12 13 14def check_args(args: argparse.Namespace): 15 """Check the values used in the command-line have acceptable values 16 17 args: 18 - args: argparse.Namespace 19 20 returns: 21 - None 22 23 raises: 24 - FileNotFoundError: if passed files do not exist. 25 - IOError: if files are of incorrect format. 26 """ 27 input_image_p = args.input_image 28 if not input_image_p.suffix in (".png", ".jpg", ".jpeg"): 29 raise IOError( 30 "--input_image option should point to an image file of the " 31 "format .jpg, .jpeg, .png" 32 ) 33 if not input_image_p.exists(): 34 raise FileNotFoundError("Cannot find ", input_image_p.name) 35 model_p = args.model_file 36 if not model_p.suffix == ".tflite": 37 raise IOError("--model_file should point to a tflite file.") 38 if not model_p.exists(): 39 raise FileNotFoundError("Cannot find ", model_p.name) 40 label_mapping_p = args.label_file 41 if not label_mapping_p.suffix == ".txt": 42 raise IOError("--label_file expects a .txt file.") 43 if not label_mapping_p.exists(): 44 raise FileNotFoundError("Cannot find ", label_mapping_p.name) 45 46 # check all args given in preferred backends make sense 47 supported_backends = ["GpuAcc", "CpuAcc", "CpuRef"] 48 if not all([backend in supported_backends for backend in args.preferred_backends]): 49 raise ValueError("Incorrect backends given. Please choose from "\ 50 "'GpuAcc', 'CpuAcc', 'CpuRef'.") 51 52 return None 53 54 55def load_image(image_path: Path, model_input_dims: Union[tuple, list], grayscale: bool): 56 """load an image and put into correct format for the tensorflow lite model 57 58 args: 59 - image_path: pathlib.Path 60 - model_input_dims: tuple (or array-like). (height,width) 61 62 returns: 63 - image: np.array 64 """ 65 height, width = model_input_dims 66 # load and resize image 67 image = Image.open(image_path).resize((width, height)) 68 # convert to greyscale if expected 69 if grayscale: 70 image = image.convert("LA") 71 72 image = np.expand_dims(image, axis=0) 73 74 return image 75 76 77def load_delegate(delegate_path: Path, backends: list): 78 """load the armnn delegate. 79 80 args: 81 - delegate_path: pathlib.Path -> location of you libarmnnDelegate.so 82 - backends: list -> list of backends you want to use in string format 83 84 returns: 85 - armnn_delegate: tflite.delegate 86 """ 87 # create a command separated string 88 backend_string = ",".join(backends) 89 # load delegate 90 armnn_delegate = tflite.load_delegate( 91 library=delegate_path, 92 options={"backends": backend_string, "logging-severity": "info"}, 93 ) 94 95 return armnn_delegate 96 97 98def load_tf_model(model_path: Path, armnn_delegate: tflite.Delegate): 99 """load a tflite model for use with the armnn delegate. 100 101 args: 102 - model_path: pathlib.Path 103 - armnn_delegate: tflite.TfLiteDelegate 104 105 returns: 106 - interpreter: tflite.Interpreter 107 """ 108 interpreter = tflite.Interpreter( 109 model_path=model_path.as_posix(), experimental_delegates=[armnn_delegate] 110 ) 111 interpreter.allocate_tensors() 112 113 return interpreter 114 115 116def run_inference(interpreter, input_image): 117 """Run inference on a processed input image and return the output from 118 inference. 119 120 args: 121 - interpreter: tflite_runtime.interpreter.Interpreter 122 - input_image: np.array 123 124 returns: 125 - output_data: np.array 126 """ 127 # Get input and output tensors. 128 input_details = interpreter.get_input_details() 129 output_details = interpreter.get_output_details() 130 # Test model on random input data. 131 interpreter.set_tensor(input_details[0]["index"], input_image) 132 interpreter.invoke() 133 output_data = interpreter.get_tensor(output_details[0]["index"]) 134 135 return output_data 136 137 138def create_mapping(label_mapping_p): 139 """Creates a Python dictionary mapping an index to a label. 140 141 label_mapping[idx] = label 142 143 args: 144 - label_mapping_p: pathlib.Path 145 146 returns: 147 - label_mapping: dict 148 """ 149 idx = 0 150 label_mapping = {} 151 with open(label_mapping_p) as label_mapping_raw: 152 for line in label_mapping_raw: 153 label_mapping[idx] = line 154 idx += 1 155 156 return label_mapping 157 158 159def process_output(output_data, label_mapping): 160 """Process the output tensor into a label from the labelmapping file. Takes 161 the index of the maximum valur from the output array. 162 163 args: 164 - output_data: np.array 165 - label_mapping: dict 166 167 returns: 168 - str: labelmapping for max index. 169 """ 170 idx = np.argmax(output_data[0]) 171 172 return label_mapping[idx] 173 174 175def main(args): 176 """Run the inference for options passed in the command line. 177 178 args: 179 - args: argparse.Namespace 180 181 returns: 182 - None 183 """ 184 # sanity check on args 185 check_args(args) 186 # load in the armnn delegate 187 armnn_delegate = load_delegate(args.delegate_path, args.preferred_backends) 188 # load tflite model 189 interpreter = load_tf_model(args.model_file, armnn_delegate) 190 # get input shape for image resizing 191 input_shape = interpreter.get_input_details()[0]["shape"] 192 height, width = input_shape[1], input_shape[2] 193 input_shape = (height, width) 194 # load input image 195 input_image = load_image(args.input_image, input_shape, False) 196 # get label mapping 197 labelmapping = create_mapping(args.label_file) 198 output_tensor = run_inference(interpreter, input_image) 199 output_prediction = process_output(output_tensor, labelmapping) 200 201 print("Prediction: ", output_prediction) 202 203 return None 204 205 206if __name__ == "__main__": 207 parser = argparse.ArgumentParser( 208 formatter_class=argparse.ArgumentDefaultsHelpFormatter 209 ) 210 parser.add_argument( 211 "--input_image", help="File path of image file", type=Path, required=True 212 ) 213 parser.add_argument( 214 "--model_file", 215 help="File path of the model tflite file", 216 type=Path, 217 required=True, 218 ) 219 parser.add_argument( 220 "--label_file", 221 help="File path of model labelmapping file", 222 type=Path, 223 required=True, 224 ) 225 parser.add_argument( 226 "--delegate_path", 227 help="File path of ArmNN delegate file", 228 type=Path, 229 required=True, 230 ) 231 parser.add_argument( 232 "--preferred_backends", 233 help="list of backends in order of preference", 234 type=str, 235 nargs="+", 236 required=False, 237 default=["CpuAcc", "CpuRef"], 238 ) 239 args = parser.parse_args() 240 241 main(args) 242