1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 https://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 package org.tensorflow.ovic; 16 17 import static com.google.common.truth.Truth.assertThat; 18 import static org.junit.Assert.fail; 19 20 import java.awt.image.BufferedImage; 21 import java.io.File; 22 import java.io.FileInputStream; 23 import java.io.IOException; 24 import java.io.InputStream; 25 import java.nio.ByteBuffer; 26 import java.nio.ByteOrder; 27 import java.nio.MappedByteBuffer; 28 import java.nio.channels.FileChannel; 29 import javax.imageio.ImageIO; 30 import org.junit.Before; 31 import org.junit.Test; 32 import org.junit.runner.RunWith; 33 import org.junit.runners.JUnit4; 34 35 /** Unit tests for {@link org.tensorflow.ovic.OvicClassifier}. */ 36 @RunWith(JUnit4.class) 37 public final class OvicClassifierTest { 38 39 private OvicClassifier classifier; 40 private InputStream labelsInputStream = null; 41 private MappedByteBuffer quantizedModel = null; 42 private MappedByteBuffer floatModel = null; 43 private MappedByteBuffer lowResModel = null; 44 private ByteBuffer testImage = null; 45 private ByteBuffer lowResTestImage = null; 46 private OvicClassificationResult testResult = null; 47 private static final String LABELS_PATH = 48 "tensorflow/lite/java/ovic/src/testdata/labels.txt"; 49 private static final String QUANTIZED_MODEL_PATH = 50 "external/tflite_ovic_testdata/quantized_model.lite"; 51 private static final String LOW_RES_MODEL_PATH = 52 "external/tflite_ovic_testdata/low_res_model.lite"; 53 private static final String FLOAT_MODEL_PATH = 54 "external/tflite_ovic_testdata/float_model.lite"; 55 private static final String TEST_IMAGE_PATH = 56 "external/tflite_ovic_testdata/test_image_224.jpg"; 57 private static final String TEST_LOW_RES_IMAGE_PATH = 58 "external/tflite_ovic_testdata/test_image_128.jpg"; 59 private static final int TEST_IMAGE_GROUNDTRUTH = 653; // "military uniform" 60 61 @Before setUp()62 public void setUp() { 63 try { 64 File labelsfile = new File(LABELS_PATH); 65 labelsInputStream = new FileInputStream(labelsfile); 66 quantizedModel = loadModelFile(QUANTIZED_MODEL_PATH); 67 floatModel = loadModelFile(FLOAT_MODEL_PATH); 68 lowResModel = loadModelFile(LOW_RES_MODEL_PATH); 69 File imageFile = new File(TEST_IMAGE_PATH); 70 BufferedImage img = ImageIO.read(imageFile); 71 testImage = toByteBuffer(img); 72 // Low res image and models. 73 imageFile = new File(TEST_LOW_RES_IMAGE_PATH); 74 img = ImageIO.read(imageFile); 75 lowResTestImage = toByteBuffer(img); 76 } catch (IOException e) { 77 System.out.print(e.getMessage()); 78 } 79 System.out.println("Successful setup"); 80 } 81 82 @Test ovicClassifier_quantizedModelCreateSuccess()83 public void ovicClassifier_quantizedModelCreateSuccess() throws Exception { 84 classifier = new OvicClassifier(labelsInputStream, quantizedModel); 85 assertThat(classifier).isNotNull(); 86 } 87 88 @Test ovicClassifier_floatModelCreateSuccess()89 public void ovicClassifier_floatModelCreateSuccess() throws Exception { 90 classifier = new OvicClassifier(labelsInputStream, floatModel); 91 assertThat(classifier).isNotNull(); 92 } 93 94 @Test ovicClassifier_quantizedModelClassifySuccess()95 public void ovicClassifier_quantizedModelClassifySuccess() throws Exception { 96 classifier = new OvicClassifier(labelsInputStream, quantizedModel); 97 testResult = classifier.classifyByteBuffer(testImage); 98 assertCorrectTopK(testResult); 99 } 100 101 @Test ovicClassifier_floatModelClassifySuccess()102 public void ovicClassifier_floatModelClassifySuccess() throws Exception { 103 classifier = new OvicClassifier(labelsInputStream, floatModel); 104 testResult = classifier.classifyByteBuffer(testImage); 105 assertCorrectTopK(testResult); 106 } 107 108 @Test ovicClassifier_lowResModelClassifySuccess()109 public void ovicClassifier_lowResModelClassifySuccess() throws Exception { 110 classifier = new OvicClassifier(labelsInputStream, lowResModel); 111 testResult = classifier.classifyByteBuffer(lowResTestImage); 112 assertCorrectTopK(testResult); 113 } 114 115 @Test ovicClassifier_latencyNotNull()116 public void ovicClassifier_latencyNotNull() throws Exception { 117 classifier = new OvicClassifier(labelsInputStream, floatModel); 118 testResult = classifier.classifyByteBuffer(testImage); 119 assertThat(testResult.latencyNano).isNotNull(); 120 } 121 122 @Test ovicClassifier_mismatchedInputResolutionFails()123 public void ovicClassifier_mismatchedInputResolutionFails() throws Exception { 124 classifier = new OvicClassifier(labelsInputStream, lowResModel); 125 int[] inputDims = classifier.getInputDims(); 126 assertThat(inputDims[1]).isEqualTo(128); 127 assertThat(inputDims[2]).isEqualTo(128); 128 try { 129 testResult = classifier.classifyByteBuffer(testImage); 130 fail(); 131 } catch (IllegalArgumentException e) { 132 // Success. 133 } 134 } 135 toByteBuffer(BufferedImage image)136 private static ByteBuffer toByteBuffer(BufferedImage image) { 137 ByteBuffer imgData = ByteBuffer.allocateDirect( 138 image.getHeight() * image.getWidth() * 3); 139 imgData.order(ByteOrder.nativeOrder()); 140 for (int y = 0; y < image.getHeight(); y++) { 141 for (int x = 0; x < image.getWidth(); x++) { 142 int val = image.getRGB(x, y); 143 imgData.put((byte) ((val >> 16) & 0xFF)); 144 imgData.put((byte) ((val >> 8) & 0xFF)); 145 imgData.put((byte) (val & 0xFF)); 146 } 147 } 148 return imgData; 149 } 150 assertCorrectTopK(OvicClassificationResult testResult)151 private static void assertCorrectTopK(OvicClassificationResult testResult) { 152 assertThat(testResult.topKClasses.size()).isGreaterThan(0); 153 Boolean topKAccurate = false; 154 // Assert that the correct class is in the top K. 155 for (int i = 0; i < testResult.topKIndices.size(); i++) { 156 if (testResult.topKIndices.get(i) == TEST_IMAGE_GROUNDTRUTH) { 157 topKAccurate = true; 158 break; 159 } 160 } 161 System.out.println(testResult.toString()); 162 System.out.flush(); 163 assertThat(topKAccurate).isTrue(); 164 } 165 loadModelFile(String modelFilePath)166 private static MappedByteBuffer loadModelFile(String modelFilePath) throws IOException { 167 File modelfile = new File(modelFilePath); 168 FileInputStream inputStream = new FileInputStream(modelfile); 169 FileChannel fileChannel = inputStream.getChannel(); 170 long startOffset = 0L; 171 long declaredLength = fileChannel.size(); 172 return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); 173 } 174 } 175