1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 import ExecuTorch
10 import ImageClassification
11 import UIKit
12 
13 import os.log
14 
15 public enum MobileNetClassifierError: Error {
16   case inputPointer
17   case rawData
18   case transform
19 
20   var localizedDescription: String {
21     switch self {
22     case .inputPointer:
23       return "Cannot get the input pointer base address"
24     case .rawData:
25       return "Cannot get the pixel data from the image"
26     case .transform:
27       return "Cannot transform the image"
28     }
29   }
30 }
31 
32 // See https://pytorch.org/vision/main/models/generated/torchvision.models.mobilenet_v3_small.html
33 // on model input/output spec.
34 public class MobileNetClassifier: ImageClassification {
35   private static let resizeSize: CGFloat = 256
36   private static let cropSize: CGFloat = 224
37 
38   private var mobileNetClassifier: ETMobileNetClassifier
39   private var labels: [String] = []
40   private var rawDataBuffer: [UInt8]
41   private var normalizedBuffer: [Float]
42 
43   public init?(modelFilePath: String, labelsFilePath: String) throws {
44     labels = try String(contentsOfFile: labelsFilePath, encoding: .utf8)
45       .components(separatedBy: .newlines)
46     mobileNetClassifier = ETMobileNetClassifier(filePath: modelFilePath)
47     rawDataBuffer = [UInt8](repeating: 0, count: Int(Self.cropSize * Self.cropSize) * 4)
48     normalizedBuffer = [Float](repeating: 0, count: rawDataBuffer.count / 4 * 3)
49 
50     #if DEBUG
51     Log.shared.add(sink: self)
52     #endif
53   }
54 
55   deinit {
56     #if DEBUG
57     Log.shared.remove(sink: self)
58     #endif
59   }
60 
classifynull61   public func classify(image: UIImage) throws -> [Classification] {
62     var input = try normalize(rawData(from: transformed(image)))
63     var output = [Float](repeating: 0, count: labels.count)
64 
65     try input.withUnsafeMutableBufferPointer { inputPointer in
66       guard let inputPointerBaseAddress = inputPointer.baseAddress else {
67         throw MobileNetClassifierError.inputPointer
68       }
69       try mobileNetClassifier.classify(
70         withInput: inputPointerBaseAddress,
71         output: &output,
72         outputSize: labels.count)
73     }
74     return softmax(output).enumerated().sorted(by: { $0.element > $1.element })
75       .compactMap { (index, probability) -> Classification? in
76         guard index < labels.count else { return nil }
77         return Classification(label: labels[index], confidence: probability)
78       }
79   }
80 
transformednull81   private func transformed(_ image: UIImage) throws -> UIImage {
82     let aspectRatio = image.size.width / image.size.height
83     let targetSize =
84       aspectRatio > 1
85       ? CGSize(width: Self.resizeSize * aspectRatio, height: Self.resizeSize)
86       : CGSize(width: Self.resizeSize, height: Self.resizeSize / aspectRatio)
87     let cropRect = CGRect(
88       x: (targetSize.width - Self.cropSize) / 2,
89       y: (targetSize.height - Self.cropSize) / 2,
90       width: Self.cropSize,
91       height: Self.cropSize)
92 
93     UIGraphicsBeginImageContextWithOptions(cropRect.size, false, 1)
94     defer { UIGraphicsEndImageContext() }
95     image.draw(
96       in: CGRect(
97         x: -cropRect.origin.x,
98         y: -cropRect.origin.y,
99         width: targetSize.width,
100         height: targetSize.height))
101     guard let resizedAndCroppedImage = UIGraphicsGetImageFromCurrentImageContext()
102     else {
103       throw MobileNetClassifierError.transform
104     }
105     return resizedAndCroppedImage
106   }
107 
rawDatanull108   private func rawData(from image: UIImage) throws -> [UInt8] {
109     guard let cgImage = image.cgImage else {
110       throw MobileNetClassifierError.rawData
111     }
112     let context = CGContext(
113       data: &rawDataBuffer,
114       width: cgImage.width,
115       height: cgImage.height,
116       bitsPerComponent: 8,
117       bytesPerRow: cgImage.width * 4,
118       space: CGColorSpaceCreateDeviceRGB(),
119       bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue | CGBitmapInfo.byteOrder32Big.rawValue
120     )
121     context?.draw(
122       cgImage,
123       in: CGRect(
124         origin: CGPoint.zero,
125         size: CGSize(width: cgImage.width, height: cgImage.height)))
126     return rawDataBuffer
127   }
128 
normalizenull129   private func normalize(_ rawData: [UInt8]) -> [Float] {
130     let mean: [Float] = [0.485, 0.456, 0.406]
131     let std: [Float] = [0.229, 0.224, 0.225]
132     let pixelCount = rawData.count / 4
133 
134     for i in 0..<pixelCount {
135       normalizedBuffer[i] = (Float(rawData[i * 4 + 0]) / 255 - mean[0]) / std[0]
136       normalizedBuffer[i + pixelCount] = (Float(rawData[i * 4 + 1]) / 255 - mean[1]) / std[1]
137       normalizedBuffer[i + pixelCount * 2] = (Float(rawData[i * 4 + 2]) / 255 - mean[2]) / std[2]
138     }
139     return normalizedBuffer
140   }
141 
softmaxnull142   private func softmax(_ input: [Float]) -> [Float] {
143     let maxInput = input.max() ?? 0
144     let expInput = input.map { exp($0 - maxInput) }
145     let sumExpInput = expInput.reduce(0, +)
146     return expInput.map { $0 / sumExpInput }
147   }
148 }
149 
150 #if DEBUG
151 extension MobileNetClassifier: LogSink {
lognull152   public func log(level: LogLevel, timestamp: TimeInterval, filename: String, line: UInt, message: String) {
153     let logMessage = "executorch:\(filename):\(line) \(message)"
154 
155     switch level {
156     case .debug:
157       os_log(.debug, "%{public}@", logMessage)
158     case .info:
159       os_log(.info, "%{public}@", logMessage)
160     case .error:
161       os_log(.error, "%{public}@", logMessage)
162     case .fatal:
163       os_log(.fault, "%{public}@", logMessage)
164     default:
165       os_log("%{public}@", logMessage)
166     }
167   }
168 }
169 #endif
170