xref: /aosp_15_r20/external/armnn/python/pyarmnn/examples/common/utils.py (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
2*89c4ff92SAndroid Build Coastguard Worker# SPDX-License-Identifier: MIT
3*89c4ff92SAndroid Build Coastguard Worker
4*89c4ff92SAndroid Build Coastguard Worker"""Contains helper functions that can be used across the example apps."""
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Workerimport os
7*89c4ff92SAndroid Build Coastguard Workerimport errno
8*89c4ff92SAndroid Build Coastguard Workerfrom pathlib import Path
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Workerimport numpy as np
11*89c4ff92SAndroid Build Coastguard Workerimport datetime
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker
14*89c4ff92SAndroid Build Coastguard Workerdef dict_labels(labels_file_path: str, include_rgb=False) -> dict:
15*89c4ff92SAndroid Build Coastguard Worker    """Creates a dictionary of labels from the input labels file.
16*89c4ff92SAndroid Build Coastguard Worker
17*89c4ff92SAndroid Build Coastguard Worker    Args:
18*89c4ff92SAndroid Build Coastguard Worker        labels_file: Path to file containing labels to map model outputs.
19*89c4ff92SAndroid Build Coastguard Worker        include_rgb: Adds randomly generated RGB values to the values of the
20*89c4ff92SAndroid Build Coastguard Worker            dictionary. Used for plotting bounding boxes of different colours.
21*89c4ff92SAndroid Build Coastguard Worker
22*89c4ff92SAndroid Build Coastguard Worker    Returns:
23*89c4ff92SAndroid Build Coastguard Worker        Dictionary with classification indices for keys and labels for values.
24*89c4ff92SAndroid Build Coastguard Worker
25*89c4ff92SAndroid Build Coastguard Worker    Raises:
26*89c4ff92SAndroid Build Coastguard Worker        FileNotFoundError:
27*89c4ff92SAndroid Build Coastguard Worker            Provided `labels_file_path` does not exist.
28*89c4ff92SAndroid Build Coastguard Worker    """
29*89c4ff92SAndroid Build Coastguard Worker    labels_file = Path(labels_file_path)
30*89c4ff92SAndroid Build Coastguard Worker    if not labels_file.is_file():
31*89c4ff92SAndroid Build Coastguard Worker        raise FileNotFoundError(
32*89c4ff92SAndroid Build Coastguard Worker            errno.ENOENT, os.strerror(errno.ENOENT), labels_file_path
33*89c4ff92SAndroid Build Coastguard Worker        )
34*89c4ff92SAndroid Build Coastguard Worker
35*89c4ff92SAndroid Build Coastguard Worker    labels = {}
36*89c4ff92SAndroid Build Coastguard Worker    with open(labels_file, "r") as f:
37*89c4ff92SAndroid Build Coastguard Worker        for idx, line in enumerate(f, 0):
38*89c4ff92SAndroid Build Coastguard Worker            if include_rgb:
39*89c4ff92SAndroid Build Coastguard Worker                labels[idx] = line.strip("\n"), tuple(np.random.random(size=3) * 255)
40*89c4ff92SAndroid Build Coastguard Worker            else:
41*89c4ff92SAndroid Build Coastguard Worker                labels[idx] = line.strip("\n")
42*89c4ff92SAndroid Build Coastguard Worker        return labels
43*89c4ff92SAndroid Build Coastguard Worker
44*89c4ff92SAndroid Build Coastguard Worker
45*89c4ff92SAndroid Build Coastguard Workerdef prepare_input_data(audio_data, input_data_type, input_quant_scale, input_quant_offset, mfcc_preprocessor):
46*89c4ff92SAndroid Build Coastguard Worker    """
47*89c4ff92SAndroid Build Coastguard Worker    Takes a block of audio data, extracts the MFCC features, quantizes the array, and uses ArmNN to create the
48*89c4ff92SAndroid Build Coastguard Worker    input tensors.
49*89c4ff92SAndroid Build Coastguard Worker
50*89c4ff92SAndroid Build Coastguard Worker    Args:
51*89c4ff92SAndroid Build Coastguard Worker        audio_data: The audio data to process
52*89c4ff92SAndroid Build Coastguard Worker        mfcc_instance: The mfcc class instance
53*89c4ff92SAndroid Build Coastguard Worker        input_data_type: The model's input data type
54*89c4ff92SAndroid Build Coastguard Worker        input_quant_scale: The model's quantization scale
55*89c4ff92SAndroid Build Coastguard Worker        input_quant_offset: The model's quantization offset
56*89c4ff92SAndroid Build Coastguard Worker        mfcc_preprocessor: The mfcc preprocessor instance
57*89c4ff92SAndroid Build Coastguard Worker    Returns:
58*89c4ff92SAndroid Build Coastguard Worker        input_data: The prepared input data
59*89c4ff92SAndroid Build Coastguard Worker    """
60*89c4ff92SAndroid Build Coastguard Worker
61*89c4ff92SAndroid Build Coastguard Worker    input_data = mfcc_preprocessor.extract_features(audio_data)
62*89c4ff92SAndroid Build Coastguard Worker    if input_data_type != np.float32:
63*89c4ff92SAndroid Build Coastguard Worker        input_data = quantize_input(input_data, input_data_type, input_quant_scale, input_quant_offset)
64*89c4ff92SAndroid Build Coastguard Worker    return input_data
65*89c4ff92SAndroid Build Coastguard Worker
66*89c4ff92SAndroid Build Coastguard Worker
67*89c4ff92SAndroid Build Coastguard Workerdef quantize_input(data, input_data_type, input_quant_scale, input_quant_offset):
68*89c4ff92SAndroid Build Coastguard Worker    """Quantize the float input to (u)int8 ready for inputting to model."""
69*89c4ff92SAndroid Build Coastguard Worker    if data.ndim != 2:
70*89c4ff92SAndroid Build Coastguard Worker        raise RuntimeError("Audio data must have 2 dimensions for quantization")
71*89c4ff92SAndroid Build Coastguard Worker
72*89c4ff92SAndroid Build Coastguard Worker    if (input_data_type != np.int8) and (input_data_type != np.uint8):
73*89c4ff92SAndroid Build Coastguard Worker        raise ValueError("Could not quantize data to required data type")
74*89c4ff92SAndroid Build Coastguard Worker
75*89c4ff92SAndroid Build Coastguard Worker    d_min = np.iinfo(input_data_type).min
76*89c4ff92SAndroid Build Coastguard Worker    d_max = np.iinfo(input_data_type).max
77*89c4ff92SAndroid Build Coastguard Worker
78*89c4ff92SAndroid Build Coastguard Worker    for row in range(data.shape[0]):
79*89c4ff92SAndroid Build Coastguard Worker        for col in range(data.shape[1]):
80*89c4ff92SAndroid Build Coastguard Worker            data[row, col] = (data[row, col] / input_quant_scale) + input_quant_offset
81*89c4ff92SAndroid Build Coastguard Worker            data[row, col] = np.clip(data[row, col], d_min, d_max)
82*89c4ff92SAndroid Build Coastguard Worker    data = data.astype(input_data_type)
83*89c4ff92SAndroid Build Coastguard Worker    return data
84*89c4ff92SAndroid Build Coastguard Worker
85*89c4ff92SAndroid Build Coastguard Worker
86*89c4ff92SAndroid Build Coastguard Workerdef dequantize_output(data, is_output_quantized, output_quant_scale, output_quant_offset):
87*89c4ff92SAndroid Build Coastguard Worker    """Dequantize the (u)int8 output to float"""
88*89c4ff92SAndroid Build Coastguard Worker
89*89c4ff92SAndroid Build Coastguard Worker    if is_output_quantized:
90*89c4ff92SAndroid Build Coastguard Worker        if data.ndim != 2:
91*89c4ff92SAndroid Build Coastguard Worker            raise RuntimeError("Data must have 2 dimensions for quantization")
92*89c4ff92SAndroid Build Coastguard Worker
93*89c4ff92SAndroid Build Coastguard Worker        data = data.astype(float)
94*89c4ff92SAndroid Build Coastguard Worker        for row in range(data.shape[0]):
95*89c4ff92SAndroid Build Coastguard Worker            for col in range(data.shape[1]):
96*89c4ff92SAndroid Build Coastguard Worker                data[row, col] = (data[row, col] - output_quant_offset)*output_quant_scale
97*89c4ff92SAndroid Build Coastguard Worker    return data
98*89c4ff92SAndroid Build Coastguard Worker
99*89c4ff92SAndroid Build Coastguard Worker
100*89c4ff92SAndroid Build Coastguard Workerclass Profiling:
101*89c4ff92SAndroid Build Coastguard Worker    def __init__(self, enabled: bool):
102*89c4ff92SAndroid Build Coastguard Worker        self.m_start = 0
103*89c4ff92SAndroid Build Coastguard Worker        self.m_end = 0
104*89c4ff92SAndroid Build Coastguard Worker        self.m_enabled = enabled
105*89c4ff92SAndroid Build Coastguard Worker
106*89c4ff92SAndroid Build Coastguard Worker    def profiling_start(self):
107*89c4ff92SAndroid Build Coastguard Worker        if self.m_enabled:
108*89c4ff92SAndroid Build Coastguard Worker            self.m_start = datetime.datetime.now()
109*89c4ff92SAndroid Build Coastguard Worker
110*89c4ff92SAndroid Build Coastguard Worker    def profiling_stop_and_print_us(self, msg):
111*89c4ff92SAndroid Build Coastguard Worker        if self.m_enabled:
112*89c4ff92SAndroid Build Coastguard Worker            self.m_end = datetime.datetime.now()
113*89c4ff92SAndroid Build Coastguard Worker            period = self.m_end - self.m_start
114*89c4ff92SAndroid Build Coastguard Worker            period_us = period.seconds * 1_000_000 + period.microseconds
115*89c4ff92SAndroid Build Coastguard Worker            print(f'Profiling: {msg} : {period_us:,} microSeconds')
116*89c4ff92SAndroid Build Coastguard Worker            return period_us
117*89c4ff92SAndroid Build Coastguard Worker        return 0
118