xref: /aosp_15_r20/external/executorch/backends/apple/coreml/runtime/delegate/ETCoreMLAsset.mm (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1//
2// ETCoreMLAsset.mm
3//
4// Copyright © 2024 Apple Inc. All rights reserved.
5//
6// Please refer to the license found in the LICENSE file in the root directory of the source tree.
7
8#import <ETCoreMLAsset.h>
9
10#import <fcntl.h>
11#import <os/lock.h>
12#import <stdio.h>
13#import <system_error>
14
15#import <objc_safe_cast.h>
16
17namespace  {
18using namespace executorchcoreml;
19
20NSDate * _Nullable get_content_modification_date(NSURL *url, NSError * __autoreleasing *error) {
21    NSDate *result = nil;
22    if (![url getResourceValue:&result forKey:NSURLContentModificationDateKey error:error]) {
23        return nil;
24    }
25
26    return SAFE_CAST(result, NSDate);
27}
28
29bool is_asset_valid(const Asset& asset) {
30    NSURL *asset_url = [NSURL fileURLWithPath:@(asset.path.c_str())];
31    for (const auto& file_info : asset.package_info.file_infos) {
32        NSError *local_error = nil;
33        const std::string& relative_path = file_info.relative_path;
34        NSURL *file_url = [asset_url URLByAppendingPathComponent:@(relative_path.c_str())];
35
36        NSDate *last_modification_date = get_content_modification_date(file_url, &local_error);
37        if (!last_modification_date) {
38            return false;
39        }
40
41        int64_t last_modification_time_interval = static_cast<int64_t>(last_modification_date.timeIntervalSince1970 * 1000);
42        if (last_modification_time_interval != file_info.last_modification_time_interval) {
43            return false;
44        }
45    }
46
47    return true;
48}
49
50void set_error_from_error_code(const std::error_code& cppError, NSError * __autoreleasing *error) {
51    if (!error || !cppError) {
52        return;
53    }
54
55    NSString *message = @(cppError.message().c_str());
56    NSString *domain =  @(cppError.category().name());
57    NSInteger code = cppError.value();
58    NSError *localError = [NSError errorWithDomain:domain code:code userInfo:@{NSLocalizedDescriptionKey : message}];
59    *error = localError;
60}
61} //namespace
62
63@implementation ETCoreMLAsset {
64    executorchcoreml::Asset _backingAsset;
65    std::vector<std::unique_ptr<FILE, decltype(&fclose)>> _openFiles;
66    os_unfair_lock _lock;
67}
68
69- (instancetype)initWithBackingAsset:(executorchcoreml::Asset)backingAsset {
70    self = [super init];
71    if (self) {
72        _isValid = static_cast<BOOL>(is_asset_valid(backingAsset));
73        _identifier = @(backingAsset.identifier.c_str());
74        _contentURL = [NSURL fileURLWithPath:@(backingAsset.path.c_str())];
75        _totalSizeInBytes = backingAsset.total_size_in_bytes();
76        _backingAsset = std::move(backingAsset);
77    }
78
79    return self;
80}
81
82- (void)dealloc {
83    [self close];
84}
85
86- (BOOL)_keepAliveAndReturnError:(NSError * __autoreleasing *)error {
87    if (!_isValid) {
88        return NO;
89    }
90
91    const auto& fileInfos = _backingAsset.package_info.file_infos;
92    if (_openFiles.size() == fileInfos.size()) {
93        return YES;
94    }
95
96    std::vector<std::unique_ptr<FILE, decltype(&fclose)>> openFiles;
97    for (const auto& fileInfo : fileInfos) {
98        NSURL *fileURL = [NSURL fileURLWithPath:@(fileInfo.relative_path.c_str()) relativeToURL:self.contentURL];
99        std::unique_ptr<FILE, decltype(&fclose)> file(fopen(fileURL.path.UTF8String, "rb"), fclose);
100        if (file == nullptr) {
101            ::set_error_from_error_code(std::error_code(errno, std::generic_category()), error);
102            break;
103        }
104        openFiles.emplace_back(std::move(file));
105    }
106
107    BOOL success = (openFiles.size() == fileInfos.size());
108    if (success) {
109        _openFiles = std::move(openFiles);
110    }
111
112    return success;
113}
114
115- (BOOL)isAlive {
116    BOOL result = NO;
117    {
118        os_unfair_lock_lock(&_lock);
119        const auto& fileInfos = _backingAsset.package_info.file_infos;
120        result = (_openFiles.size() == fileInfos.size());
121        os_unfair_lock_unlock(&_lock);
122    }
123
124    return result;
125}
126
127- (BOOL)keepAliveAndReturnError:(NSError * __autoreleasing *)error {
128    BOOL result = NO;
129    {
130        os_unfair_lock_lock(&_lock);
131        result = [self _keepAliveAndReturnError:error];
132        os_unfair_lock_unlock(&_lock);
133    }
134
135    return result;
136}
137
138- (BOOL)prewarmAndReturnError:(NSError * __autoreleasing *)error {
139    std::vector<int> fds;
140    {
141        os_unfair_lock_lock(&_lock);
142        if ([self _keepAliveAndReturnError:error]) {
143            for (const auto& file : _openFiles) {
144                fds.emplace_back(fileno(file.get()));
145            }
146        }
147        os_unfair_lock_unlock(&_lock);
148    }
149
150    const auto& fileInfos = _backingAsset.package_info.file_infos;
151    if (fds.size() != fileInfos.size()) {
152        return NO;
153    }
154
155    for (size_t i = 0; i < fileInfos.size(); i++) {
156        const auto& fileInfo = fileInfos[i];
157        size_t sizeInBytes = fileInfo.size_in_bytes;
158        struct radvisory advisory = { .ra_offset = 0, .ra_count = (int)sizeInBytes };
159        int fd = fds[i];
160        int status = fcntl(fd, F_RDADVISE, &advisory);
161        if (status == -1) {
162            ::set_error_from_error_code(std::error_code(errno, std::system_category()), error);
163            return NO;
164        }
165    }
166
167    return YES;
168}
169
170- (void)close {
171    os_unfair_lock_lock(&_lock);
172    _openFiles.clear();
173    os_unfair_lock_unlock(&_lock);
174}
175
176@end
177