1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     https://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 package org.tensorflow.ovic;
16 
17 import static com.google.common.truth.Truth.assertThat;
18 import static org.junit.Assert.fail;
19 
20 import java.awt.image.BufferedImage;
21 import java.io.File;
22 import java.io.FileInputStream;
23 import java.io.IOException;
24 import java.io.InputStream;
25 import java.nio.ByteBuffer;
26 import java.nio.ByteOrder;
27 import java.nio.MappedByteBuffer;
28 import java.nio.channels.FileChannel;
29 import javax.imageio.ImageIO;
30 import org.junit.Before;
31 import org.junit.Test;
32 import org.junit.runner.RunWith;
33 import org.junit.runners.JUnit4;
34 
35 /** Unit tests for {@link org.tensorflow.ovic.OvicClassifier}. */
36 @RunWith(JUnit4.class)
37 public final class OvicClassifierTest {
38 
39   private OvicClassifier classifier;
40   private InputStream labelsInputStream = null;
41   private MappedByteBuffer quantizedModel = null;
42   private MappedByteBuffer floatModel = null;
43   private MappedByteBuffer lowResModel = null;
44   private ByteBuffer testImage = null;
45   private ByteBuffer lowResTestImage = null;
46   private OvicClassificationResult testResult = null;
47   private static final String LABELS_PATH =
48       "tensorflow/lite/java/ovic/src/testdata/labels.txt";
49   private static final String QUANTIZED_MODEL_PATH =
50       "external/tflite_ovic_testdata/quantized_model.lite";
51   private static final String LOW_RES_MODEL_PATH =
52       "external/tflite_ovic_testdata/low_res_model.lite";
53   private static final String FLOAT_MODEL_PATH =
54       "external/tflite_ovic_testdata/float_model.lite";
55   private static final String TEST_IMAGE_PATH =
56       "external/tflite_ovic_testdata/test_image_224.jpg";
57   private static final String TEST_LOW_RES_IMAGE_PATH =
58       "external/tflite_ovic_testdata/test_image_128.jpg";
59   private static final int TEST_IMAGE_GROUNDTRUTH = 653; // "military uniform"
60 
61   @Before
setUp()62   public void setUp() {
63     try {
64       File labelsfile = new File(LABELS_PATH);
65       labelsInputStream = new FileInputStream(labelsfile);
66       quantizedModel = loadModelFile(QUANTIZED_MODEL_PATH);
67       floatModel = loadModelFile(FLOAT_MODEL_PATH);
68       lowResModel = loadModelFile(LOW_RES_MODEL_PATH);
69       File imageFile = new File(TEST_IMAGE_PATH);
70       BufferedImage img = ImageIO.read(imageFile);
71       testImage = toByteBuffer(img);
72       // Low res image and models.
73       imageFile = new File(TEST_LOW_RES_IMAGE_PATH);
74       img = ImageIO.read(imageFile);
75       lowResTestImage = toByteBuffer(img);
76     } catch (IOException e) {
77       System.out.print(e.getMessage());
78     }
79     System.out.println("Successful setup");
80   }
81 
82   @Test
ovicClassifier_quantizedModelCreateSuccess()83   public void ovicClassifier_quantizedModelCreateSuccess() throws Exception {
84     classifier = new OvicClassifier(labelsInputStream, quantizedModel);
85     assertThat(classifier).isNotNull();
86   }
87 
88   @Test
ovicClassifier_floatModelCreateSuccess()89   public void ovicClassifier_floatModelCreateSuccess() throws Exception {
90     classifier = new OvicClassifier(labelsInputStream, floatModel);
91     assertThat(classifier).isNotNull();
92   }
93 
94   @Test
ovicClassifier_quantizedModelClassifySuccess()95   public void ovicClassifier_quantizedModelClassifySuccess() throws Exception {
96     classifier = new OvicClassifier(labelsInputStream, quantizedModel);
97     testResult = classifier.classifyByteBuffer(testImage);
98     assertCorrectTopK(testResult);
99   }
100 
101   @Test
ovicClassifier_floatModelClassifySuccess()102   public void ovicClassifier_floatModelClassifySuccess() throws Exception {
103     classifier = new OvicClassifier(labelsInputStream, floatModel);
104     testResult = classifier.classifyByteBuffer(testImage);
105     assertCorrectTopK(testResult);
106   }
107 
108   @Test
ovicClassifier_lowResModelClassifySuccess()109   public void ovicClassifier_lowResModelClassifySuccess() throws Exception {
110     classifier = new OvicClassifier(labelsInputStream, lowResModel);
111     testResult = classifier.classifyByteBuffer(lowResTestImage);
112     assertCorrectTopK(testResult);
113   }
114 
115   @Test
ovicClassifier_latencyNotNull()116   public void ovicClassifier_latencyNotNull() throws Exception {
117     classifier = new OvicClassifier(labelsInputStream, floatModel);
118     testResult = classifier.classifyByteBuffer(testImage);
119     assertThat(testResult.latencyNano).isNotNull();
120   }
121 
122   @Test
ovicClassifier_mismatchedInputResolutionFails()123   public void ovicClassifier_mismatchedInputResolutionFails() throws Exception {
124     classifier = new OvicClassifier(labelsInputStream, lowResModel);
125     int[] inputDims = classifier.getInputDims();
126     assertThat(inputDims[1]).isEqualTo(128);
127     assertThat(inputDims[2]).isEqualTo(128);
128     try {
129       testResult = classifier.classifyByteBuffer(testImage);
130       fail();
131     } catch (IllegalArgumentException e) {
132       // Success.
133     }
134   }
135 
toByteBuffer(BufferedImage image)136   private static ByteBuffer toByteBuffer(BufferedImage image) {
137     ByteBuffer imgData = ByteBuffer.allocateDirect(
138         image.getHeight() * image.getWidth() * 3);
139     imgData.order(ByteOrder.nativeOrder());
140     for (int y = 0; y < image.getHeight(); y++) {
141       for (int x = 0; x < image.getWidth(); x++) {
142         int val = image.getRGB(x, y);
143         imgData.put((byte) ((val >> 16) & 0xFF));
144         imgData.put((byte) ((val >> 8) & 0xFF));
145         imgData.put((byte) (val & 0xFF));
146       }
147     }
148     return imgData;
149   }
150 
assertCorrectTopK(OvicClassificationResult testResult)151   private static void assertCorrectTopK(OvicClassificationResult testResult) {
152     assertThat(testResult.topKClasses.size()).isGreaterThan(0);
153     Boolean topKAccurate = false;
154     // Assert that the correct class is in the top K.
155     for (int i = 0; i < testResult.topKIndices.size(); i++) {
156       if (testResult.topKIndices.get(i) == TEST_IMAGE_GROUNDTRUTH) {
157         topKAccurate = true;
158         break;
159       }
160     }
161     System.out.println(testResult.toString());
162     System.out.flush();
163     assertThat(topKAccurate).isTrue();
164   }
165 
loadModelFile(String modelFilePath)166   private static MappedByteBuffer loadModelFile(String modelFilePath) throws IOException {
167     File modelfile = new File(modelFilePath);
168     FileInputStream inputStream = new FileInputStream(modelfile);
169     FileChannel fileChannel = inputStream.getChannel();
170     long startOffset = 0L;
171     long declaredLength = fileChannel.size();
172     return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
173   }
174 }
175