1#import <torch/csrc/jit/backends/backend.h> 2#import <torch/csrc/jit/backends/coreml/cpp/context.h> 3#import <torch/csrc/jit/backends/coreml/objc/PTMCoreMLCompiler.h> 4#import <torch/csrc/jit/backends/coreml/objc/PTMCoreMLExecutor.h> 5#import <torch/csrc/jit/backends/coreml/objc/PTMCoreMLModelWrapper.h> 6#import <torch/csrc/jit/backends/coreml/objc/PTMCoreMLTensorSpec.h> 7#import <torch/script.h> 8#import <fmt/format.h> 9 10#import <CoreML/CoreML.h> 11 12#if C10_IOS 13#import <UIKit/UIKit.h> 14#elif TARGET_OS_MAC 15#import <Foundation/NSProcessInfo.h> 16#endif 17 18// This is a utility macro that can be used to throw an exception when a CoreML 19// API function produces a NSError. The exception will contain a message with 20// useful info extracted from the NSError. 21#define COREML_THROW_IF_ERROR(error, preamble, ...) \ 22 do { \ 23 if C10_LIKELY(error) { \ 24 throw c10::Error( \ 25 {__func__, __FILE__, static_cast<uint32_t>(__LINE__)}, \ 26 c10::str( \ 27 preamble, \ 28 " Error details: ", \ 29 " Localized_description: ", error.localizedDescription.UTF8String, \ 30 " Domain: ", error.domain.UTF8String, \ 31 " Code: ", error.code, \ 32 " User Info: ", error.userInfo.description.UTF8String, \ 33 ##__VA_ARGS__)); \ 34 } \ 35 } while (false) 36 37namespace torch { 38namespace jit { 39namespace mobile { 40namespace coreml { 41 42using c10::impl::GenericDict; 43using c10::impl::GenericList; 44using c10::IValue; 45 46struct CoreMLConfig { 47 std::string backend = "CPU"; 48 bool allow_low_precision = true; 49}; 50 51std::string tensorListToShapesStr(GenericList tensors) { 52 std::string str("["); 53 for (const auto featureIdx : c10::irange(tensors.size())) { 54 if (featureIdx > 0) { 55 str = fmt::format("{}, ", str); 56 } 57 str = fmt::format("{}[", str); 58 auto shape = tensors.get(featureIdx).toTensor().sizes(); 59 for (const auto shapeIdx : c10::irange(shape.size())) { 60 if (shapeIdx > 0) { 61 str = fmt::format("{}, ", str); 62 } 63 str = fmt::format("{}{}", str, shape[shapeIdx]); 64 } 65 str = fmt::format("{}]", str); 66 } 67 str = fmt::format("{}]", str); 68 return str; 69} 70 71bool type_validity(const std::vector<TensorSpec>& specs) { 72 for (const TensorSpec& spec : specs) { 73 if (spec.dtype != c10::ScalarType::Float) { 74 return false; 75 } 76 } 77 return true; 78} 79 80void from_json(const nlohmann::json& j, TensorSpec& spec) { 81 j[0].get_to(spec.name); 82 std::string type_string; 83 j[1].get_to(type_string); 84 spec.dtype = scalar_type(type_string); 85} 86 87void from_json(const nlohmann::json& j, CoreMLConfig& config) { 88 j.at("backend").get_to(config.backend); 89 std::string allow_low_precision_string; 90 j.at("allow_low_precision").get_to(allow_low_precision_string); 91 if (allow_low_precision_string == "True") { 92 config.allow_low_precision = true; 93 } else { 94 config.allow_low_precision = false; 95 } 96} 97 98GenericList pack_outputs(const std::vector<TensorSpec>& output_specs, id<MLFeatureProvider> outputProvider) { 99 c10::List<torch::Tensor> outputs; 100 for (const TensorSpec& spec : output_specs) { 101 NSString *name = [NSString stringWithUTF8String:spec.name.c_str()]; 102 MLFeatureValue *val = [outputProvider featureValueForName:name]; 103 std::vector<int64_t> output_shape; 104 for (int i = 0; i < val.multiArrayValue.shape.count; ++i) { 105 output_shape.emplace_back(val.multiArrayValue.shape[i].integerValue); 106 } 107 TORCH_CHECK(val.multiArrayValue.dataType == MLMultiArrayDataTypeFloat32, "Core ML backend unexpected output data type"); 108 int64_t count = val.multiArrayValue.count; 109 float* temp = static_cast<float*>(std::malloc(count * sizeof(float))); 110 if (@available(iOS 15.4, *)) { 111 [val.multiArrayValue getBytesWithHandler:^(const void * _Nonnull bytes, NSInteger size) { 112 memcpy(temp, (float *)bytes, count * sizeof(float)); 113 }]; 114 } else { 115 memcpy(temp, (float *)val.multiArrayValue.dataPointer, count * sizeof(float)); 116 } 117 auto tensor = at::from_blob(temp, output_shape, [&](void* ptr) { std::free(ptr); }, TensorOptions().dtype(at::kFloat)); 118 outputs.push_back(std::move(tensor)); 119 } 120 if(output_specs.size() > 1){ 121 c10::List<c10::List<torch::Tensor>> output_res; 122 output_res.push_back(std::move(outputs)); 123 return c10::impl::toList(std::move(output_res)); 124 } 125 return c10::impl::toList(std::move(outputs)); 126} 127 128class CoreMLBackend: public torch::jit::PyTorchBackendInterface { 129 130 public: 131 GenericDict compile(IValue processed, GenericDict method_compile_spec) override { 132 const c10::Dict<IValue, IValue> model_dict = processed.toGenericDict(); 133 const std::string& extra = model_dict.at("extra").toStringRef(); 134 const std::string& model = model_dict.at("model").toStringRef(); 135 const std::string modelID = std::string(model_dict.at("hash").toStringRef()); 136 137 CoreMLConfig config; 138 std::vector<TensorSpec> input_specs; 139 std::vector<TensorSpec> output_specs; 140 141 try { 142 nlohmann::json extra_json = nlohmann::json::parse(extra); 143 config = extra_json["config"].get<CoreMLConfig>(); 144 input_specs = extra_json["inputs"].get<std::vector<TensorSpec>>(); 145 output_specs = extra_json["outputs"].get<std::vector<TensorSpec>>(); 146 } catch (std::exception& exn) { 147 TORCH_CHECK(false, "Parsing model dict failed!"); 148 } 149 150 if (!type_validity(input_specs) || !type_validity(output_specs)) { 151 TORCH_CHECK(false, "Compiling model failed, only float type tensors supported"); 152 } 153 154 if (![PTMCoreMLCompiler compileModel:model modelID:modelID]) { 155 TORCH_CHECK(false, "Compiling MLModel failed"); 156 } 157 158 NSError *error = nil; 159 MLModel *cpuModel = [PTMCoreMLCompiler loadModel:modelID backend:"cpu" allowLowPrecision:NO error:&error]; 160 161 if (!cpuModel) { 162 COREML_THROW_IF_ERROR(error, "Error loading MLModel", " Model spec: ", extra.c_str(), ", Model Hash: ", modelID.c_str()); 163 } 164 165 NSMutableArray *orderedFeatures = [NSMutableArray array]; 166 for (TensorSpec& spec : input_specs) { 167 NSString *name = [NSString stringWithUTF8String:spec.name.c_str()]; 168 [orderedFeatures addObject:name]; 169 } 170 171 PTMCoreMLExecutor *executor = [[PTMCoreMLExecutor alloc] initWithFeatureNames:orderedFeatures]; 172 executor.model = cpuModel; 173 [executor autorelease]; 174 175 dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT, 0), ^{ 176 NSError *error = nil; 177 MLModel *configuredModel = [PTMCoreMLCompiler loadModel:modelID backend:config.backend allowLowPrecision:config.allow_low_precision error:&error]; 178 // If we fail to configure the model, fall back to CPU 179 executor.model = configuredModel ?: cpuModel; 180 }); 181 182 MLModelWrapper model_wrapper = MLModelWrapper(executor); 183 model_wrapper.outputs = output_specs; 184 185 auto model_wrapper_ptr = c10::make_intrusive<MLModelWrapper>(model_wrapper); 186 auto handle = IValue::make_capsule(model_wrapper_ptr); 187 188 c10::Dict<IValue, IValue> ret(StringType::get(), c10::AnyType::get()); 189 ret.insert("forward", handle); 190 return c10::impl::toGenericDict(ret); 191 } 192 193 GenericList execute(IValue handle, GenericList inputs) override { 194 @autoreleasepool { 195 const auto model_wrapper = c10::static_intrusive_pointer_cast<MLModelWrapper>(handle.toCapsule()); 196 197 PTMCoreMLExecutor *executor = model_wrapper->executor; 198 [executor setInputs:inputs]; 199 200 NSError *error = nil; 201 id<MLFeatureProvider> outputsProvider = [executor forward:&error]; 202 if (!outputsProvider) { 203 COREML_THROW_IF_ERROR(error, "Error running CoreML inference", " Input Shape:", tensorListToShapesStr(inputs)); 204 } 205 206 return pack_outputs(model_wrapper->outputs, outputsProvider); 207 } 208 } 209 210 bool is_available() override { 211#if TARGET_OS_IPHONE 212 return [UIDevice currentDevice].systemVersion.floatValue >= 12.0; 213#elif TARGET_OS_MAC 214 NSOperatingSystemVersion supportedVer = {10, 13, 0}; 215 return [[NSProcessInfo processInfo] isOperatingSystemAtLeastVersion:supportedVer]; 216#endif 217 return false; 218 } 219}; 220 221static auto cls = torch::jit::backend<CoreMLBackend>("coreml"); 222 223struct PTMCoreMLContext : public ContextInterface { 224 void setModelCacheDirectory(std::string dir) override { 225 [PTMCoreMLCompiler setCacheDirectory:dir]; 226 } 227}; 228 229static BackendRegistrar g_coreml_backend(new PTMCoreMLContext()); 230 231} // namespace 232} 233} 234} 235