1*89c4ff92SAndroid Build Coastguard Worker# Copyright © 2021 Arm Ltd and Contributors. All rights reserved. 2*89c4ff92SAndroid Build Coastguard Worker# SPDX-License-Identifier: MIT 3*89c4ff92SAndroid Build Coastguard Worker 4*89c4ff92SAndroid Build Coastguard Worker"""Contains helper functions that can be used across the example apps.""" 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Workerimport os 7*89c4ff92SAndroid Build Coastguard Workerimport errno 8*89c4ff92SAndroid Build Coastguard Workerfrom pathlib import Path 9*89c4ff92SAndroid Build Coastguard Worker 10*89c4ff92SAndroid Build Coastguard Workerimport numpy as np 11*89c4ff92SAndroid Build Coastguard Workerimport datetime 12*89c4ff92SAndroid Build Coastguard Worker 13*89c4ff92SAndroid Build Coastguard Worker 14*89c4ff92SAndroid Build Coastguard Workerdef dict_labels(labels_file_path: str, include_rgb=False) -> dict: 15*89c4ff92SAndroid Build Coastguard Worker """Creates a dictionary of labels from the input labels file. 16*89c4ff92SAndroid Build Coastguard Worker 17*89c4ff92SAndroid Build Coastguard Worker Args: 18*89c4ff92SAndroid Build Coastguard Worker labels_file: Path to file containing labels to map model outputs. 19*89c4ff92SAndroid Build Coastguard Worker include_rgb: Adds randomly generated RGB values to the values of the 20*89c4ff92SAndroid Build Coastguard Worker dictionary. Used for plotting bounding boxes of different colours. 21*89c4ff92SAndroid Build Coastguard Worker 22*89c4ff92SAndroid Build Coastguard Worker Returns: 23*89c4ff92SAndroid Build Coastguard Worker Dictionary with classification indices for keys and labels for values. 24*89c4ff92SAndroid Build Coastguard Worker 25*89c4ff92SAndroid Build Coastguard Worker Raises: 26*89c4ff92SAndroid Build Coastguard Worker FileNotFoundError: 27*89c4ff92SAndroid Build Coastguard Worker Provided `labels_file_path` does not exist. 28*89c4ff92SAndroid Build Coastguard Worker """ 29*89c4ff92SAndroid Build Coastguard Worker labels_file = Path(labels_file_path) 30*89c4ff92SAndroid Build Coastguard Worker if not labels_file.is_file(): 31*89c4ff92SAndroid Build Coastguard Worker raise FileNotFoundError( 32*89c4ff92SAndroid Build Coastguard Worker errno.ENOENT, os.strerror(errno.ENOENT), labels_file_path 33*89c4ff92SAndroid Build Coastguard Worker ) 34*89c4ff92SAndroid Build Coastguard Worker 35*89c4ff92SAndroid Build Coastguard Worker labels = {} 36*89c4ff92SAndroid Build Coastguard Worker with open(labels_file, "r") as f: 37*89c4ff92SAndroid Build Coastguard Worker for idx, line in enumerate(f, 0): 38*89c4ff92SAndroid Build Coastguard Worker if include_rgb: 39*89c4ff92SAndroid Build Coastguard Worker labels[idx] = line.strip("\n"), tuple(np.random.random(size=3) * 255) 40*89c4ff92SAndroid Build Coastguard Worker else: 41*89c4ff92SAndroid Build Coastguard Worker labels[idx] = line.strip("\n") 42*89c4ff92SAndroid Build Coastguard Worker return labels 43*89c4ff92SAndroid Build Coastguard Worker 44*89c4ff92SAndroid Build Coastguard Worker 45*89c4ff92SAndroid Build Coastguard Workerdef prepare_input_data(audio_data, input_data_type, input_quant_scale, input_quant_offset, mfcc_preprocessor): 46*89c4ff92SAndroid Build Coastguard Worker """ 47*89c4ff92SAndroid Build Coastguard Worker Takes a block of audio data, extracts the MFCC features, quantizes the array, and uses ArmNN to create the 48*89c4ff92SAndroid Build Coastguard Worker input tensors. 49*89c4ff92SAndroid Build Coastguard Worker 50*89c4ff92SAndroid Build Coastguard Worker Args: 51*89c4ff92SAndroid Build Coastguard Worker audio_data: The audio data to process 52*89c4ff92SAndroid Build Coastguard Worker mfcc_instance: The mfcc class instance 53*89c4ff92SAndroid Build Coastguard Worker input_data_type: The model's input data type 54*89c4ff92SAndroid Build Coastguard Worker input_quant_scale: The model's quantization scale 55*89c4ff92SAndroid Build Coastguard Worker input_quant_offset: The model's quantization offset 56*89c4ff92SAndroid Build Coastguard Worker mfcc_preprocessor: The mfcc preprocessor instance 57*89c4ff92SAndroid Build Coastguard Worker Returns: 58*89c4ff92SAndroid Build Coastguard Worker input_data: The prepared input data 59*89c4ff92SAndroid Build Coastguard Worker """ 60*89c4ff92SAndroid Build Coastguard Worker 61*89c4ff92SAndroid Build Coastguard Worker input_data = mfcc_preprocessor.extract_features(audio_data) 62*89c4ff92SAndroid Build Coastguard Worker if input_data_type != np.float32: 63*89c4ff92SAndroid Build Coastguard Worker input_data = quantize_input(input_data, input_data_type, input_quant_scale, input_quant_offset) 64*89c4ff92SAndroid Build Coastguard Worker return input_data 65*89c4ff92SAndroid Build Coastguard Worker 66*89c4ff92SAndroid Build Coastguard Worker 67*89c4ff92SAndroid Build Coastguard Workerdef quantize_input(data, input_data_type, input_quant_scale, input_quant_offset): 68*89c4ff92SAndroid Build Coastguard Worker """Quantize the float input to (u)int8 ready for inputting to model.""" 69*89c4ff92SAndroid Build Coastguard Worker if data.ndim != 2: 70*89c4ff92SAndroid Build Coastguard Worker raise RuntimeError("Audio data must have 2 dimensions for quantization") 71*89c4ff92SAndroid Build Coastguard Worker 72*89c4ff92SAndroid Build Coastguard Worker if (input_data_type != np.int8) and (input_data_type != np.uint8): 73*89c4ff92SAndroid Build Coastguard Worker raise ValueError("Could not quantize data to required data type") 74*89c4ff92SAndroid Build Coastguard Worker 75*89c4ff92SAndroid Build Coastguard Worker d_min = np.iinfo(input_data_type).min 76*89c4ff92SAndroid Build Coastguard Worker d_max = np.iinfo(input_data_type).max 77*89c4ff92SAndroid Build Coastguard Worker 78*89c4ff92SAndroid Build Coastguard Worker for row in range(data.shape[0]): 79*89c4ff92SAndroid Build Coastguard Worker for col in range(data.shape[1]): 80*89c4ff92SAndroid Build Coastguard Worker data[row, col] = (data[row, col] / input_quant_scale) + input_quant_offset 81*89c4ff92SAndroid Build Coastguard Worker data[row, col] = np.clip(data[row, col], d_min, d_max) 82*89c4ff92SAndroid Build Coastguard Worker data = data.astype(input_data_type) 83*89c4ff92SAndroid Build Coastguard Worker return data 84*89c4ff92SAndroid Build Coastguard Worker 85*89c4ff92SAndroid Build Coastguard Worker 86*89c4ff92SAndroid Build Coastguard Workerdef dequantize_output(data, is_output_quantized, output_quant_scale, output_quant_offset): 87*89c4ff92SAndroid Build Coastguard Worker """Dequantize the (u)int8 output to float""" 88*89c4ff92SAndroid Build Coastguard Worker 89*89c4ff92SAndroid Build Coastguard Worker if is_output_quantized: 90*89c4ff92SAndroid Build Coastguard Worker if data.ndim != 2: 91*89c4ff92SAndroid Build Coastguard Worker raise RuntimeError("Data must have 2 dimensions for quantization") 92*89c4ff92SAndroid Build Coastguard Worker 93*89c4ff92SAndroid Build Coastguard Worker data = data.astype(float) 94*89c4ff92SAndroid Build Coastguard Worker for row in range(data.shape[0]): 95*89c4ff92SAndroid Build Coastguard Worker for col in range(data.shape[1]): 96*89c4ff92SAndroid Build Coastguard Worker data[row, col] = (data[row, col] - output_quant_offset)*output_quant_scale 97*89c4ff92SAndroid Build Coastguard Worker return data 98*89c4ff92SAndroid Build Coastguard Worker 99*89c4ff92SAndroid Build Coastguard Worker 100*89c4ff92SAndroid Build Coastguard Workerclass Profiling: 101*89c4ff92SAndroid Build Coastguard Worker def __init__(self, enabled: bool): 102*89c4ff92SAndroid Build Coastguard Worker self.m_start = 0 103*89c4ff92SAndroid Build Coastguard Worker self.m_end = 0 104*89c4ff92SAndroid Build Coastguard Worker self.m_enabled = enabled 105*89c4ff92SAndroid Build Coastguard Worker 106*89c4ff92SAndroid Build Coastguard Worker def profiling_start(self): 107*89c4ff92SAndroid Build Coastguard Worker if self.m_enabled: 108*89c4ff92SAndroid Build Coastguard Worker self.m_start = datetime.datetime.now() 109*89c4ff92SAndroid Build Coastguard Worker 110*89c4ff92SAndroid Build Coastguard Worker def profiling_stop_and_print_us(self, msg): 111*89c4ff92SAndroid Build Coastguard Worker if self.m_enabled: 112*89c4ff92SAndroid Build Coastguard Worker self.m_end = datetime.datetime.now() 113*89c4ff92SAndroid Build Coastguard Worker period = self.m_end - self.m_start 114*89c4ff92SAndroid Build Coastguard Worker period_us = period.seconds * 1_000_000 + period.microseconds 115*89c4ff92SAndroid Build Coastguard Worker print(f'Profiling: {msg} : {period_us:,} microSeconds') 116*89c4ff92SAndroid Build Coastguard Worker return period_us 117*89c4ff92SAndroid Build Coastguard Worker return 0 118