xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /*Copyright 2018 Google LLC
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     https://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 package org.tensorflow.ovic;
16 
17 import static java.nio.charset.StandardCharsets.UTF_8;
18 
19 import java.io.BufferedReader;
20 import java.io.IOException;
21 import java.io.InputStream;
22 import java.io.InputStreamReader;
23 import java.nio.ByteBuffer;
24 import java.nio.MappedByteBuffer;
25 import java.util.AbstractMap;
26 import java.util.ArrayList;
27 import java.util.Collections;
28 import java.util.Comparator;
29 import java.util.List;
30 import java.util.Map;
31 import java.util.PriorityQueue;
32 import org.tensorflow.lite.Interpreter;
33 import org.tensorflow.lite.TestHelper;
34 
35 /** Class for running ImageNet classification with a TfLite model. */
36 public class OvicClassifier {
37 
38   /** Tag for the {@link Log}. */
39   private static final String TAG = "OvicClassifier";
40 
41   /** Number of results to show (i.e. the "K" in top-K predictions). */
42   private static final int RESULTS_TO_SHOW = 5;
43 
44   /** An instance of the driver class to run model inference with Tensorflow Lite. */
45   private Interpreter tflite;
46 
47   /** Labels corresponding to the output of the vision model. */
48   private final List<String> labelList;
49 
50   /** An array to hold inference results, to be feed into Tensorflow Lite as outputs. */
51   private byte[][] inferenceOutputArray = null;
52   /** An array to hold final prediction probabilities. */
53   private float[][] labelProbArray = null;
54 
55   /** Input resultion. */
56   private int[] inputDims = null;
57   /** Whether the model runs as float or quantized. */
58   private Boolean outputIsFloat = null;
59 
60   private final PriorityQueue<Map.Entry<Integer, Float>> sortedLabels =
61       new PriorityQueue<>(
62           RESULTS_TO_SHOW,
63           new Comparator<Map.Entry<Integer, Float>>() {
64             @Override
65             public int compare(Map.Entry<Integer, Float> o1, Map.Entry<Integer, Float> o2) {
66               return o1.getValue().compareTo(o2.getValue());
67             }
68           });
69 
70   /** Initializes an {@code OvicClassifier}. */
OvicClassifier(InputStream labelInputStream, MappedByteBuffer model)71   public OvicClassifier(InputStream labelInputStream, MappedByteBuffer model) throws IOException {
72     if (model == null) {
73       throw new RuntimeException("Input model is empty.");
74     }
75     labelList = loadLabelList(labelInputStream);
76     // OVIC uses one thread for CPU inference.
77     tflite = new Interpreter(model, new Interpreter.Options().setNumThreads(1));
78     inputDims = TestHelper.getInputDims(tflite, 0);
79     if (inputDims.length != 4) {
80       throw new RuntimeException("The model's input dimensions must be 4 (BWHC).");
81     }
82     if (inputDims[0] != 1) {
83       throw new IllegalStateException(
84           "The model must have a batch size of 1, got " + inputDims[0] + " instead.");
85     }
86     if (inputDims[3] != 3) {
87       throw new IllegalStateException(
88           "The model must have three color channels, got " + inputDims[3] + " instead.");
89     }
90     int minSide = Math.min(inputDims[1], inputDims[2]);
91     int maxSide = Math.max(inputDims[1], inputDims[2]);
92     if (minSide <= 0 || maxSide > 1000) {
93       throw new RuntimeException("The model's resolution must be between (0, 1000].");
94     }
95     String outputDataType = TestHelper.getOutputDataType(tflite, 0);
96     switch (outputDataType) {
97       case "float":
98         outputIsFloat = true;
99         break;
100       case "byte":
101         outputIsFloat = false;
102         break;
103       default:
104         throw new IllegalStateException("Cannot process output type: " + outputDataType);
105     }
106     inferenceOutputArray = new byte[1][labelList.size()];
107     labelProbArray = new float[1][labelList.size()];
108   }
109 
110   /** Classifies a {@link ByteBuffer} image. */
111   // @throws RuntimeException if model is uninitialized.
classifyByteBuffer(ByteBuffer imgData)112   public OvicClassificationResult classifyByteBuffer(ByteBuffer imgData) {
113     if (tflite == null) {
114       throw new RuntimeException(TAG + ": ImageNet classifier has not been initialized; Failed.");
115     }
116     if (outputIsFloat == null) {
117       throw new RuntimeException(TAG + ": Classifier output type has not been resolved.");
118     }
119     if (outputIsFloat) {
120       tflite.run(imgData, labelProbArray);
121     } else {
122       tflite.run(imgData, inferenceOutputArray);
123       /** Convert results to float */
124       for (int i = 0; i < inferenceOutputArray[0].length; i++) {
125         labelProbArray[0][i] = (inferenceOutputArray[0][i] & 0xff) / 255.0f;
126       }
127     }
128     OvicClassificationResult iterResult = computeTopKLabels();
129     iterResult.latencyMilli = getLastNativeInferenceLatencyMilliseconds();
130     iterResult.latencyNano = getLastNativeInferenceLatencyNanoseconds();
131     return iterResult;
132   }
133 
134   /** Return the probability array of all classes. */
getlabelProbArray()135   public float[][] getlabelProbArray() {
136     return labelProbArray;
137   }
138 
139   /** Return the number of top labels predicted by the classifier. */
getNumPredictions()140   public int getNumPredictions() {
141     return RESULTS_TO_SHOW;
142   }
143 
144   /** Return the four dimensions of the input image. */
getInputDims()145   public int[] getInputDims() {
146     return inputDims;
147   }
148 
149   /*
150    * Get native inference latency of last image classification run.
151    *  @throws RuntimeException if model is uninitialized.
152    */
getLastNativeInferenceLatencyMilliseconds()153   public Long getLastNativeInferenceLatencyMilliseconds() {
154     if (tflite == null) {
155       throw new RuntimeException(TAG + ": ImageNet classifier has not been initialized; Failed.");
156     }
157     Long latency = tflite.getLastNativeInferenceDurationNanoseconds();
158     return (latency == null) ? null : (Long) (latency / 1000000);
159   }
160 
161   /*
162    * Get native inference latency of last image classification run.
163    *  @throws RuntimeException if model is uninitialized.
164    */
getLastNativeInferenceLatencyNanoseconds()165   public Long getLastNativeInferenceLatencyNanoseconds() {
166     if (tflite == null) {
167       throw new IllegalStateException(
168           TAG + ": ImageNet classifier has not been initialized; Failed.");
169     }
170     return tflite.getLastNativeInferenceDurationNanoseconds();
171   }
172 
173   /** Closes tflite to release resources. */
close()174   public void close() {
175     tflite.close();
176     tflite = null;
177   }
178 
179   /** Reads label list from Assets. */
loadLabelList(InputStream labelInputStream)180   private static List<String> loadLabelList(InputStream labelInputStream) throws IOException {
181     List<String> labelList = new ArrayList<>();
182     try (BufferedReader reader =
183         new BufferedReader(new InputStreamReader(labelInputStream, UTF_8))) {
184       String line;
185       while ((line = reader.readLine()) != null) {
186         labelList.add(line);
187       }
188     }
189     return labelList;
190   }
191 
192   /** Computes top-K labels. */
computeTopKLabels()193   private OvicClassificationResult computeTopKLabels() {
194     if (labelList == null) {
195       throw new RuntimeException("Label file has not been loaded.");
196     }
197     for (int i = 0; i < labelList.size(); ++i) {
198       sortedLabels.add(new AbstractMap.SimpleEntry<>(i, labelProbArray[0][i]));
199       if (sortedLabels.size() > RESULTS_TO_SHOW) {
200         sortedLabels.poll();
201       }
202     }
203     OvicClassificationResult singleImageResult = new OvicClassificationResult();
204     if (sortedLabels.size() != RESULTS_TO_SHOW) {
205       throw new RuntimeException(
206           "Number of returned labels does not match requirement: "
207               + sortedLabels.size()
208               + " returned, but "
209               + RESULTS_TO_SHOW
210               + " required.");
211     }
212     for (int i = 0; i < RESULTS_TO_SHOW; ++i) {
213       Map.Entry<Integer, Float> label = sortedLabels.poll();
214       // ImageNet model prediction indices are 0-based.
215       singleImageResult.topKIndices.add(label.getKey());
216       singleImageResult.topKClasses.add(labelList.get(label.getKey()));
217       singleImageResult.topKProbs.add(label.getValue());
218     }
219     // Labels with lowest probability are returned first, hence need to reverse them.
220     Collections.reverse(singleImageResult.topKIndices);
221     Collections.reverse(singleImageResult.topKClasses);
222     Collections.reverse(singleImageResult.topKProbs);
223     return singleImageResult;
224   }
225 }
226