xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/swift/Tests/InterpreterTests.swift (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 // Copyright 2018 Google Inc. All rights reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at:
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 import XCTest
16 
17 @testable import TensorFlowLite
18 
19 class InterpreterTests: XCTestCase {
20 
21   var interpreter: Interpreter!
22 
setUpnull23   override func setUp() {
24     super.setUp()
25 
26     interpreter = try! Interpreter(modelPath: AddModel.path)
27   }
28 
tearDownnull29   override func tearDown() {
30     interpreter = nil
31 
32     super.tearDown()
33   }
34 
testInit_ValidModelPathnull35   func testInit_ValidModelPath() {
36     XCTAssertNoThrow(try Interpreter(modelPath: AddModel.path))
37   }
38 
testInit_InvalidModelPath_ThrowsFailedToLoadModelnull39   func testInit_InvalidModelPath_ThrowsFailedToLoadModel() {
40     XCTAssertThrowsError(try Interpreter(modelPath: "/invalid/path")) { error in
41       self.assertEqualErrors(actual: error, expected: .failedToLoadModel)
42     }
43   }
44 
testInitWithOptionsnull45   func testInitWithOptions() throws {
46     var options = Interpreter.Options()
47     options.threadCount = 2
48     let interpreter = try Interpreter(modelPath: AddQuantizedModel.path, options: options)
49     XCTAssertNotNil(interpreter.options)
50     XCTAssertNil(interpreter.delegates)
51   }
52 
testInputTensorCountnull53   func testInputTensorCount() {
54     XCTAssertEqual(interpreter.inputTensorCount, AddModel.inputTensorCount)
55   }
56 
testOutputTensorCountnull57   func testOutputTensorCount() {
58     XCTAssertEqual(interpreter.outputTensorCount, AddModel.outputTensorCount)
59   }
60 
testInvokenull61   func testInvoke() throws {
62     try interpreter.allocateTensors()
63     XCTAssertNoThrow(try interpreter.invoke())
64   }
65 
testInvoke_ThrowsAllocateTensorsRequired_ModelNotReadynull66   func testInvoke_ThrowsAllocateTensorsRequired_ModelNotReady() {
67     XCTAssertThrowsError(try interpreter.invoke()) { error in
68       self.assertEqualErrors(actual: error, expected: .allocateTensorsRequired)
69     }
70   }
71 
testInputTensorAtIndexnull72   func testInputTensorAtIndex() throws {
73     try setUpAddModelInputTensor()
74     let inputTensor = try interpreter.input(at: AddModel.validIndex)
75     XCTAssertEqual(inputTensor, AddModel.inputTensor)
76   }
77 
testInputTensorAtIndex_QuantizedModelnull78   func testInputTensorAtIndex_QuantizedModel() throws {
79     interpreter = try Interpreter(modelPath: AddQuantizedModel.path)
80     try setUpAddQuantizedModelInputTensor()
81     let inputTensor = try interpreter.input(at: AddQuantizedModel.inputOutputIndex)
82     XCTAssertEqual(inputTensor, AddQuantizedModel.inputTensor)
83   }
84 
testInputTensorAtIndex_ThrowsInvalidIndexnull85   func testInputTensorAtIndex_ThrowsInvalidIndex() throws {
86     try interpreter.allocateTensors()
87     XCTAssertThrowsError(try interpreter.input(at: AddModel.invalidIndex)) { error in
88       let maxIndex = AddModel.inputTensorCount - 1
89       self.assertEqualErrors(
90         actual: error,
91         expected: .invalidTensorIndex(index: AddModel.invalidIndex, maxIndex: maxIndex)
92       )
93     }
94   }
95 
testInputTensorAtIndex_ThrowsAllocateTensorsRequirednull96   func testInputTensorAtIndex_ThrowsAllocateTensorsRequired() {
97     XCTAssertThrowsError(try interpreter.input(at: AddModel.validIndex)) { error in
98       self.assertEqualErrors(actual: error, expected: .allocateTensorsRequired)
99     }
100   }
101 
testOutputTensorAtIndexnull102   func testOutputTensorAtIndex() throws {
103     try setUpAddModelInputTensor()
104     try interpreter.invoke()
105     let outputTensor = try interpreter.output(at: AddModel.validIndex)
106     XCTAssertEqual(outputTensor, AddModel.outputTensor)
107     let expectedResults = [Float32](unsafeData: outputTensor.data)
108     XCTAssertEqual(expectedResults, AddModel.results)
109   }
110 
testOutputTensorAtIndex_QuantizedModelnull111   func testOutputTensorAtIndex_QuantizedModel() throws {
112     interpreter = try Interpreter(modelPath: AddQuantizedModel.path)
113     try setUpAddQuantizedModelInputTensor()
114     try interpreter.invoke()
115     let outputTensor = try interpreter.output(at: AddQuantizedModel.inputOutputIndex)
116     XCTAssertEqual(outputTensor, AddQuantizedModel.outputTensor)
117     let expectedResults = [UInt8](outputTensor.data)
118     XCTAssertEqual(expectedResults, AddQuantizedModel.results)
119   }
120 
testOutputTensorAtIndex_ThrowsInvalidIndexnull121   func testOutputTensorAtIndex_ThrowsInvalidIndex() throws {
122     try interpreter.allocateTensors()
123     try interpreter.invoke()
124     XCTAssertThrowsError(try interpreter.output(at: AddModel.invalidIndex)) { error in
125       let maxIndex = AddModel.outputTensorCount - 1
126       self.assertEqualErrors(
127         actual: error,
128         expected: .invalidTensorIndex(index: AddModel.invalidIndex, maxIndex: maxIndex)
129       )
130     }
131   }
132 
testOutputTensorAtIndex_ThrowsInvokeInterpreterRequirednull133   func testOutputTensorAtIndex_ThrowsInvokeInterpreterRequired() {
134     XCTAssertThrowsError(try interpreter.output(at: AddModel.validIndex)) { error in
135       self.assertEqualErrors(actual: error, expected: .invokeInterpreterRequired)
136     }
137   }
138 
testResizeInputTensorAtIndexToShapenull139   func testResizeInputTensorAtIndexToShape() {
140     XCTAssertNoThrow(try interpreter.resizeInput(at: AddModel.validIndex, to: [2, 2, 3]))
141     XCTAssertNoThrow(try interpreter.allocateTensors())
142   }
143 
testResizeInputTensorAtIndexToShape_ThrowsInvalidIndexnull144   func testResizeInputTensorAtIndexToShape_ThrowsInvalidIndex() {
145     XCTAssertThrowsError(
146       try interpreter.resizeInput(
147         at: AddModel.invalidIndex,
148         to: [2, 2, 3]
149       )
150     ) { error in
151       let maxIndex = AddModel.inputTensorCount - 1
152       self.assertEqualErrors(
153         actual: error,
154         expected: .invalidTensorIndex(index: AddModel.invalidIndex, maxIndex: maxIndex)
155       )
156     }
157   }
158 
testCopyDataToInputTensorAtIndexnull159   func testCopyDataToInputTensorAtIndex() throws {
160     try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape)
161     try interpreter.allocateTensors()
162     let inputTensor = try interpreter.copy(AddModel.inputData, toInputAt: AddModel.validIndex)
163     XCTAssertEqual(inputTensor.data, AddModel.inputData)
164   }
165 
testCopyDataToInputTensorAtIndex_ThrowsInvalidIndexnull166   func testCopyDataToInputTensorAtIndex_ThrowsInvalidIndex() {
167     XCTAssertThrowsError(
168       try interpreter.copy(
169         AddModel.inputData,
170         toInputAt: AddModel.invalidIndex
171       )
172     ) { error in
173       let maxIndex = AddModel.inputTensorCount - 1
174       self.assertEqualErrors(
175         actual: error,
176         expected: .invalidTensorIndex(index: AddModel.invalidIndex, maxIndex: maxIndex)
177       )
178     }
179   }
180 
testCopyDataToInputTensorAtIndex_ThrowsInvalidDataCountnull181   func testCopyDataToInputTensorAtIndex_ThrowsInvalidDataCount() throws {
182     try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape)
183     try interpreter.allocateTensors()
184     let invalidData = Data(count: AddModel.dataCount - 1)
185     XCTAssertThrowsError(
186       try interpreter.copy(
187         invalidData,
188         toInputAt: AddModel.validIndex
189       )
190     ) { error in
191       self.assertEqualErrors(
192         actual: error,
193         expected: .invalidTensorDataCount(provided: invalidData.count, required: AddModel.dataCount)
194       )
195     }
196   }
197 
testAllocateTensorsnull198   func testAllocateTensors() {
199     XCTAssertNoThrow(try interpreter.allocateTensors())
200   }
201 
202   // MARK: - Private
203 
setUpAddModelInputTensornull204   private func setUpAddModelInputTensor() throws {
205     precondition(interpreter != nil)
206     try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape)
207     try interpreter.allocateTensors()
208     try interpreter.copy(AddModel.inputData, toInputAt: AddModel.validIndex)
209   }
210 
setUpAddQuantizedModelInputTensornull211   private func setUpAddQuantizedModelInputTensor() throws {
212     precondition(interpreter != nil)
213     try interpreter.resizeInput(at: AddQuantizedModel.inputOutputIndex, to: AddQuantizedModel.shape)
214     try interpreter.allocateTensors()
215     try interpreter.copy(AddQuantizedModel.inputData, toInputAt: AddQuantizedModel.inputOutputIndex)
216   }
217 
assertEqualErrorsnull218   private func assertEqualErrors(actual: Error, expected: InterpreterError) {
219     guard let actual = actual as? InterpreterError else {
220       XCTFail("Actual error should be of type InterpreterError.")
221       return
222     }
223     XCTAssertEqual(actual, expected)
224   }
225 }
226 
227 class InterpreterOptionsTests: XCTestCase {
228 
testInitWithDefaultValuesnull229   func testInitWithDefaultValues() {
230     let options = Interpreter.Options()
231     XCTAssertNil(options.threadCount)
232     XCTAssertFalse(options.isXNNPackEnabled)
233   }
234 
testInitWithCustomValuesnull235   func testInitWithCustomValues() {
236     var options = Interpreter.Options()
237 
238     options.threadCount = 2
239     XCTAssertEqual(options.threadCount, 2)
240 
241     options.isXNNPackEnabled = false
242     XCTAssertFalse(options.isXNNPackEnabled)
243 
244     options.isXNNPackEnabled = true
245     XCTAssertTrue(options.isXNNPackEnabled)
246   }
247 
testEquatablenull248   func testEquatable() {
249     var options1 = Interpreter.Options()
250     var options2 = Interpreter.Options()
251     XCTAssertEqual(options1, options2)
252 
253     options1.threadCount = 2
254     options2.threadCount = 2
255     XCTAssertEqual(options1, options2)
256 
257     options2.threadCount = 3
258     XCTAssertNotEqual(options1, options2)
259 
260     options2.threadCount = 2
261     XCTAssertEqual(options1, options2)
262 
263     options2.isXNNPackEnabled = true
264     XCTAssertNotEqual(options1, options2)
265 
266     options1.isXNNPackEnabled = true
267     XCTAssertEqual(options1, options2)
268   }
269 }
270 
271 // MARK: - Constants
272 
273 /// Values for the `add.bin` model.
274 enum AddModel {
275   static let info = (name: "add", extension: "bin")
276   static let inputTensorCount = 1
277   static let outputTensorCount = 1
278   static let invalidIndex = 1
279   static let validIndex = 0
280   static let shape: Tensor.Shape = [2]
281   static let dataCount = inputData.count
282   static let inputData = Data(copyingBufferOf: [Float32(1.0), Float32(3.0)])
283   static let outputData = Data(copyingBufferOf: [Float32(3.0), Float32(9.0)])
284   static let results = [Float32(3.0), Float32(9.0)]
285 
286   static let inputTensor = Tensor(
287     name: "input",
288     dataType: .float32,
289     shape: shape,
290     data: inputData
291   )
292   static let outputTensor = Tensor(
293     name: "output",
294     dataType: .float32,
295     shape: shape,
296     data: outputData
297   )
298 
299   static var path: String = {
300     let bundle = Bundle(for: InterpreterTests.self)
301     guard let path = bundle.path(forResource: info.name, ofType: info.extension) else { return "" }
302     return path
303   }()
304 }
305 
306 /// Values for the `add_quantized.bin` model.
307 enum AddQuantizedModel {
308   static let info = (name: "add_quantized", extension: "bin")
309   static let inputOutputIndex = 0
310   static let shape: Tensor.Shape = [2]
311   static let inputData = Data([1, 3])
312   static let outputData = Data([3, 9])
313   static let quantizationParameters = QuantizationParameters(scale: 0.003922, zeroPoint: 0)
314   static let results: [UInt8] = [3, 9]
315 
316   static let inputTensor = Tensor(
317     name: "input",
318     dataType: .uInt8,
319     shape: shape,
320     data: inputData,
321     quantizationParameters: quantizationParameters
322   )
323   static let outputTensor = Tensor(
324     name: "output",
325     dataType: .uInt8,
326     shape: shape,
327     data: outputData,
328     quantizationParameters: quantizationParameters
329   )
330 
331   static var path: String = {
332     let bundle = Bundle(for: InterpreterTests.self)
333     guard let path = bundle.path(forResource: info.name, ofType: info.extension) else { return "" }
334     return path
335   }()
336 }
337 
338 // MARK: - Extensions
339 
340 extension Array {
341   /// Creates a new array from the bytes of the given unsafe data.
342   ///
343   /// - Warning: The array's `Element` type must be trivial in that it can be copied bit for bit
344   ///     with no indirection or reference-counting operations; otherwise, copying the raw bytes in
345   ///     the `unsafeData`'s buffer to a new array returns an unsafe copy.
346   /// - Note: Returns `nil` if `unsafeData.count` is not a multiple of
347   ///     `MemoryLayout<Element>.stride`.
348   /// - Parameter unsafeData: The data containing the bytes to turn into an array.
349   init?(unsafeData: Data) {
350     guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil }
351     #if swift(>=5.0)
352       self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) }
353     #else
354       self = unsafeData.withUnsafeBytes {
355         .init(
356           UnsafeBufferPointer<Element>(
357             start: $0,
358             count: unsafeData.count / MemoryLayout<Element>.stride
359           ))
360       }
361     #endif  // swift(>=5.0)
362   }
363 }
364 
365 extension Data {
366   /// Creates a new buffer by copying the buffer pointer of the given array.
367   ///
368   /// - Warning: The given array's element type `T` must be trivial in that it can be copied bit
369   ///     for bit with no indirection or reference-counting operations; otherwise, reinterpreting
370   ///     data from the resulting buffer has undefined behavior.
371   /// - Parameter array: An array with elements of type `T`.
372   init<T>(copyingBufferOf array: [T]) {
373     self = array.withUnsafeBufferPointer(Data.init)
374   }
375 }
376