xref: /aosp_15_r20/external/executorch/backends/apple/coreml/runtime/delegate/coreml_backend_delegate.mm (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1//
2// coreml_backend_delegate.mm
3//
4// Copyright © 2024 Apple Inc. All rights reserved.
5//
6// Please refer to the license found in the LICENSE file in the root directory of the source tree.
7
8#import <ETCoreMLLogging.h>
9#import <ETCoreMLModel.h>
10#import <ETCoreMLStrings.h>
11#import <backend_delegate.h>
12#import <coreml_backend/delegate.h>
13#import <executorch/runtime/core/evalue.h>
14#import <executorch/runtime/platform/log.h>
15#import <memory>
16#import <model_event_logger.h>
17#import <model_logging_options.h>
18#import <multiarray.h>
19#import <objc_safe_cast.h>
20#import <unordered_map>
21#import <vector>
22
23#ifdef ET_EVENT_TRACER_ENABLED
24#import <model_event_logger_impl.h>
25#endif
26
27namespace {
28using namespace executorchcoreml;
29
30using executorch::aten::ScalarType;
31using executorch::runtime::ArrayRef;
32using executorch::runtime::Backend;
33using executorch::runtime::BackendExecutionContext;
34using executorch::runtime::BackendInitContext;
35using executorch::runtime::CompileSpec;
36using executorch::runtime::DelegateHandle;
37using executorch::runtime::EValue;
38using executorch::runtime::Error;
39using executorch::runtime::EventTracerDebugLogLevel;
40using executorch::runtime::FreeableBuffer;
41using executorch::runtime::get_backend_class;
42using executorch::runtime::Result;
43
44std::optional<MultiArray::DataType> get_data_type(ScalarType scalar_type) {
45    switch (scalar_type) {
46        case ScalarType::Bool:
47            return MultiArray::DataType::Bool;
48        case ScalarType::Byte:
49            return MultiArray::DataType::Byte;
50        case ScalarType::Short:
51            return MultiArray::DataType::Short;
52        case ScalarType::Int:
53            return MultiArray::DataType::Int32;
54        case ScalarType::Long:
55            return MultiArray::DataType::Int64;
56        case ScalarType::Half:
57            return MultiArray::DataType::Float16;
58        case ScalarType::Float:
59            return MultiArray::DataType::Float32;
60        case ScalarType::Double:
61            return MultiArray::DataType::Float64;
62        default:
63            return std::nullopt;
64    }
65}
66
67enum class ArgType: uint8_t {
68    Input,
69    Output
70};
71
72std::optional<MultiArray> get_multi_array(EValue *eValue, ArgType argType) {
73    if (!eValue->isTensor()) {
74        return std::nullopt;
75    }
76
77    auto tensor = eValue->toTensor();
78    auto dataType = get_data_type(tensor.scalar_type());
79    if (!dataType.has_value()) {
80        ET_LOG(Error, "%s: DataType=%d is not supported", ETCoreMLStrings.delegateIdentifier.UTF8String, (int)tensor.scalar_type());
81        return std::nullopt;
82    }
83
84    std::vector<ssize_t> strides(tensor.strides().begin(), tensor.strides().end());
85    std::vector<size_t> shape(tensor.sizes().begin(), tensor.sizes().end());
86    MultiArray::MemoryLayout layout(dataType.value(), std::move(shape), std::move(strides));
87    switch (argType) {
88        case ArgType::Input: {
89            return MultiArray(const_cast<void *>(tensor.const_data_ptr()), layout);
90        }
91        case ArgType::Output: {
92            return MultiArray(tensor.mutable_data_ptr(), layout);
93        }
94    }
95}
96
97std::optional<BackendDelegate::Config> parse_config(NSURL *plistURL) {
98    NSDictionary<NSString *, id> *dict = [NSDictionary dictionaryWithContentsOfURL:plistURL];
99    if (!dict) {
100        return std::nullopt;
101    }
102
103    BackendDelegate::Config config;
104    {
105        NSNumber *should_prewarm_model = SAFE_CAST(dict[@"shouldPrewarmModel"], NSNumber);
106        if (should_prewarm_model) {
107            config.should_prewarm_model = static_cast<bool>(should_prewarm_model.boolValue);
108        }
109    }
110
111    {
112        NSNumber *should_prewarm_asset = SAFE_CAST(dict[@"shouldPrewarmAsset"], NSNumber);
113        if (should_prewarm_asset) {
114            config.should_prewarm_asset = static_cast<bool>(should_prewarm_asset.boolValue);
115        }
116    }
117
118    {
119        NSNumber *max_models_cache_size_in_bytes = SAFE_CAST(dict[@"maxModelsCacheSizeInBytes"], NSNumber);
120        if (max_models_cache_size_in_bytes) {
121            config.max_models_cache_size = max_models_cache_size_in_bytes.unsignedLongLongValue;
122        }
123    }
124
125    return config;
126}
127
128BackendDelegate::Config get_delegate_config(NSString *config_name) {
129    NSURL *config_url = [NSBundle.mainBundle URLForResource:config_name withExtension:@"plist"];
130    config_url = config_url ?: [[NSBundle bundleForClass:ETCoreMLModel.class] URLForResource:config_name withExtension:@"plist"];
131    auto config = parse_config(config_url);
132    return config.has_value() ? config.value() : BackendDelegate::Config();
133}
134
135ModelLoggingOptions get_logging_options(BackendExecutionContext& context) {
136    ModelLoggingOptions options;
137    auto event_tracer = context.event_tracer();
138    if (event_tracer) {
139        options.log_profiling_info = true;
140        auto debug_level = event_tracer->event_tracer_debug_level();
141        options.log_intermediate_tensors = (debug_level >= EventTracerDebugLogLevel::kIntermediateOutputs);
142    }
143
144    return options;
145}
146
147} //namespace
148
149namespace executorch {
150namespace backends {
151namespace coreml {
152
153using namespace executorchcoreml;
154
155CoreMLBackendDelegate::CoreMLBackendDelegate() noexcept
156:impl_(BackendDelegate::make(get_delegate_config(ETCoreMLStrings.configPlistName)))
157{}
158
159Result<DelegateHandle *>
160CoreMLBackendDelegate::init(BackendInitContext& context,
161                            FreeableBuffer* processed,
162                            ArrayRef<CompileSpec> specs) const {
163    ET_LOG(Debug, "%s: init called.", ETCoreMLStrings.delegateIdentifier.UTF8String);
164    std::unordered_map<std::string, Buffer> specs_map;
165    specs_map.reserve(specs.size());
166    for (auto it = specs.cbegin(); it != specs.cend(); ++it) {
167        auto& spec = *(it);
168        auto buffer = Buffer(spec.value.buffer, spec.value.nbytes);
169        specs_map.emplace(spec.key, std::move(buffer));
170    }
171
172    auto buffer = Buffer(processed->data(), processed->size());
173    std::error_code error;
174    auto handle = impl_->init(std::move(buffer), specs_map);
175    ET_CHECK_OR_RETURN_ERROR(handle != nullptr,
176                             InvalidProgram,
177                             "%s: Failed to init the model.", ETCoreMLStrings.delegateIdentifier.UTF8String);
178    processed->Free();
179    return handle;
180}
181
182Error CoreMLBackendDelegate::execute(BackendExecutionContext& context,
183                                     DelegateHandle* handle,
184                                     EValue** args) const {
185    const auto& nArgs = impl_->get_num_arguments(handle);
186    std::vector<MultiArray> delegate_args;
187    size_t nInputs = nArgs.first;
188    size_t nOutputs = nArgs.second;
189    delegate_args.reserve(nInputs + nOutputs);
190
191    // inputs
192    for (size_t i = 0; i < nInputs; i++) {
193        auto multi_array = get_multi_array(args[i], ArgType::Input);
194        ET_CHECK_OR_RETURN_ERROR(multi_array.has_value(),
195                                 Internal,
196                                 "%s: Failed to create multiarray from input at args[%zu]", ETCoreMLStrings.delegateIdentifier.UTF8String, i);
197        delegate_args.emplace_back(std::move(multi_array.value()));
198    }
199
200    // outputs
201    for (size_t i = nInputs; i < nInputs + nOutputs; i++) {
202        auto multi_array = get_multi_array(args[i], ArgType::Output);
203        ET_CHECK_OR_RETURN_ERROR(multi_array.has_value(),
204                                 Internal,
205                                 "%s: Failed to create multiarray from output at args[%zu]", ETCoreMLStrings.delegateIdentifier.UTF8String, i);
206        delegate_args.emplace_back(std::move(multi_array.value()));
207    }
208
209    auto logging_options = get_logging_options(context);
210    std::error_code ec;
211#ifdef ET_EVENT_TRACER_ENABLED
212    auto event_logger = ModelEventLoggerImpl(context.event_tracer());
213    ET_CHECK_OR_RETURN_ERROR(impl_->execute(handle, delegate_args, logging_options, &event_logger, ec),
214                             DelegateInvalidHandle,
215                             "%s: Failed to run the model.",
216                             ETCoreMLStrings.delegateIdentifier.UTF8String);
217#else
218    ET_CHECK_OR_RETURN_ERROR(impl_->execute(handle, delegate_args, logging_options, nullptr, ec),
219                             DelegateInvalidHandle,
220                             "%s: Failed to run the model.",
221                             ETCoreMLStrings.delegateIdentifier.UTF8String);
222#endif
223
224    return Error::Ok;
225}
226
227bool CoreMLBackendDelegate::is_available() const {
228    ET_LOG(Debug, "%s: is_available called.", ETCoreMLStrings.delegateIdentifier.UTF8String);
229    return impl_->is_available();
230}
231
232void CoreMLBackendDelegate::destroy(DelegateHandle* handle) const {
233    ET_LOG(Debug, "%s: destroy called.", ETCoreMLStrings.delegateIdentifier.UTF8String);
234    impl_->destroy(handle);
235}
236
237bool CoreMLBackendDelegate::purge_models_cache() const noexcept {
238    ET_LOG(Debug, "%s: purge_models_cache called.", ETCoreMLStrings.delegateIdentifier.UTF8String);
239    return impl_->purge_models_cache();
240}
241
242CoreMLBackendDelegate *CoreMLBackendDelegate::get_registered_delegate() noexcept {
243    return static_cast<CoreMLBackendDelegate *>(get_backend_class(ETCoreMLStrings.delegateIdentifier.UTF8String));
244}
245
246namespace {
247auto cls = CoreMLBackendDelegate();
248Backend backend{ETCoreMLStrings.delegateIdentifier.UTF8String, &cls};
249static auto success_with_compiler = register_backend(backend);
250}
251
252} // namespace coreml
253} // namespace backends
254} // namespace executorch
255