1// 2// coreml_backend_delegate.mm 3// 4// Copyright © 2024 Apple Inc. All rights reserved. 5// 6// Please refer to the license found in the LICENSE file in the root directory of the source tree. 7 8#import <ETCoreMLLogging.h> 9#import <ETCoreMLModel.h> 10#import <ETCoreMLStrings.h> 11#import <backend_delegate.h> 12#import <coreml_backend/delegate.h> 13#import <executorch/runtime/core/evalue.h> 14#import <executorch/runtime/platform/log.h> 15#import <memory> 16#import <model_event_logger.h> 17#import <model_logging_options.h> 18#import <multiarray.h> 19#import <objc_safe_cast.h> 20#import <unordered_map> 21#import <vector> 22 23#ifdef ET_EVENT_TRACER_ENABLED 24#import <model_event_logger_impl.h> 25#endif 26 27namespace { 28using namespace executorchcoreml; 29 30using executorch::aten::ScalarType; 31using executorch::runtime::ArrayRef; 32using executorch::runtime::Backend; 33using executorch::runtime::BackendExecutionContext; 34using executorch::runtime::BackendInitContext; 35using executorch::runtime::CompileSpec; 36using executorch::runtime::DelegateHandle; 37using executorch::runtime::EValue; 38using executorch::runtime::Error; 39using executorch::runtime::EventTracerDebugLogLevel; 40using executorch::runtime::FreeableBuffer; 41using executorch::runtime::get_backend_class; 42using executorch::runtime::Result; 43 44std::optional<MultiArray::DataType> get_data_type(ScalarType scalar_type) { 45 switch (scalar_type) { 46 case ScalarType::Bool: 47 return MultiArray::DataType::Bool; 48 case ScalarType::Byte: 49 return MultiArray::DataType::Byte; 50 case ScalarType::Short: 51 return MultiArray::DataType::Short; 52 case ScalarType::Int: 53 return MultiArray::DataType::Int32; 54 case ScalarType::Long: 55 return MultiArray::DataType::Int64; 56 case ScalarType::Half: 57 return MultiArray::DataType::Float16; 58 case ScalarType::Float: 59 return MultiArray::DataType::Float32; 60 case ScalarType::Double: 61 return MultiArray::DataType::Float64; 62 default: 63 return std::nullopt; 64 } 65} 66 67enum class ArgType: uint8_t { 68 Input, 69 Output 70}; 71 72std::optional<MultiArray> get_multi_array(EValue *eValue, ArgType argType) { 73 if (!eValue->isTensor()) { 74 return std::nullopt; 75 } 76 77 auto tensor = eValue->toTensor(); 78 auto dataType = get_data_type(tensor.scalar_type()); 79 if (!dataType.has_value()) { 80 ET_LOG(Error, "%s: DataType=%d is not supported", ETCoreMLStrings.delegateIdentifier.UTF8String, (int)tensor.scalar_type()); 81 return std::nullopt; 82 } 83 84 std::vector<ssize_t> strides(tensor.strides().begin(), tensor.strides().end()); 85 std::vector<size_t> shape(tensor.sizes().begin(), tensor.sizes().end()); 86 MultiArray::MemoryLayout layout(dataType.value(), std::move(shape), std::move(strides)); 87 switch (argType) { 88 case ArgType::Input: { 89 return MultiArray(const_cast<void *>(tensor.const_data_ptr()), layout); 90 } 91 case ArgType::Output: { 92 return MultiArray(tensor.mutable_data_ptr(), layout); 93 } 94 } 95} 96 97std::optional<BackendDelegate::Config> parse_config(NSURL *plistURL) { 98 NSDictionary<NSString *, id> *dict = [NSDictionary dictionaryWithContentsOfURL:plistURL]; 99 if (!dict) { 100 return std::nullopt; 101 } 102 103 BackendDelegate::Config config; 104 { 105 NSNumber *should_prewarm_model = SAFE_CAST(dict[@"shouldPrewarmModel"], NSNumber); 106 if (should_prewarm_model) { 107 config.should_prewarm_model = static_cast<bool>(should_prewarm_model.boolValue); 108 } 109 } 110 111 { 112 NSNumber *should_prewarm_asset = SAFE_CAST(dict[@"shouldPrewarmAsset"], NSNumber); 113 if (should_prewarm_asset) { 114 config.should_prewarm_asset = static_cast<bool>(should_prewarm_asset.boolValue); 115 } 116 } 117 118 { 119 NSNumber *max_models_cache_size_in_bytes = SAFE_CAST(dict[@"maxModelsCacheSizeInBytes"], NSNumber); 120 if (max_models_cache_size_in_bytes) { 121 config.max_models_cache_size = max_models_cache_size_in_bytes.unsignedLongLongValue; 122 } 123 } 124 125 return config; 126} 127 128BackendDelegate::Config get_delegate_config(NSString *config_name) { 129 NSURL *config_url = [NSBundle.mainBundle URLForResource:config_name withExtension:@"plist"]; 130 config_url = config_url ?: [[NSBundle bundleForClass:ETCoreMLModel.class] URLForResource:config_name withExtension:@"plist"]; 131 auto config = parse_config(config_url); 132 return config.has_value() ? config.value() : BackendDelegate::Config(); 133} 134 135ModelLoggingOptions get_logging_options(BackendExecutionContext& context) { 136 ModelLoggingOptions options; 137 auto event_tracer = context.event_tracer(); 138 if (event_tracer) { 139 options.log_profiling_info = true; 140 auto debug_level = event_tracer->event_tracer_debug_level(); 141 options.log_intermediate_tensors = (debug_level >= EventTracerDebugLogLevel::kIntermediateOutputs); 142 } 143 144 return options; 145} 146 147} //namespace 148 149namespace executorch { 150namespace backends { 151namespace coreml { 152 153using namespace executorchcoreml; 154 155CoreMLBackendDelegate::CoreMLBackendDelegate() noexcept 156:impl_(BackendDelegate::make(get_delegate_config(ETCoreMLStrings.configPlistName))) 157{} 158 159Result<DelegateHandle *> 160CoreMLBackendDelegate::init(BackendInitContext& context, 161 FreeableBuffer* processed, 162 ArrayRef<CompileSpec> specs) const { 163 ET_LOG(Debug, "%s: init called.", ETCoreMLStrings.delegateIdentifier.UTF8String); 164 std::unordered_map<std::string, Buffer> specs_map; 165 specs_map.reserve(specs.size()); 166 for (auto it = specs.cbegin(); it != specs.cend(); ++it) { 167 auto& spec = *(it); 168 auto buffer = Buffer(spec.value.buffer, spec.value.nbytes); 169 specs_map.emplace(spec.key, std::move(buffer)); 170 } 171 172 auto buffer = Buffer(processed->data(), processed->size()); 173 std::error_code error; 174 auto handle = impl_->init(std::move(buffer), specs_map); 175 ET_CHECK_OR_RETURN_ERROR(handle != nullptr, 176 InvalidProgram, 177 "%s: Failed to init the model.", ETCoreMLStrings.delegateIdentifier.UTF8String); 178 processed->Free(); 179 return handle; 180} 181 182Error CoreMLBackendDelegate::execute(BackendExecutionContext& context, 183 DelegateHandle* handle, 184 EValue** args) const { 185 const auto& nArgs = impl_->get_num_arguments(handle); 186 std::vector<MultiArray> delegate_args; 187 size_t nInputs = nArgs.first; 188 size_t nOutputs = nArgs.second; 189 delegate_args.reserve(nInputs + nOutputs); 190 191 // inputs 192 for (size_t i = 0; i < nInputs; i++) { 193 auto multi_array = get_multi_array(args[i], ArgType::Input); 194 ET_CHECK_OR_RETURN_ERROR(multi_array.has_value(), 195 Internal, 196 "%s: Failed to create multiarray from input at args[%zu]", ETCoreMLStrings.delegateIdentifier.UTF8String, i); 197 delegate_args.emplace_back(std::move(multi_array.value())); 198 } 199 200 // outputs 201 for (size_t i = nInputs; i < nInputs + nOutputs; i++) { 202 auto multi_array = get_multi_array(args[i], ArgType::Output); 203 ET_CHECK_OR_RETURN_ERROR(multi_array.has_value(), 204 Internal, 205 "%s: Failed to create multiarray from output at args[%zu]", ETCoreMLStrings.delegateIdentifier.UTF8String, i); 206 delegate_args.emplace_back(std::move(multi_array.value())); 207 } 208 209 auto logging_options = get_logging_options(context); 210 std::error_code ec; 211#ifdef ET_EVENT_TRACER_ENABLED 212 auto event_logger = ModelEventLoggerImpl(context.event_tracer()); 213 ET_CHECK_OR_RETURN_ERROR(impl_->execute(handle, delegate_args, logging_options, &event_logger, ec), 214 DelegateInvalidHandle, 215 "%s: Failed to run the model.", 216 ETCoreMLStrings.delegateIdentifier.UTF8String); 217#else 218 ET_CHECK_OR_RETURN_ERROR(impl_->execute(handle, delegate_args, logging_options, nullptr, ec), 219 DelegateInvalidHandle, 220 "%s: Failed to run the model.", 221 ETCoreMLStrings.delegateIdentifier.UTF8String); 222#endif 223 224 return Error::Ok; 225} 226 227bool CoreMLBackendDelegate::is_available() const { 228 ET_LOG(Debug, "%s: is_available called.", ETCoreMLStrings.delegateIdentifier.UTF8String); 229 return impl_->is_available(); 230} 231 232void CoreMLBackendDelegate::destroy(DelegateHandle* handle) const { 233 ET_LOG(Debug, "%s: destroy called.", ETCoreMLStrings.delegateIdentifier.UTF8String); 234 impl_->destroy(handle); 235} 236 237bool CoreMLBackendDelegate::purge_models_cache() const noexcept { 238 ET_LOG(Debug, "%s: purge_models_cache called.", ETCoreMLStrings.delegateIdentifier.UTF8String); 239 return impl_->purge_models_cache(); 240} 241 242CoreMLBackendDelegate *CoreMLBackendDelegate::get_registered_delegate() noexcept { 243 return static_cast<CoreMLBackendDelegate *>(get_backend_class(ETCoreMLStrings.delegateIdentifier.UTF8String)); 244} 245 246namespace { 247auto cls = CoreMLBackendDelegate(); 248Backend backend{ETCoreMLStrings.delegateIdentifier.UTF8String, &cls}; 249static auto success_with_compiler = register_backend(backend); 250} 251 252} // namespace coreml 253} // namespace backends 254} // namespace executorch 255