1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 import ImageClassification
10 import XCTest
11 
12 @testable import MobileNetClassifier
13 
14 final class MobileNetClassifierTest: XCTestCase {
15 
testV3WithPortableBackendnull16   func testV3WithPortableBackend() throws {
17     try run(model: "mv3")
18   }
19 
testV3WithCoreMLBackendnull20   func testV3WithCoreMLBackend() throws {
21     try run(model: "mv3_coreml_all")
22   }
23 
testV3WithMPSBackendnull24   func testV3WithMPSBackend() throws {
25     try run(model: "mv3_mps_float16")
26   }
27 
testV3WithXNNPACKBackendnull28   func testV3WithXNNPACKBackend() throws {
29     try run(model: "mv3_xnnpack_fp32")
30   }
31 
runnull32   private func run(model modelName: String) throws {
33     guard
34       let modelFilePath = Bundle(for: type(of: self))
35         .path(forResource: modelName, ofType: "pte")
36     else {
37       XCTFail("Failed to get model path")
38       return
39     }
40     guard
41       let labelsFilePath = Bundle(for: type(of: self))
42         .path(forResource: "imagenet_classes", ofType: "txt")
43     else {
44       XCTFail("Failed to get labels path")
45       return
46     }
47     let classifier = try MobileNetClassifier(
48       modelFilePath: modelFilePath,
49       labelsFilePath: labelsFilePath)
50     for expectedClassification in [
51       Classification(label: "Arctic fox", confidence: 0.9),
52       Classification(label: "Samoyed", confidence: 0.7),
53       Classification(label: "hot pot", confidence: 0.8),
54     ] {
55       guard
56         let imagePath = Bundle(for: type(of: self))
57           .path(forResource: expectedClassification.label, ofType: "jpg"),
58         let image = UIImage(contentsOfFile: imagePath)
59       else {
60         XCTFail("Failed to get image path or image")
61         return
62       }
63       guard let classification = try classifier?.classify(image: image).first
64       else {
65         XCTFail("Failed to run the model")
66         return
67       }
68       XCTAssertEqual(classification.label, expectedClassification.label)
69       XCTAssertGreaterThan(classification.confidence, expectedClassification.confidence)
70     }
71   }
72 }
73