xref: /aosp_15_r20/external/android-nn-driver/1.2/ArmnnDriver.hpp (revision 3e777be0405cee09af5d5785ff37f7cfb5bee59a)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <HalInterfaces.h>
9 
10 #include "../ArmnnDevice.hpp"
11 #include "ArmnnDriverImpl.hpp"
12 #include "HalPolicy.hpp"
13 
14 #include "../ArmnnDriverImpl.hpp"
15 #include "../1.2/ArmnnDriverImpl.hpp"
16 #include "../1.2/HalPolicy.hpp"
17 #include "../1.1/ArmnnDriverImpl.hpp"
18 #include "../1.1/HalPolicy.hpp"
19 #include "../1.0/ArmnnDriverImpl.hpp"
20 #include "../1.0/HalPolicy.hpp"
21 
22 #include <armnn/BackendHelper.hpp>
23 
24 #include <log/log.h>
25 
26 namespace armnn_driver
27 {
28 namespace hal_1_2
29 {
30 
31 class ArmnnDriver : public ArmnnDevice, public V1_2::IDevice
32 {
33 public:
34 
ArmnnDriver(DriverOptions options)35     ArmnnDriver(DriverOptions options)
36         : ArmnnDevice(std::move(options))
37     {
38         ALOGV("hal_1_2::ArmnnDriver::ArmnnDriver()");
39     }
~ArmnnDriver()40     ~ArmnnDriver() {}
41 
42     using HidlToken = android::hardware::hidl_array<uint8_t, ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN>;
43 
44 public:
getCapabilities(V1_0::IDevice::getCapabilities_cb cb)45     Return<void> getCapabilities(V1_0::IDevice::getCapabilities_cb cb) override
46     {
47         ALOGV("hal_1_2::ArmnnDriver::getCapabilities()");
48 
49         return hal_1_0::ArmnnDriverImpl::getCapabilities(m_Runtime, cb);
50     }
51 
getSupportedOperations(const V1_0::Model & model,V1_0::IDevice::getSupportedOperations_cb cb)52     Return<void> getSupportedOperations(const V1_0::Model& model,
53                                         V1_0::IDevice::getSupportedOperations_cb cb) override
54     {
55         ALOGV("hal_1_2::ArmnnDriver::getSupportedOperations()");
56 
57         return armnn_driver::ArmnnDriverImpl<hal_1_0::HalPolicy>::getSupportedOperations(m_Runtime,
58                                                                                          m_Options,
59                                                                                          model,
60                                                                                          cb);
61     }
62 
prepareModel(const V1_0::Model & model,const android::sp<V1_0::IPreparedModelCallback> & cb)63     Return<V1_0::ErrorStatus> prepareModel(const V1_0::Model& model,
64                                            const android::sp<V1_0::IPreparedModelCallback>& cb) override
65     {
66         ALOGV("hal_1_2::ArmnnDriver::prepareModel()");
67 
68         return armnn_driver::ArmnnDriverImpl<hal_1_0::HalPolicy>::prepareModel(m_Runtime,
69                                                                                m_ClTunedParameters,
70                                                                                m_Options,
71                                                                                model,
72                                                                                cb);
73     }
74 
getCapabilities_1_1(V1_1::IDevice::getCapabilities_1_1_cb cb)75     Return<void> getCapabilities_1_1(V1_1::IDevice::getCapabilities_1_1_cb cb) override
76     {
77         ALOGV("hal_1_2::ArmnnDriver::getCapabilities_1_1()");
78 
79         return hal_1_1::ArmnnDriverImpl::getCapabilities_1_1(m_Runtime, cb);
80     }
81 
getSupportedOperations_1_1(const V1_1::Model & model,V1_1::IDevice::getSupportedOperations_1_1_cb cb)82     Return<void> getSupportedOperations_1_1(const V1_1::Model& model,
83                                             V1_1::IDevice::getSupportedOperations_1_1_cb cb) override
84     {
85         ALOGV("hal_1_2::ArmnnDriver::getSupportedOperations_1_1()");
86         return armnn_driver::ArmnnDriverImpl<hal_1_1::HalPolicy>::getSupportedOperations(m_Runtime,
87                                                                                          m_Options,
88                                                                                          model,
89                                                                                          cb);
90     }
91 
prepareModel_1_1(const V1_1::Model & model,V1_1::ExecutionPreference preference,const android::sp<V1_0::IPreparedModelCallback> & cb)92     Return<V1_0::ErrorStatus> prepareModel_1_1(const V1_1::Model& model,
93                                                V1_1::ExecutionPreference preference,
94                                                const android::sp<V1_0::IPreparedModelCallback>& cb) override
95     {
96         ALOGV("hal_1_2::ArmnnDriver::prepareModel_1_1()");
97 
98         if (!(preference == V1_1::ExecutionPreference::LOW_POWER ||
99               preference == V1_1::ExecutionPreference::FAST_SINGLE_ANSWER ||
100               preference == V1_1::ExecutionPreference::SUSTAINED_SPEED))
101         {
102             ALOGV("hal_1_2::ArmnnDriver::prepareModel_1_1: Invalid execution preference");
103             cb->notify(V1_0::ErrorStatus::INVALID_ARGUMENT, nullptr);
104             return V1_0::ErrorStatus::INVALID_ARGUMENT;
105         }
106 
107         return armnn_driver::ArmnnDriverImpl<hal_1_1::HalPolicy>::prepareModel(m_Runtime,
108                                                                                m_ClTunedParameters,
109                                                                                m_Options,
110                                                                                model,
111                                                                                cb,
112                                                                                model.relaxComputationFloat32toFloat16
113                                                                                && m_Options.GetFp16Enabled());
114     }
115 
getStatus()116     Return<V1_0::DeviceStatus> getStatus() override
117     {
118         ALOGV("hal_1_2::ArmnnDriver::getStatus()");
119 
120         return armnn_driver::ArmnnDriverImpl<hal_1_2::HalPolicy>::getStatus();
121     }
122 
getVersionString(getVersionString_cb cb)123     Return<void> getVersionString(getVersionString_cb cb)
124     {
125         ALOGV("hal_1_2::ArmnnDriver::getVersionString()");
126 
127         cb(V1_0::ErrorStatus::NONE, "ArmNN");
128         return Void();
129     }
130 
getType(getType_cb cb)131     Return<void> getType(getType_cb cb)
132     {
133         ALOGV("hal_1_2::ArmnnDriver::getType()");
134         const auto device_type = hal_1_2::HalPolicy::GetDeviceTypeFromOptions(this->m_Options);
135         cb(V1_0::ErrorStatus::NONE, device_type);
136         return Void();
137     }
138 
prepareModelFromCache(const android::hardware::hidl_vec<android::hardware::hidl_handle> & modelCacheHandle,const android::hardware::hidl_vec<android::hardware::hidl_handle> & dataCacheHandle,const HidlToken & token,const android::sp<V1_2::IPreparedModelCallback> & cb)139     Return<V1_0::ErrorStatus> prepareModelFromCache(
140         const android::hardware::hidl_vec<android::hardware::hidl_handle>& modelCacheHandle,
141         const android::hardware::hidl_vec<android::hardware::hidl_handle>& dataCacheHandle,
142         const HidlToken& token,
143         const android::sp<V1_2::IPreparedModelCallback>& cb)
144     {
145         ALOGV("hal_1_2::ArmnnDriver::prepareModelFromCache()");
146         return ArmnnDriverImpl::prepareModelFromCache(m_Runtime,
147                                                       m_Options,
148                                                       modelCacheHandle,
149                                                       dataCacheHandle,
150                                                       token,
151                                                       cb);
152     }
153 
prepareModel_1_2(const V1_2::Model & model,V1_1::ExecutionPreference preference,const android::hardware::hidl_vec<android::hardware::hidl_handle> & modelCacheHandle,const android::hardware::hidl_vec<android::hardware::hidl_handle> & dataCacheHandle,const HidlToken & token,const android::sp<V1_2::IPreparedModelCallback> & cb)154     Return<V1_0::ErrorStatus> prepareModel_1_2(
155         const V1_2::Model& model, V1_1::ExecutionPreference preference,
156         const android::hardware::hidl_vec<android::hardware::hidl_handle>& modelCacheHandle,
157         const android::hardware::hidl_vec<android::hardware::hidl_handle>& dataCacheHandle,
158         const HidlToken& token,
159         const android::sp<V1_2::IPreparedModelCallback>& cb)
160     {
161         ALOGV("hal_1_2::ArmnnDriver::prepareModel_1_2()");
162 
163         if (!(preference == V1_1::ExecutionPreference::LOW_POWER ||
164               preference == V1_1::ExecutionPreference::FAST_SINGLE_ANSWER ||
165               preference == V1_1::ExecutionPreference::SUSTAINED_SPEED))
166         {
167             ALOGV("hal_1_2::ArmnnDriver::prepareModel_1_2: Invalid execution preference");
168             cb->notify(V1_0::ErrorStatus::INVALID_ARGUMENT, nullptr);
169             return V1_0::ErrorStatus::INVALID_ARGUMENT;
170         }
171 
172         return ArmnnDriverImpl::prepareArmnnModel_1_2(m_Runtime,
173                                                       m_ClTunedParameters,
174                                                       m_Options,
175                                                       model,
176                                                       modelCacheHandle,
177                                                       dataCacheHandle,
178                                                       token,
179                                                       cb,
180                                                       model.relaxComputationFloat32toFloat16
181                                                       && m_Options.GetFp16Enabled());
182     }
183 
getSupportedExtensions(getSupportedExtensions_cb cb)184     Return<void> getSupportedExtensions(getSupportedExtensions_cb cb)
185     {
186         ALOGV("hal_1_2::ArmnnDriver::getSupportedExtensions()");
187         cb(V1_0::ErrorStatus::NONE, {/* No extensions. */});
188         return Void();
189     }
190 
getCapabilities_1_2(getCapabilities_1_2_cb cb)191     Return<void> getCapabilities_1_2(getCapabilities_1_2_cb cb)
192     {
193         ALOGV("hal_1_2::ArmnnDriver::getCapabilities()");
194 
195         return hal_1_2::ArmnnDriverImpl::getCapabilities_1_2(m_Runtime, cb);
196     }
197 
getSupportedOperations_1_2(const V1_2::Model & model,getSupportedOperations_1_2_cb cb)198     Return<void> getSupportedOperations_1_2(const V1_2::Model& model,
199                                             getSupportedOperations_1_2_cb cb)
200     {
201         ALOGV("hal_1_2::ArmnnDriver::getSupportedOperations()");
202 
203         return armnn_driver::ArmnnDriverImpl<hal_1_2::HalPolicy>::getSupportedOperations(m_Runtime,
204                                                                                          m_Options,
205                                                                                          model,
206                                                                                          cb);
207     }
208 
getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb)209     Return<void> getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb)
210     {
211         ALOGV("hal_1_2::ArmnnDriver::getSupportedExtensions()");
212         unsigned int numberOfCachedModelFiles = 0;
213         for (auto& backend : m_Options.GetBackends())
214         {
215             numberOfCachedModelFiles += GetNumberOfCacheFiles(backend);
216         }
217         cb(V1_0::ErrorStatus::NONE, numberOfCachedModelFiles,   1ul);
218         return Void();
219     }
220 };
221 
222 } // namespace hal_1_2
223 } // namespace armnn_driver
224