xref: /aosp_15_r20/external/armnn/samples/ImageClassification/run_classifier.py (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1#
2# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3# SPDX-License-Identifier: MIT
4#
5import argparse
6from pathlib import Path
7from typing import Union
8
9import tflite_runtime.interpreter as tflite
10from PIL import Image
11import numpy as np
12
13
14def check_args(args: argparse.Namespace):
15    """Check the values used in the command-line have acceptable values
16
17    args:
18      - args: argparse.Namespace
19
20    returns:
21      - None
22
23    raises:
24      - FileNotFoundError: if passed files do not exist.
25      - IOError: if files are of incorrect format.
26    """
27    input_image_p = args.input_image
28    if not input_image_p.suffix in (".png", ".jpg", ".jpeg"):
29        raise IOError(
30            "--input_image option should point to an image file of the "
31            "format .jpg, .jpeg, .png"
32        )
33    if not input_image_p.exists():
34        raise FileNotFoundError("Cannot find ", input_image_p.name)
35    model_p = args.model_file
36    if not model_p.suffix == ".tflite":
37        raise IOError("--model_file should point to a tflite file.")
38    if not model_p.exists():
39        raise FileNotFoundError("Cannot find ", model_p.name)
40    label_mapping_p = args.label_file
41    if not label_mapping_p.suffix == ".txt":
42        raise IOError("--label_file expects a .txt file.")
43    if not label_mapping_p.exists():
44        raise FileNotFoundError("Cannot find ", label_mapping_p.name)
45
46    # check all args given in preferred backends make sense
47    supported_backends = ["GpuAcc", "CpuAcc", "CpuRef"]
48    if not all([backend in supported_backends for backend in args.preferred_backends]):
49        raise ValueError("Incorrect backends given. Please choose from "\
50            "'GpuAcc', 'CpuAcc', 'CpuRef'.")
51
52    return None
53
54
55def load_image(image_path: Path, model_input_dims: Union[tuple, list], grayscale: bool):
56    """load an image and put into correct format for the tensorflow lite model
57
58    args:
59      - image_path: pathlib.Path
60      - model_input_dims: tuple (or array-like). (height,width)
61
62    returns:
63      - image: np.array
64    """
65    height, width = model_input_dims
66    # load and resize image
67    image = Image.open(image_path).resize((width, height))
68    # convert to greyscale if expected
69    if grayscale:
70        image = image.convert("LA")
71
72    image = np.expand_dims(image, axis=0)
73
74    return image
75
76
77def load_delegate(delegate_path: Path, backends: list):
78    """load the armnn delegate.
79
80    args:
81      - delegate_path: pathlib.Path -> location of you libarmnnDelegate.so
82      - backends: list -> list of backends you want to use in string format
83
84    returns:
85      - armnn_delegate: tflite.delegate
86    """
87    # create a command separated string
88    backend_string = ",".join(backends)
89    # load delegate
90    armnn_delegate = tflite.load_delegate(
91        library=delegate_path,
92        options={"backends": backend_string, "logging-severity": "info"},
93    )
94
95    return armnn_delegate
96
97
98def load_tf_model(model_path: Path, armnn_delegate: tflite.Delegate):
99    """load a tflite model for use with the armnn delegate.
100
101    args:
102      - model_path: pathlib.Path
103      - armnn_delegate: tflite.TfLiteDelegate
104
105    returns:
106      - interpreter: tflite.Interpreter
107    """
108    interpreter = tflite.Interpreter(
109        model_path=model_path.as_posix(), experimental_delegates=[armnn_delegate]
110    )
111    interpreter.allocate_tensors()
112
113    return interpreter
114
115
116def run_inference(interpreter, input_image):
117    """Run inference on a processed input image and return the output from
118    inference.
119
120    args:
121      - interpreter: tflite_runtime.interpreter.Interpreter
122      - input_image: np.array
123
124    returns:
125      - output_data: np.array
126    """
127    # Get input and output tensors.
128    input_details = interpreter.get_input_details()
129    output_details = interpreter.get_output_details()
130    # Test model on random input data.
131    interpreter.set_tensor(input_details[0]["index"], input_image)
132    interpreter.invoke()
133    output_data = interpreter.get_tensor(output_details[0]["index"])
134
135    return output_data
136
137
138def create_mapping(label_mapping_p):
139    """Creates a Python dictionary mapping an index to a label.
140
141    label_mapping[idx] = label
142
143    args:
144      - label_mapping_p: pathlib.Path
145
146    returns:
147      - label_mapping: dict
148    """
149    idx = 0
150    label_mapping = {}
151    with open(label_mapping_p) as label_mapping_raw:
152        for line in label_mapping_raw:
153            label_mapping[idx] = line
154            idx += 1
155
156    return label_mapping
157
158
159def process_output(output_data, label_mapping):
160    """Process the output tensor into a label from the labelmapping file. Takes
161    the index of the maximum valur from the output array.
162
163    args:
164      - output_data: np.array
165      - label_mapping: dict
166
167    returns:
168      - str: labelmapping for max index.
169    """
170    idx = np.argmax(output_data[0])
171
172    return label_mapping[idx]
173
174
175def main(args):
176    """Run the inference for options passed in the command line.
177
178    args:
179      - args: argparse.Namespace
180
181    returns:
182      - None
183    """
184    # sanity check on args
185    check_args(args)
186    # load in the armnn delegate
187    armnn_delegate = load_delegate(args.delegate_path, args.preferred_backends)
188    # load tflite model
189    interpreter = load_tf_model(args.model_file, armnn_delegate)
190    # get input shape for image resizing
191    input_shape = interpreter.get_input_details()[0]["shape"]
192    height, width = input_shape[1], input_shape[2]
193    input_shape = (height, width)
194    # load input image
195    input_image = load_image(args.input_image, input_shape, False)
196    # get label mapping
197    labelmapping = create_mapping(args.label_file)
198    output_tensor = run_inference(interpreter, input_image)
199    output_prediction = process_output(output_tensor, labelmapping)
200
201    print("Prediction: ", output_prediction)
202
203    return None
204
205
206if __name__ == "__main__":
207    parser = argparse.ArgumentParser(
208        formatter_class=argparse.ArgumentDefaultsHelpFormatter
209    )
210    parser.add_argument(
211        "--input_image", help="File path of image file", type=Path, required=True
212    )
213    parser.add_argument(
214        "--model_file",
215        help="File path of the model tflite file",
216        type=Path,
217        required=True,
218    )
219    parser.add_argument(
220        "--label_file",
221        help="File path of model labelmapping file",
222        type=Path,
223        required=True,
224    )
225    parser.add_argument(
226        "--delegate_path",
227        help="File path of ArmNN delegate file",
228        type=Path,
229        required=True,
230    )
231    parser.add_argument(
232        "--preferred_backends",
233        help="list of backends in order of preference",
234        type=str,
235        nargs="+",
236        required=False,
237        default=["CpuAcc", "CpuRef"],
238    )
239    args = parser.parse_args()
240
241    main(args)
242