xref: /aosp_15_r20/external/armnn/shim/sl/canonical/ArmnnDriver.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <android-base/logging.h>
9 #include <nnapi/IBuffer.h>
10 #include <nnapi/IDevice.h>
11 #include <nnapi/IPreparedModel.h>
12 #include <nnapi/OperandTypes.h>
13 #include <nnapi/Result.h>
14 #include <nnapi/Types.h>
15 #include <nnapi/Validation.h>
16 
17 #include "ArmnnDevice.hpp"
18 #include "ArmnnDriverImpl.hpp"
19 #include "Converter.hpp"
20 
21 #include "ArmnnDriverImpl.hpp"
22 #include "ModelToINetworkTransformer.hpp"
23 
24 #include <armnn/Version.hpp>
25 #include <log/log.h>
26 namespace armnn_driver
27 {
28 
29 //using namespace android::nn;
30 
31 class ArmnnDriver : public IDevice
32 {
33 private:
34     std::unique_ptr<ArmnnDevice> m_Device;
35 public:
ArmnnDriver(DriverOptions options)36     ArmnnDriver(DriverOptions options)
37     {
38         try
39         {
40             VLOG(DRIVER) << "ArmnnDriver::ArmnnDriver()";
41             m_Device = std::unique_ptr<ArmnnDevice>(new ArmnnDevice(std::move(options)));
42         }
43         catch (armnn::InvalidArgumentException& ex)
44         {
45             VLOG(DRIVER) << "ArmnnDevice failed to initialise: " << ex.what();
46         }
47         catch (...)
48         {
49             VLOG(DRIVER) << "ArmnnDevice failed to initialise with an unknown error";
50         }
51     }
52 
53 public:
54 
getName() const55     const std::string& getName() const override
56     {
57         VLOG(DRIVER) << "ArmnnDriver::getName()";
58         static const std::string name = "arm-armnn-sl";
59         return name;
60     }
61 
getVersionString() const62     const std::string& getVersionString() const override
63     {
64         VLOG(DRIVER) << "ArmnnDriver::getVersionString()";
65         static const std::string versionString = ARMNN_VERSION;
66         return versionString;
67     }
68 
getFeatureLevel() const69     Version getFeatureLevel() const override
70     {
71         VLOG(DRIVER) << "ArmnnDriver::getFeatureLevel()";
72         return kVersionFeatureLevel6;
73     }
74 
getType() const75     DeviceType getType() const override
76     {
77         VLOG(DRIVER) << "ArmnnDriver::getType()";
78         return DeviceType::CPU;
79     }
80 
getSupportedExtensions() const81     const std::vector<Extension>& getSupportedExtensions() const override
82     {
83         VLOG(DRIVER) << "ArmnnDriver::getSupportedExtensions()";
84         static const std::vector<Extension> extensions = {};
85         return extensions;
86     }
87 
getCapabilities() const88     const Capabilities& getCapabilities() const override
89     {
90         VLOG(DRIVER) << "ArmnnDriver::GetCapabilities()";
91         return ArmnnDriverImpl::GetCapabilities(m_Device->m_Runtime);
92     }
93 
getNumberOfCacheFilesNeeded() const94     std::pair<uint32_t, uint32_t> getNumberOfCacheFilesNeeded() const override
95     {
96         VLOG(DRIVER) << "ArmnnDriver::getNumberOfCacheFilesNeeded()";
97         unsigned int numberOfCachedModelFiles = 0;
98         for (auto& backend : m_Device->m_Options.GetBackends())
99         {
100             numberOfCachedModelFiles += GetNumberOfCacheFiles(backend);
101             VLOG(DRIVER) << "ArmnnDriver::getNumberOfCacheFilesNeeded() = "
102                          << std::to_string(numberOfCachedModelFiles);
103         }
104         return std::make_pair(numberOfCachedModelFiles, 1ul);
105     }
106 
wait() const107     GeneralResult<void> wait() const override
108     {
109         VLOG(DRIVER) << "ArmnnDriver::wait()";
110         return {};
111     }
112 
getSupportedOperations(const Model & model) const113     GeneralResult<std::vector<bool>> getSupportedOperations(const Model& model) const override
114     {
115         VLOG(DRIVER) << "ArmnnDriver::getSupportedOperations()";
116         if (m_Device.get() == nullptr)
117         {
118             return NN_ERROR(ErrorStatus::DEVICE_UNAVAILABLE) << "Device Unavailable!";
119         }
120 
121         std::stringstream ss;
122         ss << "ArmnnDriverImpl::getSupportedOperations()";
123         std::string fileName;
124         std::string timestamp;
125         if (!m_Device->m_Options.GetRequestInputsAndOutputsDumpDir().empty())
126         {
127             ss << " : "
128                << m_Device->m_Options.GetRequestInputsAndOutputsDumpDir()
129                << "/"
130                // << GetFileTimestamp()
131                << "_getSupportedOperations.txt";
132         }
133         VLOG(DRIVER) << ss.str().c_str();
134 
135         if (!m_Device->m_Options.GetRequestInputsAndOutputsDumpDir().empty())
136         {
137             //dump the marker file
138             std::ofstream fileStream;
139             fileStream.open(fileName, std::ofstream::out | std::ofstream::trunc);
140             if (fileStream.good())
141             {
142                 fileStream << timestamp << std::endl;
143                 fileStream << timestamp << std::endl;
144             }
145             fileStream.close();
146         }
147 
148         std::vector<bool> result;
149         if (!m_Device->m_Runtime)
150         {
151             return NN_ERROR(ErrorStatus::DEVICE_UNAVAILABLE) << "Device Unavailable!";
152         }
153 
154         // Run general model validation, if this doesn't pass we shouldn't analyse the model anyway.
155         if (const auto result = validate(model); !result.ok())
156         {
157             return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Model!";
158         }
159 
160         // Attempt to convert the model to an ArmNN input network (INetwork).
161         ModelToINetworkTransformer modelConverter(m_Device->m_Options.GetBackends(),
162                                                   model,
163                                                   m_Device->m_Options.GetForcedUnsupportedOperations());
164 
165         if (modelConverter.GetConversionResult() != ConversionResult::Success
166             && modelConverter.GetConversionResult() != ConversionResult::UnsupportedFeature)
167         {
168             return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << "Conversion Error!";
169         }
170 
171         // Check each operation if it was converted successfully and copy the flags
172         // into the result (vector<bool>) that we need to return to Android.
173         result.reserve(model.main.operations.size());
174         for (uint32_t operationIdx = 0; operationIdx < model.main.operations.size(); ++operationIdx)
175         {
176             bool operationSupported = modelConverter.IsOperationSupported(operationIdx);
177             result.push_back(operationSupported);
178         }
179 
180         return result;
181     }
182 
prepareModel(const Model & model,ExecutionPreference preference,Priority priority,OptionalTimePoint deadline,const std::vector<SharedHandle> & modelCache,const std::vector<SharedHandle> & dataCache,const CacheToken & token,const std::vector<android::nn::TokenValuePair> & hints,const std::vector<android::nn::ExtensionNameAndPrefix> & extensionNameToPrefix) const183     GeneralResult<SharedPreparedModel> prepareModel(const Model& model,
184         ExecutionPreference preference,
185         Priority priority,
186         OptionalTimePoint deadline,
187         const std::vector<SharedHandle>& modelCache,
188         const std::vector<SharedHandle>& dataCache,
189         const CacheToken& token,
190         const std::vector<android::nn::TokenValuePair>& hints,
191         const std::vector<android::nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override
192     {
193         VLOG(DRIVER) << "ArmnnDriver::prepareModel()";
194 
195         if (m_Device.get() == nullptr)
196         {
197             return NN_ERROR(ErrorStatus::DEVICE_UNAVAILABLE) << "Device Unavailable!";
198         }
199         // Validate arguments.
200         if (const auto result = validate(model); !result.ok()) {
201             return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Model: " << result.error();
202         }
203         if (const auto result = validate(preference); !result.ok()) {
204             return NN_ERROR(ErrorStatus::INVALID_ARGUMENT)
205                 << "Invalid ExecutionPreference: " << result.error();
206         }
207         if (const auto result = validate(priority); !result.ok()) {
208             return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Priority: " << result.error();
209         }
210 
211         // Check if deadline has passed.
212         if (hasDeadlinePassed(deadline)) {
213             return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT);
214         }
215 
216         return ArmnnDriverImpl::PrepareArmnnModel(m_Device->m_Runtime,
217             m_Device->m_ClTunedParameters,
218             m_Device->m_Options,
219             model,
220             modelCache,
221             dataCache,
222             token,
223             model.relaxComputationFloat32toFloat16 && m_Device->m_Options.GetFp16Enabled(),
224             priority);
225     }
226 
prepareModelFromCache(OptionalTimePoint deadline,const std::vector<SharedHandle> & modelCache,const std::vector<SharedHandle> & dataCache,const CacheToken & token) const227     GeneralResult<SharedPreparedModel> prepareModelFromCache(OptionalTimePoint deadline,
228                                                              const std::vector<SharedHandle>& modelCache,
229                                                              const std::vector<SharedHandle>& dataCache,
230                                                              const CacheToken& token) const override
231     {
232         VLOG(DRIVER) << "ArmnnDriver::prepareModelFromCache()";
233         if (m_Device.get() == nullptr)
234         {
235             return NN_ERROR(ErrorStatus::DEVICE_UNAVAILABLE) << "Device Unavailable!";
236         }
237         // Check if deadline has passed.
238         if (hasDeadlinePassed(deadline)) {
239             return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT);
240         }
241 
242         return ArmnnDriverImpl::PrepareArmnnModelFromCache(
243                      m_Device->m_Runtime,
244                      m_Device->m_ClTunedParameters,
245                      m_Device->m_Options,
246                      modelCache,
247                      dataCache,
248                      token,
249                      m_Device->m_Options.GetFp16Enabled());
250     }
251 
allocate(const BufferDesc &,const std::vector<SharedPreparedModel> &,const std::vector<BufferRole> &,const std::vector<BufferRole> &) const252     GeneralResult<SharedBuffer> allocate(const BufferDesc&,
253                                          const std::vector<SharedPreparedModel>&,
254                                          const std::vector<BufferRole>&,
255                                          const std::vector<BufferRole>&) const override
256     {
257         VLOG(DRIVER) << "ArmnnDriver::allocate()";
258         return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "ArmnnDriver::allocate -- does not support allocate.";
259     }
260 };
261 
262 } // namespace armnn_driver
263