1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 import ExecuTorch 10 import ImageClassification 11 import UIKit 12 13 import os.log 14 15 public enum MobileNetClassifierError: Error { 16 case inputPointer 17 case rawData 18 case transform 19 20 var localizedDescription: String { 21 switch self { 22 case .inputPointer: 23 return "Cannot get the input pointer base address" 24 case .rawData: 25 return "Cannot get the pixel data from the image" 26 case .transform: 27 return "Cannot transform the image" 28 } 29 } 30 } 31 32 // See https://pytorch.org/vision/main/models/generated/torchvision.models.mobilenet_v3_small.html 33 // on model input/output spec. 34 public class MobileNetClassifier: ImageClassification { 35 private static let resizeSize: CGFloat = 256 36 private static let cropSize: CGFloat = 224 37 38 private var mobileNetClassifier: ETMobileNetClassifier 39 private var labels: [String] = [] 40 private var rawDataBuffer: [UInt8] 41 private var normalizedBuffer: [Float] 42 43 public init?(modelFilePath: String, labelsFilePath: String) throws { 44 labels = try String(contentsOfFile: labelsFilePath, encoding: .utf8) 45 .components(separatedBy: .newlines) 46 mobileNetClassifier = ETMobileNetClassifier(filePath: modelFilePath) 47 rawDataBuffer = [UInt8](repeating: 0, count: Int(Self.cropSize * Self.cropSize) * 4) 48 normalizedBuffer = [Float](repeating: 0, count: rawDataBuffer.count / 4 * 3) 49 50 #if DEBUG 51 Log.shared.add(sink: self) 52 #endif 53 } 54 55 deinit { 56 #if DEBUG 57 Log.shared.remove(sink: self) 58 #endif 59 } 60 classifynull61 public func classify(image: UIImage) throws -> [Classification] { 62 var input = try normalize(rawData(from: transformed(image))) 63 var output = [Float](repeating: 0, count: labels.count) 64 65 try input.withUnsafeMutableBufferPointer { inputPointer in 66 guard let inputPointerBaseAddress = inputPointer.baseAddress else { 67 throw MobileNetClassifierError.inputPointer 68 } 69 try mobileNetClassifier.classify( 70 withInput: inputPointerBaseAddress, 71 output: &output, 72 outputSize: labels.count) 73 } 74 return softmax(output).enumerated().sorted(by: { $0.element > $1.element }) 75 .compactMap { (index, probability) -> Classification? in 76 guard index < labels.count else { return nil } 77 return Classification(label: labels[index], confidence: probability) 78 } 79 } 80 transformednull81 private func transformed(_ image: UIImage) throws -> UIImage { 82 let aspectRatio = image.size.width / image.size.height 83 let targetSize = 84 aspectRatio > 1 85 ? CGSize(width: Self.resizeSize * aspectRatio, height: Self.resizeSize) 86 : CGSize(width: Self.resizeSize, height: Self.resizeSize / aspectRatio) 87 let cropRect = CGRect( 88 x: (targetSize.width - Self.cropSize) / 2, 89 y: (targetSize.height - Self.cropSize) / 2, 90 width: Self.cropSize, 91 height: Self.cropSize) 92 93 UIGraphicsBeginImageContextWithOptions(cropRect.size, false, 1) 94 defer { UIGraphicsEndImageContext() } 95 image.draw( 96 in: CGRect( 97 x: -cropRect.origin.x, 98 y: -cropRect.origin.y, 99 width: targetSize.width, 100 height: targetSize.height)) 101 guard let resizedAndCroppedImage = UIGraphicsGetImageFromCurrentImageContext() 102 else { 103 throw MobileNetClassifierError.transform 104 } 105 return resizedAndCroppedImage 106 } 107 rawDatanull108 private func rawData(from image: UIImage) throws -> [UInt8] { 109 guard let cgImage = image.cgImage else { 110 throw MobileNetClassifierError.rawData 111 } 112 let context = CGContext( 113 data: &rawDataBuffer, 114 width: cgImage.width, 115 height: cgImage.height, 116 bitsPerComponent: 8, 117 bytesPerRow: cgImage.width * 4, 118 space: CGColorSpaceCreateDeviceRGB(), 119 bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue | CGBitmapInfo.byteOrder32Big.rawValue 120 ) 121 context?.draw( 122 cgImage, 123 in: CGRect( 124 origin: CGPoint.zero, 125 size: CGSize(width: cgImage.width, height: cgImage.height))) 126 return rawDataBuffer 127 } 128 normalizenull129 private func normalize(_ rawData: [UInt8]) -> [Float] { 130 let mean: [Float] = [0.485, 0.456, 0.406] 131 let std: [Float] = [0.229, 0.224, 0.225] 132 let pixelCount = rawData.count / 4 133 134 for i in 0..<pixelCount { 135 normalizedBuffer[i] = (Float(rawData[i * 4 + 0]) / 255 - mean[0]) / std[0] 136 normalizedBuffer[i + pixelCount] = (Float(rawData[i * 4 + 1]) / 255 - mean[1]) / std[1] 137 normalizedBuffer[i + pixelCount * 2] = (Float(rawData[i * 4 + 2]) / 255 - mean[2]) / std[2] 138 } 139 return normalizedBuffer 140 } 141 softmaxnull142 private func softmax(_ input: [Float]) -> [Float] { 143 let maxInput = input.max() ?? 0 144 let expInput = input.map { exp($0 - maxInput) } 145 let sumExpInput = expInput.reduce(0, +) 146 return expInput.map { $0 / sumExpInput } 147 } 148 } 149 150 #if DEBUG 151 extension MobileNetClassifier: LogSink { lognull152 public func log(level: LogLevel, timestamp: TimeInterval, filename: String, line: UInt, message: String) { 153 let logMessage = "executorch:\(filename):\(line) \(message)" 154 155 switch level { 156 case .debug: 157 os_log(.debug, "%{public}@", logMessage) 158 case .info: 159 os_log(.info, "%{public}@", logMessage) 160 case .error: 161 os_log(.error, "%{public}@", logMessage) 162 case .fatal: 163 os_log(.fault, "%{public}@", logMessage) 164 default: 165 os_log("%{public}@", logMessage) 166 } 167 } 168 } 169 #endif 170