1# Copyright © 2020 Arm Ltd and Contributors. All rights reserved. 2# SPDX-License-Identifier: MIT 3 4import tflite_runtime.interpreter as tflite 5import numpy as np 6import os 7 8 9def run_mock_model(delegate, test_data_folder): 10 model_path = os.path.join(test_data_folder, 'mock_model.tflite') 11 interpreter = tflite.Interpreter(model_path=model_path, 12 experimental_delegates=[delegate]) 13 interpreter.allocate_tensors() 14 15 # Get input and output tensors. 16 input_details = interpreter.get_input_details() 17 output_details = interpreter.get_output_details() 18 19 # Test model on random input data. 20 input_shape = input_details[0]['shape'] 21 input_data = np.array(np.random.random_sample(input_shape), dtype=np.uint8) 22 interpreter.set_tensor(input_details[0]['index'], input_data) 23 24 interpreter.invoke() 25 26def run_inference(test_data_folder, model_filename, inputs, delegates=None): 27 model_path = os.path.join(test_data_folder, model_filename) 28 interpreter = tflite.Interpreter(model_path=model_path, 29 experimental_delegates=delegates) 30 interpreter.allocate_tensors() 31 32 # Get input and output tensors. 33 input_details = interpreter.get_input_details() 34 output_details = interpreter.get_output_details() 35 36 # Set inputs to tensors. 37 for i in range(len(inputs)): 38 interpreter.set_tensor(input_details[i]['index'], inputs[i]) 39 40 interpreter.invoke() 41 42 results = [] 43 for output in output_details: 44 results.append(interpreter.get_tensor(output['index'])) 45 46 return results 47 48def compare_outputs(outputs, expected_outputs): 49 assert len(outputs) == len(expected_outputs), 'Incorrect number of outputs' 50 for i in range(len(expected_outputs)): 51 assert outputs[i].shape == expected_outputs[i].shape, 'Incorrect output shape on output#{}'.format(i) 52 assert outputs[i].dtype == expected_outputs[i].dtype, 'Incorrect output data type on output#{}'.format(i) 53 assert outputs[i].all() == expected_outputs[i].all(), 'Incorrect output value on output#{}'.format(i)