xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/backends/coreml/objc/PTMCoreMLBackend.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#import <torch/csrc/jit/backends/backend.h>
2#import <torch/csrc/jit/backends/coreml/cpp/context.h>
3#import <torch/csrc/jit/backends/coreml/objc/PTMCoreMLCompiler.h>
4#import <torch/csrc/jit/backends/coreml/objc/PTMCoreMLExecutor.h>
5#import <torch/csrc/jit/backends/coreml/objc/PTMCoreMLModelWrapper.h>
6#import <torch/csrc/jit/backends/coreml/objc/PTMCoreMLTensorSpec.h>
7#import <torch/script.h>
8#import <fmt/format.h>
9
10#import <CoreML/CoreML.h>
11
12#if C10_IOS
13#import <UIKit/UIKit.h>
14#elif TARGET_OS_MAC
15#import <Foundation/NSProcessInfo.h>
16#endif
17
18// This is a utility macro that can be used to throw an exception when a CoreML
19// API function produces a NSError. The exception will contain a message with
20// useful info extracted from the NSError.
21#define COREML_THROW_IF_ERROR(error, preamble, ...)     \
22  do {                                                                           \
23    if C10_LIKELY(error) {                                                       \
24      throw c10::Error(                                                          \
25          {__func__, __FILE__, static_cast<uint32_t>(__LINE__)},                 \
26          c10::str(                                                              \
27              preamble,                                                          \
28              " Error details: ",                                                \
29              " Localized_description: ", error.localizedDescription.UTF8String, \
30              " Domain: ", error.domain.UTF8String,                              \
31              " Code: ", error.code,                                             \
32              " User Info: ", error.userInfo.description.UTF8String,             \
33              ##__VA_ARGS__));                                                   \
34    }                                                                            \
35  } while (false)
36
37namespace torch {
38namespace jit {
39namespace mobile {
40namespace coreml {
41
42using c10::impl::GenericDict;
43using c10::impl::GenericList;
44using c10::IValue;
45
46struct CoreMLConfig {
47  std::string backend = "CPU";
48  bool allow_low_precision = true;
49};
50
51std::string tensorListToShapesStr(GenericList tensors) {
52  std::string str("[");
53  for (const auto featureIdx : c10::irange(tensors.size())) {
54    if (featureIdx > 0) {
55      str = fmt::format("{}, ", str);
56    }
57    str = fmt::format("{}[", str);
58    auto shape = tensors.get(featureIdx).toTensor().sizes();
59    for (const auto shapeIdx : c10::irange(shape.size())) {
60      if (shapeIdx > 0) {
61        str = fmt::format("{}, ", str);
62      }
63      str = fmt::format("{}{}", str, shape[shapeIdx]);
64    }
65    str = fmt::format("{}]", str);
66  }
67  str = fmt::format("{}]", str);
68  return str;
69}
70
71bool type_validity(const std::vector<TensorSpec>& specs) {
72  for (const TensorSpec& spec : specs) {
73    if (spec.dtype != c10::ScalarType::Float) {
74      return false;
75    }
76  }
77  return true;
78}
79
80void from_json(const nlohmann::json& j, TensorSpec& spec) {
81  j[0].get_to(spec.name);
82  std::string type_string;
83  j[1].get_to(type_string);
84  spec.dtype = scalar_type(type_string);
85}
86
87void from_json(const nlohmann::json& j, CoreMLConfig& config) {
88  j.at("backend").get_to(config.backend);
89  std::string allow_low_precision_string;
90  j.at("allow_low_precision").get_to(allow_low_precision_string);
91  if (allow_low_precision_string == "True") {
92    config.allow_low_precision = true;
93  } else {
94    config.allow_low_precision = false;
95  }
96}
97
98GenericList pack_outputs(const std::vector<TensorSpec>& output_specs, id<MLFeatureProvider> outputProvider) {
99  c10::List<torch::Tensor> outputs;
100  for (const TensorSpec& spec : output_specs) {
101    NSString *name = [NSString stringWithUTF8String:spec.name.c_str()];
102    MLFeatureValue *val = [outputProvider featureValueForName:name];
103    std::vector<int64_t> output_shape;
104    for (int i = 0; i < val.multiArrayValue.shape.count; ++i) {
105      output_shape.emplace_back(val.multiArrayValue.shape[i].integerValue);
106    }
107    TORCH_CHECK(val.multiArrayValue.dataType == MLMultiArrayDataTypeFloat32, "Core ML backend unexpected output data type");
108    int64_t count = val.multiArrayValue.count;
109    float* temp = static_cast<float*>(std::malloc(count * sizeof(float)));
110    if (@available(iOS 15.4, *)) {
111      [val.multiArrayValue getBytesWithHandler:^(const void * _Nonnull bytes, NSInteger size) {
112        memcpy(temp, (float *)bytes, count * sizeof(float));
113      }];
114    } else {
115      memcpy(temp, (float *)val.multiArrayValue.dataPointer, count * sizeof(float));
116    }
117    auto tensor = at::from_blob(temp, output_shape, [&](void* ptr) { std::free(ptr); }, TensorOptions().dtype(at::kFloat));
118    outputs.push_back(std::move(tensor));
119  }
120  if(output_specs.size() > 1){
121    c10::List<c10::List<torch::Tensor>> output_res;
122    output_res.push_back(std::move(outputs));
123    return c10::impl::toList(std::move(output_res));
124  }
125  return c10::impl::toList(std::move(outputs));
126}
127
128class CoreMLBackend: public torch::jit::PyTorchBackendInterface {
129
130 public:
131  GenericDict compile(IValue processed, GenericDict method_compile_spec) override {
132    const c10::Dict<IValue, IValue> model_dict = processed.toGenericDict();
133    const std::string& extra = model_dict.at("extra").toStringRef();
134    const std::string& model = model_dict.at("model").toStringRef();
135    const std::string modelID = std::string(model_dict.at("hash").toStringRef());
136
137    CoreMLConfig config;
138    std::vector<TensorSpec> input_specs;
139    std::vector<TensorSpec> output_specs;
140
141    try {
142      nlohmann::json extra_json = nlohmann::json::parse(extra);
143      config = extra_json["config"].get<CoreMLConfig>();
144      input_specs = extra_json["inputs"].get<std::vector<TensorSpec>>();
145      output_specs = extra_json["outputs"].get<std::vector<TensorSpec>>();
146    } catch (std::exception& exn) {
147      TORCH_CHECK(false, "Parsing model dict failed!");
148    }
149
150    if (!type_validity(input_specs) || !type_validity(output_specs)) {
151      TORCH_CHECK(false, "Compiling model failed, only float type tensors supported");
152    }
153
154    if (![PTMCoreMLCompiler compileModel:model modelID:modelID]) {
155      TORCH_CHECK(false, "Compiling MLModel failed");
156    }
157
158    NSError *error = nil;
159    MLModel *cpuModel = [PTMCoreMLCompiler loadModel:modelID backend:"cpu" allowLowPrecision:NO error:&error];
160
161    if (!cpuModel) {
162      COREML_THROW_IF_ERROR(error, "Error loading MLModel", " Model spec: ", extra.c_str(), ", Model Hash: ", modelID.c_str());
163    }
164
165    NSMutableArray *orderedFeatures = [NSMutableArray array];
166    for (TensorSpec& spec : input_specs) {
167      NSString *name = [NSString stringWithUTF8String:spec.name.c_str()];
168      [orderedFeatures addObject:name];
169    }
170
171    PTMCoreMLExecutor *executor = [[PTMCoreMLExecutor alloc] initWithFeatureNames:orderedFeatures];
172    executor.model = cpuModel;
173    [executor autorelease];
174
175    dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT, 0), ^{
176      NSError *error = nil;
177      MLModel *configuredModel = [PTMCoreMLCompiler loadModel:modelID backend:config.backend allowLowPrecision:config.allow_low_precision error:&error];
178      // If we fail to configure the model, fall back to CPU
179      executor.model = configuredModel ?: cpuModel;
180    });
181
182    MLModelWrapper model_wrapper = MLModelWrapper(executor);
183    model_wrapper.outputs = output_specs;
184
185    auto model_wrapper_ptr = c10::make_intrusive<MLModelWrapper>(model_wrapper);
186    auto handle = IValue::make_capsule(model_wrapper_ptr);
187
188    c10::Dict<IValue, IValue> ret(StringType::get(), c10::AnyType::get());
189    ret.insert("forward", handle);
190    return c10::impl::toGenericDict(ret);
191  }
192
193  GenericList execute(IValue handle, GenericList inputs) override {
194    @autoreleasepool {
195      const auto model_wrapper = c10::static_intrusive_pointer_cast<MLModelWrapper>(handle.toCapsule());
196
197      PTMCoreMLExecutor *executor = model_wrapper->executor;
198      [executor setInputs:inputs];
199
200      NSError *error = nil;
201      id<MLFeatureProvider> outputsProvider = [executor forward:&error];
202      if (!outputsProvider) {
203        COREML_THROW_IF_ERROR(error, "Error running CoreML inference", " Input Shape:", tensorListToShapesStr(inputs));
204      }
205
206      return pack_outputs(model_wrapper->outputs, outputsProvider);
207    }
208  }
209
210  bool is_available() override {
211#if TARGET_OS_IPHONE
212    return [UIDevice currentDevice].systemVersion.floatValue >= 12.0;
213#elif TARGET_OS_MAC
214    NSOperatingSystemVersion supportedVer = {10, 13, 0};
215    return [[NSProcessInfo processInfo] isOperatingSystemAtLeastVersion:supportedVer];
216#endif
217    return false;
218  }
219};
220
221static auto cls = torch::jit::backend<CoreMLBackend>("coreml");
222
223struct PTMCoreMLContext : public ContextInterface {
224  void setModelCacheDirectory(std::string dir) override {
225    [PTMCoreMLCompiler setCacheDirectory:dir];
226  }
227};
228
229static BackendRegistrar g_coreml_backend(new PTMCoreMLContext());
230
231} // namespace
232}
233}
234}
235