1 // Copyright 2018 Google Inc. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at: 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 import XCTest 16 17 @testable import TensorFlowLite 18 19 class InterpreterTests: XCTestCase { 20 21 var interpreter: Interpreter! 22 setUpnull23 override func setUp() { 24 super.setUp() 25 26 interpreter = try! Interpreter(modelPath: AddModel.path) 27 } 28 tearDownnull29 override func tearDown() { 30 interpreter = nil 31 32 super.tearDown() 33 } 34 testInit_ValidModelPathnull35 func testInit_ValidModelPath() { 36 XCTAssertNoThrow(try Interpreter(modelPath: AddModel.path)) 37 } 38 testInit_InvalidModelPath_ThrowsFailedToLoadModelnull39 func testInit_InvalidModelPath_ThrowsFailedToLoadModel() { 40 XCTAssertThrowsError(try Interpreter(modelPath: "/invalid/path")) { error in 41 self.assertEqualErrors(actual: error, expected: .failedToLoadModel) 42 } 43 } 44 testInitWithOptionsnull45 func testInitWithOptions() throws { 46 var options = Interpreter.Options() 47 options.threadCount = 2 48 let interpreter = try Interpreter(modelPath: AddQuantizedModel.path, options: options) 49 XCTAssertNotNil(interpreter.options) 50 XCTAssertNil(interpreter.delegates) 51 } 52 testInputTensorCountnull53 func testInputTensorCount() { 54 XCTAssertEqual(interpreter.inputTensorCount, AddModel.inputTensorCount) 55 } 56 testOutputTensorCountnull57 func testOutputTensorCount() { 58 XCTAssertEqual(interpreter.outputTensorCount, AddModel.outputTensorCount) 59 } 60 testInvokenull61 func testInvoke() throws { 62 try interpreter.allocateTensors() 63 XCTAssertNoThrow(try interpreter.invoke()) 64 } 65 testInvoke_ThrowsAllocateTensorsRequired_ModelNotReadynull66 func testInvoke_ThrowsAllocateTensorsRequired_ModelNotReady() { 67 XCTAssertThrowsError(try interpreter.invoke()) { error in 68 self.assertEqualErrors(actual: error, expected: .allocateTensorsRequired) 69 } 70 } 71 testInputTensorAtIndexnull72 func testInputTensorAtIndex() throws { 73 try setUpAddModelInputTensor() 74 let inputTensor = try interpreter.input(at: AddModel.validIndex) 75 XCTAssertEqual(inputTensor, AddModel.inputTensor) 76 } 77 testInputTensorAtIndex_QuantizedModelnull78 func testInputTensorAtIndex_QuantizedModel() throws { 79 interpreter = try Interpreter(modelPath: AddQuantizedModel.path) 80 try setUpAddQuantizedModelInputTensor() 81 let inputTensor = try interpreter.input(at: AddQuantizedModel.inputOutputIndex) 82 XCTAssertEqual(inputTensor, AddQuantizedModel.inputTensor) 83 } 84 testInputTensorAtIndex_ThrowsInvalidIndexnull85 func testInputTensorAtIndex_ThrowsInvalidIndex() throws { 86 try interpreter.allocateTensors() 87 XCTAssertThrowsError(try interpreter.input(at: AddModel.invalidIndex)) { error in 88 let maxIndex = AddModel.inputTensorCount - 1 89 self.assertEqualErrors( 90 actual: error, 91 expected: .invalidTensorIndex(index: AddModel.invalidIndex, maxIndex: maxIndex) 92 ) 93 } 94 } 95 testInputTensorAtIndex_ThrowsAllocateTensorsRequirednull96 func testInputTensorAtIndex_ThrowsAllocateTensorsRequired() { 97 XCTAssertThrowsError(try interpreter.input(at: AddModel.validIndex)) { error in 98 self.assertEqualErrors(actual: error, expected: .allocateTensorsRequired) 99 } 100 } 101 testOutputTensorAtIndexnull102 func testOutputTensorAtIndex() throws { 103 try setUpAddModelInputTensor() 104 try interpreter.invoke() 105 let outputTensor = try interpreter.output(at: AddModel.validIndex) 106 XCTAssertEqual(outputTensor, AddModel.outputTensor) 107 let expectedResults = [Float32](unsafeData: outputTensor.data) 108 XCTAssertEqual(expectedResults, AddModel.results) 109 } 110 testOutputTensorAtIndex_QuantizedModelnull111 func testOutputTensorAtIndex_QuantizedModel() throws { 112 interpreter = try Interpreter(modelPath: AddQuantizedModel.path) 113 try setUpAddQuantizedModelInputTensor() 114 try interpreter.invoke() 115 let outputTensor = try interpreter.output(at: AddQuantizedModel.inputOutputIndex) 116 XCTAssertEqual(outputTensor, AddQuantizedModel.outputTensor) 117 let expectedResults = [UInt8](outputTensor.data) 118 XCTAssertEqual(expectedResults, AddQuantizedModel.results) 119 } 120 testOutputTensorAtIndex_ThrowsInvalidIndexnull121 func testOutputTensorAtIndex_ThrowsInvalidIndex() throws { 122 try interpreter.allocateTensors() 123 try interpreter.invoke() 124 XCTAssertThrowsError(try interpreter.output(at: AddModel.invalidIndex)) { error in 125 let maxIndex = AddModel.outputTensorCount - 1 126 self.assertEqualErrors( 127 actual: error, 128 expected: .invalidTensorIndex(index: AddModel.invalidIndex, maxIndex: maxIndex) 129 ) 130 } 131 } 132 testOutputTensorAtIndex_ThrowsInvokeInterpreterRequirednull133 func testOutputTensorAtIndex_ThrowsInvokeInterpreterRequired() { 134 XCTAssertThrowsError(try interpreter.output(at: AddModel.validIndex)) { error in 135 self.assertEqualErrors(actual: error, expected: .invokeInterpreterRequired) 136 } 137 } 138 testResizeInputTensorAtIndexToShapenull139 func testResizeInputTensorAtIndexToShape() { 140 XCTAssertNoThrow(try interpreter.resizeInput(at: AddModel.validIndex, to: [2, 2, 3])) 141 XCTAssertNoThrow(try interpreter.allocateTensors()) 142 } 143 testResizeInputTensorAtIndexToShape_ThrowsInvalidIndexnull144 func testResizeInputTensorAtIndexToShape_ThrowsInvalidIndex() { 145 XCTAssertThrowsError( 146 try interpreter.resizeInput( 147 at: AddModel.invalidIndex, 148 to: [2, 2, 3] 149 ) 150 ) { error in 151 let maxIndex = AddModel.inputTensorCount - 1 152 self.assertEqualErrors( 153 actual: error, 154 expected: .invalidTensorIndex(index: AddModel.invalidIndex, maxIndex: maxIndex) 155 ) 156 } 157 } 158 testCopyDataToInputTensorAtIndexnull159 func testCopyDataToInputTensorAtIndex() throws { 160 try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape) 161 try interpreter.allocateTensors() 162 let inputTensor = try interpreter.copy(AddModel.inputData, toInputAt: AddModel.validIndex) 163 XCTAssertEqual(inputTensor.data, AddModel.inputData) 164 } 165 testCopyDataToInputTensorAtIndex_ThrowsInvalidIndexnull166 func testCopyDataToInputTensorAtIndex_ThrowsInvalidIndex() { 167 XCTAssertThrowsError( 168 try interpreter.copy( 169 AddModel.inputData, 170 toInputAt: AddModel.invalidIndex 171 ) 172 ) { error in 173 let maxIndex = AddModel.inputTensorCount - 1 174 self.assertEqualErrors( 175 actual: error, 176 expected: .invalidTensorIndex(index: AddModel.invalidIndex, maxIndex: maxIndex) 177 ) 178 } 179 } 180 testCopyDataToInputTensorAtIndex_ThrowsInvalidDataCountnull181 func testCopyDataToInputTensorAtIndex_ThrowsInvalidDataCount() throws { 182 try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape) 183 try interpreter.allocateTensors() 184 let invalidData = Data(count: AddModel.dataCount - 1) 185 XCTAssertThrowsError( 186 try interpreter.copy( 187 invalidData, 188 toInputAt: AddModel.validIndex 189 ) 190 ) { error in 191 self.assertEqualErrors( 192 actual: error, 193 expected: .invalidTensorDataCount(provided: invalidData.count, required: AddModel.dataCount) 194 ) 195 } 196 } 197 testAllocateTensorsnull198 func testAllocateTensors() { 199 XCTAssertNoThrow(try interpreter.allocateTensors()) 200 } 201 202 // MARK: - Private 203 setUpAddModelInputTensornull204 private func setUpAddModelInputTensor() throws { 205 precondition(interpreter != nil) 206 try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape) 207 try interpreter.allocateTensors() 208 try interpreter.copy(AddModel.inputData, toInputAt: AddModel.validIndex) 209 } 210 setUpAddQuantizedModelInputTensornull211 private func setUpAddQuantizedModelInputTensor() throws { 212 precondition(interpreter != nil) 213 try interpreter.resizeInput(at: AddQuantizedModel.inputOutputIndex, to: AddQuantizedModel.shape) 214 try interpreter.allocateTensors() 215 try interpreter.copy(AddQuantizedModel.inputData, toInputAt: AddQuantizedModel.inputOutputIndex) 216 } 217 assertEqualErrorsnull218 private func assertEqualErrors(actual: Error, expected: InterpreterError) { 219 guard let actual = actual as? InterpreterError else { 220 XCTFail("Actual error should be of type InterpreterError.") 221 return 222 } 223 XCTAssertEqual(actual, expected) 224 } 225 } 226 227 class InterpreterOptionsTests: XCTestCase { 228 testInitWithDefaultValuesnull229 func testInitWithDefaultValues() { 230 let options = Interpreter.Options() 231 XCTAssertNil(options.threadCount) 232 XCTAssertFalse(options.isXNNPackEnabled) 233 } 234 testInitWithCustomValuesnull235 func testInitWithCustomValues() { 236 var options = Interpreter.Options() 237 238 options.threadCount = 2 239 XCTAssertEqual(options.threadCount, 2) 240 241 options.isXNNPackEnabled = false 242 XCTAssertFalse(options.isXNNPackEnabled) 243 244 options.isXNNPackEnabled = true 245 XCTAssertTrue(options.isXNNPackEnabled) 246 } 247 testEquatablenull248 func testEquatable() { 249 var options1 = Interpreter.Options() 250 var options2 = Interpreter.Options() 251 XCTAssertEqual(options1, options2) 252 253 options1.threadCount = 2 254 options2.threadCount = 2 255 XCTAssertEqual(options1, options2) 256 257 options2.threadCount = 3 258 XCTAssertNotEqual(options1, options2) 259 260 options2.threadCount = 2 261 XCTAssertEqual(options1, options2) 262 263 options2.isXNNPackEnabled = true 264 XCTAssertNotEqual(options1, options2) 265 266 options1.isXNNPackEnabled = true 267 XCTAssertEqual(options1, options2) 268 } 269 } 270 271 // MARK: - Constants 272 273 /// Values for the `add.bin` model. 274 enum AddModel { 275 static let info = (name: "add", extension: "bin") 276 static let inputTensorCount = 1 277 static let outputTensorCount = 1 278 static let invalidIndex = 1 279 static let validIndex = 0 280 static let shape: Tensor.Shape = [2] 281 static let dataCount = inputData.count 282 static let inputData = Data(copyingBufferOf: [Float32(1.0), Float32(3.0)]) 283 static let outputData = Data(copyingBufferOf: [Float32(3.0), Float32(9.0)]) 284 static let results = [Float32(3.0), Float32(9.0)] 285 286 static let inputTensor = Tensor( 287 name: "input", 288 dataType: .float32, 289 shape: shape, 290 data: inputData 291 ) 292 static let outputTensor = Tensor( 293 name: "output", 294 dataType: .float32, 295 shape: shape, 296 data: outputData 297 ) 298 299 static var path: String = { 300 let bundle = Bundle(for: InterpreterTests.self) 301 guard let path = bundle.path(forResource: info.name, ofType: info.extension) else { return "" } 302 return path 303 }() 304 } 305 306 /// Values for the `add_quantized.bin` model. 307 enum AddQuantizedModel { 308 static let info = (name: "add_quantized", extension: "bin") 309 static let inputOutputIndex = 0 310 static let shape: Tensor.Shape = [2] 311 static let inputData = Data([1, 3]) 312 static let outputData = Data([3, 9]) 313 static let quantizationParameters = QuantizationParameters(scale: 0.003922, zeroPoint: 0) 314 static let results: [UInt8] = [3, 9] 315 316 static let inputTensor = Tensor( 317 name: "input", 318 dataType: .uInt8, 319 shape: shape, 320 data: inputData, 321 quantizationParameters: quantizationParameters 322 ) 323 static let outputTensor = Tensor( 324 name: "output", 325 dataType: .uInt8, 326 shape: shape, 327 data: outputData, 328 quantizationParameters: quantizationParameters 329 ) 330 331 static var path: String = { 332 let bundle = Bundle(for: InterpreterTests.self) 333 guard let path = bundle.path(forResource: info.name, ofType: info.extension) else { return "" } 334 return path 335 }() 336 } 337 338 // MARK: - Extensions 339 340 extension Array { 341 /// Creates a new array from the bytes of the given unsafe data. 342 /// 343 /// - Warning: The array's `Element` type must be trivial in that it can be copied bit for bit 344 /// with no indirection or reference-counting operations; otherwise, copying the raw bytes in 345 /// the `unsafeData`'s buffer to a new array returns an unsafe copy. 346 /// - Note: Returns `nil` if `unsafeData.count` is not a multiple of 347 /// `MemoryLayout<Element>.stride`. 348 /// - Parameter unsafeData: The data containing the bytes to turn into an array. 349 init?(unsafeData: Data) { 350 guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil } 351 #if swift(>=5.0) 352 self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) } 353 #else 354 self = unsafeData.withUnsafeBytes { 355 .init( 356 UnsafeBufferPointer<Element>( 357 start: $0, 358 count: unsafeData.count / MemoryLayout<Element>.stride 359 )) 360 } 361 #endif // swift(>=5.0) 362 } 363 } 364 365 extension Data { 366 /// Creates a new buffer by copying the buffer pointer of the given array. 367 /// 368 /// - Warning: The given array's element type `T` must be trivial in that it can be copied bit 369 /// for bit with no indirection or reference-counting operations; otherwise, reinterpreting 370 /// data from the resulting buffer has undefined behavior. 371 /// - Parameter array: An array with elements of type `T`. 372 init<T>(copyingBufferOf array: [T]) { 373 self = array.withUnsafeBufferPointer(Data.init) 374 } 375 } 376