1# Copyright © 2020 Arm Ltd and Contributors. All rights reserved. 2# SPDX-License-Identifier: MIT 3 4import os 5from typing import List, Tuple 6 7import pyarmnn as ann 8import numpy as np 9 10class ArmnnNetworkExecutor: 11 12 def __init__(self, model_file: str, backends: list): 13 """ 14 Creates an inference executor for a given network and a list of backends. 15 16 Args: 17 model_file: User-specified model file. 18 backends: List of backends to optimize network. 19 """ 20 self.model_file = model_file 21 self.backends = backends 22 self.network_id, self.runtime, self.input_binding_info, self.output_binding_info = self.create_network() 23 self.output_tensors = ann.make_output_tensors(self.output_binding_info) 24 25 def run(self, input_data_list: list) -> List[np.ndarray]: 26 """ 27 Creates input tensors from input data and executes inference with the loaded network. 28 29 Args: 30 input_data_list: List of input frames. 31 32 Returns: 33 list: Inference results as a list of ndarrays. 34 """ 35 input_tensors = ann.make_input_tensors(self.input_binding_info, input_data_list) 36 self.runtime.EnqueueWorkload(self.network_id, input_tensors, self.output_tensors) 37 output = ann.workload_tensors_to_ndarray(self.output_tensors) 38 39 return output 40 41 def create_network(self): 42 """ 43 Creates a network based on the model file and a list of backends. 44 45 Returns: 46 net_id: Unique ID of the network to run. 47 runtime: Runtime context for executing inference. 48 input_binding_info: Contains essential information about the model input. 49 output_binding_info: Used to map output tensor and its memory. 50 """ 51 if not os.path.exists(self.model_file): 52 raise FileNotFoundError(f'Model file not found for: {self.model_file}') 53 54 _, ext = os.path.splitext(self.model_file) 55 if ext == '.tflite': 56 parser = ann.ITfLiteParser() 57 else: 58 raise ValueError("Supplied model file type is not supported. Supported types are [ tflite ]") 59 60 network = parser.CreateNetworkFromBinaryFile(self.model_file) 61 62 # Specify backends to optimize network 63 preferred_backends = [] 64 for b in self.backends: 65 preferred_backends.append(ann.BackendId(b)) 66 67 # Select appropriate device context and optimize the network for that device 68 options = ann.CreationOptions() 69 runtime = ann.IRuntime(options) 70 opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(), 71 ann.OptimizerOptions()) 72 print(f'Preferred backends: {self.backends}\n{runtime.GetDeviceSpec()}\n' 73 f'Optimization warnings: {messages}') 74 75 # Load the optimized network onto the Runtime device 76 net_id, _ = runtime.LoadNetwork(opt_network) 77 78 # Get input and output binding information 79 graph_id = parser.GetSubgraphCount() - 1 80 input_names = parser.GetSubgraphInputTensorNames(graph_id) 81 input_binding_info = [] 82 for input_name in input_names: 83 in_bind_info = parser.GetNetworkInputBindingInfo(graph_id, input_name) 84 input_binding_info.append(in_bind_info) 85 output_names = parser.GetSubgraphOutputTensorNames(graph_id) 86 output_binding_info = [] 87 for output_name in output_names: 88 out_bind_info = parser.GetNetworkOutputBindingInfo(graph_id, output_name) 89 output_binding_info.append(out_bind_info) 90 return net_id, runtime, input_binding_info, output_binding_info 91 92 def get_data_type(self): 93 """ 94 Get the input data type of the initiated network. 95 96 Returns: 97 numpy data type or None if doesn't exist in the if condition. 98 """ 99 if self.input_binding_info[0][1].GetDataType() == ann.DataType_Float32: 100 return np.float32 101 elif self.input_binding_info[0][1].GetDataType() == ann.DataType_QAsymmU8: 102 return np.uint8 103 elif self.input_binding_info[0][1].GetDataType() == ann.DataType_QAsymmS8: 104 return np.int8 105 else: 106 return None 107 108 def get_shape(self): 109 """ 110 Get the input shape of the initiated network. 111 112 Returns: 113 tuple: The Shape of the network input. 114 """ 115 return tuple(self.input_binding_info[0][1].GetShape()) 116 117 def get_input_quantization_scale(self, idx): 118 """ 119 Get the input quantization scale of the initiated network. 120 121 Returns: 122 The quantization scale of the network input. 123 """ 124 return self.input_binding_info[idx][1].GetQuantizationScale() 125 126 def get_input_quantization_offset(self, idx): 127 """ 128 Get the input quantization offset of the initiated network. 129 130 Returns: 131 The quantization offset of the network input. 132 """ 133 return self.input_binding_info[idx][1].GetQuantizationOffset() 134 135 def is_output_quantized(self, idx): 136 """ 137 Get True/False if output tensor is quantized or not respectively. 138 139 Returns: 140 True if output is quantized and False otherwise. 141 """ 142 return self.output_binding_info[idx][1].IsQuantized() 143 144 def get_output_quantization_scale(self, idx): 145 """ 146 Get the output quantization offset of the initiated network. 147 148 Returns: 149 The quantization offset of the network output. 150 """ 151 return self.output_binding_info[idx][1].GetQuantizationScale() 152 153 def get_output_quantization_offset(self, idx): 154 """ 155 Get the output quantization offset of the initiated network. 156 157 Returns: 158 The quantization offset of the network output. 159 """ 160 return self.output_binding_info[idx][1].GetQuantizationOffset() 161 162