1 /*Copyright 2018 Google LLC 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 https://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 package org.tensorflow.ovic; 16 17 import static java.nio.charset.StandardCharsets.UTF_8; 18 19 import java.io.BufferedReader; 20 import java.io.IOException; 21 import java.io.InputStream; 22 import java.io.InputStreamReader; 23 import java.nio.ByteBuffer; 24 import java.nio.MappedByteBuffer; 25 import java.util.AbstractMap; 26 import java.util.ArrayList; 27 import java.util.Collections; 28 import java.util.Comparator; 29 import java.util.List; 30 import java.util.Map; 31 import java.util.PriorityQueue; 32 import org.tensorflow.lite.Interpreter; 33 import org.tensorflow.lite.TestHelper; 34 35 /** Class for running ImageNet classification with a TfLite model. */ 36 public class OvicClassifier { 37 38 /** Tag for the {@link Log}. */ 39 private static final String TAG = "OvicClassifier"; 40 41 /** Number of results to show (i.e. the "K" in top-K predictions). */ 42 private static final int RESULTS_TO_SHOW = 5; 43 44 /** An instance of the driver class to run model inference with Tensorflow Lite. */ 45 private Interpreter tflite; 46 47 /** Labels corresponding to the output of the vision model. */ 48 private final List<String> labelList; 49 50 /** An array to hold inference results, to be feed into Tensorflow Lite as outputs. */ 51 private byte[][] inferenceOutputArray = null; 52 /** An array to hold final prediction probabilities. */ 53 private float[][] labelProbArray = null; 54 55 /** Input resultion. */ 56 private int[] inputDims = null; 57 /** Whether the model runs as float or quantized. */ 58 private Boolean outputIsFloat = null; 59 60 private final PriorityQueue<Map.Entry<Integer, Float>> sortedLabels = 61 new PriorityQueue<>( 62 RESULTS_TO_SHOW, 63 new Comparator<Map.Entry<Integer, Float>>() { 64 @Override 65 public int compare(Map.Entry<Integer, Float> o1, Map.Entry<Integer, Float> o2) { 66 return o1.getValue().compareTo(o2.getValue()); 67 } 68 }); 69 70 /** Initializes an {@code OvicClassifier}. */ OvicClassifier(InputStream labelInputStream, MappedByteBuffer model)71 public OvicClassifier(InputStream labelInputStream, MappedByteBuffer model) throws IOException { 72 if (model == null) { 73 throw new RuntimeException("Input model is empty."); 74 } 75 labelList = loadLabelList(labelInputStream); 76 // OVIC uses one thread for CPU inference. 77 tflite = new Interpreter(model, new Interpreter.Options().setNumThreads(1)); 78 inputDims = TestHelper.getInputDims(tflite, 0); 79 if (inputDims.length != 4) { 80 throw new RuntimeException("The model's input dimensions must be 4 (BWHC)."); 81 } 82 if (inputDims[0] != 1) { 83 throw new IllegalStateException( 84 "The model must have a batch size of 1, got " + inputDims[0] + " instead."); 85 } 86 if (inputDims[3] != 3) { 87 throw new IllegalStateException( 88 "The model must have three color channels, got " + inputDims[3] + " instead."); 89 } 90 int minSide = Math.min(inputDims[1], inputDims[2]); 91 int maxSide = Math.max(inputDims[1], inputDims[2]); 92 if (minSide <= 0 || maxSide > 1000) { 93 throw new RuntimeException("The model's resolution must be between (0, 1000]."); 94 } 95 String outputDataType = TestHelper.getOutputDataType(tflite, 0); 96 switch (outputDataType) { 97 case "float": 98 outputIsFloat = true; 99 break; 100 case "byte": 101 outputIsFloat = false; 102 break; 103 default: 104 throw new IllegalStateException("Cannot process output type: " + outputDataType); 105 } 106 inferenceOutputArray = new byte[1][labelList.size()]; 107 labelProbArray = new float[1][labelList.size()]; 108 } 109 110 /** Classifies a {@link ByteBuffer} image. */ 111 // @throws RuntimeException if model is uninitialized. classifyByteBuffer(ByteBuffer imgData)112 public OvicClassificationResult classifyByteBuffer(ByteBuffer imgData) { 113 if (tflite == null) { 114 throw new RuntimeException(TAG + ": ImageNet classifier has not been initialized; Failed."); 115 } 116 if (outputIsFloat == null) { 117 throw new RuntimeException(TAG + ": Classifier output type has not been resolved."); 118 } 119 if (outputIsFloat) { 120 tflite.run(imgData, labelProbArray); 121 } else { 122 tflite.run(imgData, inferenceOutputArray); 123 /** Convert results to float */ 124 for (int i = 0; i < inferenceOutputArray[0].length; i++) { 125 labelProbArray[0][i] = (inferenceOutputArray[0][i] & 0xff) / 255.0f; 126 } 127 } 128 OvicClassificationResult iterResult = computeTopKLabels(); 129 iterResult.latencyMilli = getLastNativeInferenceLatencyMilliseconds(); 130 iterResult.latencyNano = getLastNativeInferenceLatencyNanoseconds(); 131 return iterResult; 132 } 133 134 /** Return the probability array of all classes. */ getlabelProbArray()135 public float[][] getlabelProbArray() { 136 return labelProbArray; 137 } 138 139 /** Return the number of top labels predicted by the classifier. */ getNumPredictions()140 public int getNumPredictions() { 141 return RESULTS_TO_SHOW; 142 } 143 144 /** Return the four dimensions of the input image. */ getInputDims()145 public int[] getInputDims() { 146 return inputDims; 147 } 148 149 /* 150 * Get native inference latency of last image classification run. 151 * @throws RuntimeException if model is uninitialized. 152 */ getLastNativeInferenceLatencyMilliseconds()153 public Long getLastNativeInferenceLatencyMilliseconds() { 154 if (tflite == null) { 155 throw new RuntimeException(TAG + ": ImageNet classifier has not been initialized; Failed."); 156 } 157 Long latency = tflite.getLastNativeInferenceDurationNanoseconds(); 158 return (latency == null) ? null : (Long) (latency / 1000000); 159 } 160 161 /* 162 * Get native inference latency of last image classification run. 163 * @throws RuntimeException if model is uninitialized. 164 */ getLastNativeInferenceLatencyNanoseconds()165 public Long getLastNativeInferenceLatencyNanoseconds() { 166 if (tflite == null) { 167 throw new IllegalStateException( 168 TAG + ": ImageNet classifier has not been initialized; Failed."); 169 } 170 return tflite.getLastNativeInferenceDurationNanoseconds(); 171 } 172 173 /** Closes tflite to release resources. */ close()174 public void close() { 175 tflite.close(); 176 tflite = null; 177 } 178 179 /** Reads label list from Assets. */ loadLabelList(InputStream labelInputStream)180 private static List<String> loadLabelList(InputStream labelInputStream) throws IOException { 181 List<String> labelList = new ArrayList<>(); 182 try (BufferedReader reader = 183 new BufferedReader(new InputStreamReader(labelInputStream, UTF_8))) { 184 String line; 185 while ((line = reader.readLine()) != null) { 186 labelList.add(line); 187 } 188 } 189 return labelList; 190 } 191 192 /** Computes top-K labels. */ computeTopKLabels()193 private OvicClassificationResult computeTopKLabels() { 194 if (labelList == null) { 195 throw new RuntimeException("Label file has not been loaded."); 196 } 197 for (int i = 0; i < labelList.size(); ++i) { 198 sortedLabels.add(new AbstractMap.SimpleEntry<>(i, labelProbArray[0][i])); 199 if (sortedLabels.size() > RESULTS_TO_SHOW) { 200 sortedLabels.poll(); 201 } 202 } 203 OvicClassificationResult singleImageResult = new OvicClassificationResult(); 204 if (sortedLabels.size() != RESULTS_TO_SHOW) { 205 throw new RuntimeException( 206 "Number of returned labels does not match requirement: " 207 + sortedLabels.size() 208 + " returned, but " 209 + RESULTS_TO_SHOW 210 + " required."); 211 } 212 for (int i = 0; i < RESULTS_TO_SHOW; ++i) { 213 Map.Entry<Integer, Float> label = sortedLabels.poll(); 214 // ImageNet model prediction indices are 0-based. 215 singleImageResult.topKIndices.add(label.getKey()); 216 singleImageResult.topKClasses.add(labelList.get(label.getKey())); 217 singleImageResult.topKProbs.add(label.getValue()); 218 } 219 // Labels with lowest probability are returned first, hence need to reverse them. 220 Collections.reverse(singleImageResult.topKIndices); 221 Collections.reverse(singleImageResult.topKClasses); 222 Collections.reverse(singleImageResult.topKProbs); 223 return singleImageResult; 224 } 225 } 226