1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 import ImageClassification 10 import XCTest 11 12 @testable import MobileNetClassifier 13 14 final class MobileNetClassifierTest: XCTestCase { 15 testV3WithPortableBackendnull16 func testV3WithPortableBackend() throws { 17 try run(model: "mv3") 18 } 19 testV3WithCoreMLBackendnull20 func testV3WithCoreMLBackend() throws { 21 try run(model: "mv3_coreml_all") 22 } 23 testV3WithMPSBackendnull24 func testV3WithMPSBackend() throws { 25 try run(model: "mv3_mps_float16") 26 } 27 testV3WithXNNPACKBackendnull28 func testV3WithXNNPACKBackend() throws { 29 try run(model: "mv3_xnnpack_fp32") 30 } 31 runnull32 private func run(model modelName: String) throws { 33 guard 34 let modelFilePath = Bundle(for: type(of: self)) 35 .path(forResource: modelName, ofType: "pte") 36 else { 37 XCTFail("Failed to get model path") 38 return 39 } 40 guard 41 let labelsFilePath = Bundle(for: type(of: self)) 42 .path(forResource: "imagenet_classes", ofType: "txt") 43 else { 44 XCTFail("Failed to get labels path") 45 return 46 } 47 let classifier = try MobileNetClassifier( 48 modelFilePath: modelFilePath, 49 labelsFilePath: labelsFilePath) 50 for expectedClassification in [ 51 Classification(label: "Arctic fox", confidence: 0.9), 52 Classification(label: "Samoyed", confidence: 0.7), 53 Classification(label: "hot pot", confidence: 0.8), 54 ] { 55 guard 56 let imagePath = Bundle(for: type(of: self)) 57 .path(forResource: expectedClassification.label, ofType: "jpg"), 58 let image = UIImage(contentsOfFile: imagePath) 59 else { 60 XCTFail("Failed to get image path or image") 61 return 62 } 63 guard let classification = try classifier?.classify(image: image).first 64 else { 65 XCTFail("Failed to run the model") 66 return 67 } 68 XCTAssertEqual(classification.label, expectedClassification.label) 69 XCTAssertGreaterThan(classification.confidence, expectedClassification.confidence) 70 } 71 } 72 } 73