1 // 2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <android-base/logging.h> 9 #include <nnapi/IBuffer.h> 10 #include <nnapi/IDevice.h> 11 #include <nnapi/IPreparedModel.h> 12 #include <nnapi/OperandTypes.h> 13 #include <nnapi/Result.h> 14 #include <nnapi/Types.h> 15 #include <nnapi/Validation.h> 16 17 #include "ArmnnDevice.hpp" 18 #include "ArmnnDriverImpl.hpp" 19 #include "Converter.hpp" 20 21 #include "ArmnnDriverImpl.hpp" 22 #include "ModelToINetworkTransformer.hpp" 23 24 #include <armnn/Version.hpp> 25 #include <log/log.h> 26 namespace armnn_driver 27 { 28 29 //using namespace android::nn; 30 31 class ArmnnDriver : public IDevice 32 { 33 private: 34 std::unique_ptr<ArmnnDevice> m_Device; 35 public: ArmnnDriver(DriverOptions options)36 ArmnnDriver(DriverOptions options) 37 { 38 try 39 { 40 VLOG(DRIVER) << "ArmnnDriver::ArmnnDriver()"; 41 m_Device = std::unique_ptr<ArmnnDevice>(new ArmnnDevice(std::move(options))); 42 } 43 catch (armnn::InvalidArgumentException& ex) 44 { 45 VLOG(DRIVER) << "ArmnnDevice failed to initialise: " << ex.what(); 46 } 47 catch (...) 48 { 49 VLOG(DRIVER) << "ArmnnDevice failed to initialise with an unknown error"; 50 } 51 } 52 53 public: 54 getName() const55 const std::string& getName() const override 56 { 57 VLOG(DRIVER) << "ArmnnDriver::getName()"; 58 static const std::string name = "arm-armnn-sl"; 59 return name; 60 } 61 getVersionString() const62 const std::string& getVersionString() const override 63 { 64 VLOG(DRIVER) << "ArmnnDriver::getVersionString()"; 65 static const std::string versionString = ARMNN_VERSION; 66 return versionString; 67 } 68 getFeatureLevel() const69 Version getFeatureLevel() const override 70 { 71 VLOG(DRIVER) << "ArmnnDriver::getFeatureLevel()"; 72 return kVersionFeatureLevel6; 73 } 74 getType() const75 DeviceType getType() const override 76 { 77 VLOG(DRIVER) << "ArmnnDriver::getType()"; 78 return DeviceType::CPU; 79 } 80 getSupportedExtensions() const81 const std::vector<Extension>& getSupportedExtensions() const override 82 { 83 VLOG(DRIVER) << "ArmnnDriver::getSupportedExtensions()"; 84 static const std::vector<Extension> extensions = {}; 85 return extensions; 86 } 87 getCapabilities() const88 const Capabilities& getCapabilities() const override 89 { 90 VLOG(DRIVER) << "ArmnnDriver::GetCapabilities()"; 91 return ArmnnDriverImpl::GetCapabilities(m_Device->m_Runtime); 92 } 93 getNumberOfCacheFilesNeeded() const94 std::pair<uint32_t, uint32_t> getNumberOfCacheFilesNeeded() const override 95 { 96 VLOG(DRIVER) << "ArmnnDriver::getNumberOfCacheFilesNeeded()"; 97 unsigned int numberOfCachedModelFiles = 0; 98 for (auto& backend : m_Device->m_Options.GetBackends()) 99 { 100 numberOfCachedModelFiles += GetNumberOfCacheFiles(backend); 101 VLOG(DRIVER) << "ArmnnDriver::getNumberOfCacheFilesNeeded() = " 102 << std::to_string(numberOfCachedModelFiles); 103 } 104 return std::make_pair(numberOfCachedModelFiles, 1ul); 105 } 106 wait() const107 GeneralResult<void> wait() const override 108 { 109 VLOG(DRIVER) << "ArmnnDriver::wait()"; 110 return {}; 111 } 112 getSupportedOperations(const Model & model) const113 GeneralResult<std::vector<bool>> getSupportedOperations(const Model& model) const override 114 { 115 VLOG(DRIVER) << "ArmnnDriver::getSupportedOperations()"; 116 if (m_Device.get() == nullptr) 117 { 118 return NN_ERROR(ErrorStatus::DEVICE_UNAVAILABLE) << "Device Unavailable!"; 119 } 120 121 std::stringstream ss; 122 ss << "ArmnnDriverImpl::getSupportedOperations()"; 123 std::string fileName; 124 std::string timestamp; 125 if (!m_Device->m_Options.GetRequestInputsAndOutputsDumpDir().empty()) 126 { 127 ss << " : " 128 << m_Device->m_Options.GetRequestInputsAndOutputsDumpDir() 129 << "/" 130 // << GetFileTimestamp() 131 << "_getSupportedOperations.txt"; 132 } 133 VLOG(DRIVER) << ss.str().c_str(); 134 135 if (!m_Device->m_Options.GetRequestInputsAndOutputsDumpDir().empty()) 136 { 137 //dump the marker file 138 std::ofstream fileStream; 139 fileStream.open(fileName, std::ofstream::out | std::ofstream::trunc); 140 if (fileStream.good()) 141 { 142 fileStream << timestamp << std::endl; 143 fileStream << timestamp << std::endl; 144 } 145 fileStream.close(); 146 } 147 148 std::vector<bool> result; 149 if (!m_Device->m_Runtime) 150 { 151 return NN_ERROR(ErrorStatus::DEVICE_UNAVAILABLE) << "Device Unavailable!"; 152 } 153 154 // Run general model validation, if this doesn't pass we shouldn't analyse the model anyway. 155 if (const auto result = validate(model); !result.ok()) 156 { 157 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Model!"; 158 } 159 160 // Attempt to convert the model to an ArmNN input network (INetwork). 161 ModelToINetworkTransformer modelConverter(m_Device->m_Options.GetBackends(), 162 model, 163 m_Device->m_Options.GetForcedUnsupportedOperations()); 164 165 if (modelConverter.GetConversionResult() != ConversionResult::Success 166 && modelConverter.GetConversionResult() != ConversionResult::UnsupportedFeature) 167 { 168 return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << "Conversion Error!"; 169 } 170 171 // Check each operation if it was converted successfully and copy the flags 172 // into the result (vector<bool>) that we need to return to Android. 173 result.reserve(model.main.operations.size()); 174 for (uint32_t operationIdx = 0; operationIdx < model.main.operations.size(); ++operationIdx) 175 { 176 bool operationSupported = modelConverter.IsOperationSupported(operationIdx); 177 result.push_back(operationSupported); 178 } 179 180 return result; 181 } 182 prepareModel(const Model & model,ExecutionPreference preference,Priority priority,OptionalTimePoint deadline,const std::vector<SharedHandle> & modelCache,const std::vector<SharedHandle> & dataCache,const CacheToken & token,const std::vector<android::nn::TokenValuePair> & hints,const std::vector<android::nn::ExtensionNameAndPrefix> & extensionNameToPrefix) const183 GeneralResult<SharedPreparedModel> prepareModel(const Model& model, 184 ExecutionPreference preference, 185 Priority priority, 186 OptionalTimePoint deadline, 187 const std::vector<SharedHandle>& modelCache, 188 const std::vector<SharedHandle>& dataCache, 189 const CacheToken& token, 190 const std::vector<android::nn::TokenValuePair>& hints, 191 const std::vector<android::nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override 192 { 193 VLOG(DRIVER) << "ArmnnDriver::prepareModel()"; 194 195 if (m_Device.get() == nullptr) 196 { 197 return NN_ERROR(ErrorStatus::DEVICE_UNAVAILABLE) << "Device Unavailable!"; 198 } 199 // Validate arguments. 200 if (const auto result = validate(model); !result.ok()) { 201 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Model: " << result.error(); 202 } 203 if (const auto result = validate(preference); !result.ok()) { 204 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) 205 << "Invalid ExecutionPreference: " << result.error(); 206 } 207 if (const auto result = validate(priority); !result.ok()) { 208 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Priority: " << result.error(); 209 } 210 211 // Check if deadline has passed. 212 if (hasDeadlinePassed(deadline)) { 213 return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT); 214 } 215 216 return ArmnnDriverImpl::PrepareArmnnModel(m_Device->m_Runtime, 217 m_Device->m_ClTunedParameters, 218 m_Device->m_Options, 219 model, 220 modelCache, 221 dataCache, 222 token, 223 model.relaxComputationFloat32toFloat16 && m_Device->m_Options.GetFp16Enabled(), 224 priority); 225 } 226 prepareModelFromCache(OptionalTimePoint deadline,const std::vector<SharedHandle> & modelCache,const std::vector<SharedHandle> & dataCache,const CacheToken & token) const227 GeneralResult<SharedPreparedModel> prepareModelFromCache(OptionalTimePoint deadline, 228 const std::vector<SharedHandle>& modelCache, 229 const std::vector<SharedHandle>& dataCache, 230 const CacheToken& token) const override 231 { 232 VLOG(DRIVER) << "ArmnnDriver::prepareModelFromCache()"; 233 if (m_Device.get() == nullptr) 234 { 235 return NN_ERROR(ErrorStatus::DEVICE_UNAVAILABLE) << "Device Unavailable!"; 236 } 237 // Check if deadline has passed. 238 if (hasDeadlinePassed(deadline)) { 239 return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT); 240 } 241 242 return ArmnnDriverImpl::PrepareArmnnModelFromCache( 243 m_Device->m_Runtime, 244 m_Device->m_ClTunedParameters, 245 m_Device->m_Options, 246 modelCache, 247 dataCache, 248 token, 249 m_Device->m_Options.GetFp16Enabled()); 250 } 251 allocate(const BufferDesc &,const std::vector<SharedPreparedModel> &,const std::vector<BufferRole> &,const std::vector<BufferRole> &) const252 GeneralResult<SharedBuffer> allocate(const BufferDesc&, 253 const std::vector<SharedPreparedModel>&, 254 const std::vector<BufferRole>&, 255 const std::vector<BufferRole>&) const override 256 { 257 VLOG(DRIVER) << "ArmnnDriver::allocate()"; 258 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "ArmnnDriver::allocate -- does not support allocate."; 259 } 260 }; 261 262 } // namespace armnn_driver 263