1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #pragma once 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker #include <android-base/logging.h> 9*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/IBuffer.h> 10*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/IDevice.h> 11*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/IPreparedModel.h> 12*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/OperandTypes.h> 13*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/Result.h> 14*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/Types.h> 15*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/Validation.h> 16*89c4ff92SAndroid Build Coastguard Worker 17*89c4ff92SAndroid Build Coastguard Worker #include "ArmnnDevice.hpp" 18*89c4ff92SAndroid Build Coastguard Worker #include "ArmnnDriverImpl.hpp" 19*89c4ff92SAndroid Build Coastguard Worker #include "Converter.hpp" 20*89c4ff92SAndroid Build Coastguard Worker 21*89c4ff92SAndroid Build Coastguard Worker #include "ArmnnDriverImpl.hpp" 22*89c4ff92SAndroid Build Coastguard Worker #include "ModelToINetworkTransformer.hpp" 23*89c4ff92SAndroid Build Coastguard Worker 24*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Version.hpp> 25*89c4ff92SAndroid Build Coastguard Worker #include <log/log.h> 26*89c4ff92SAndroid Build Coastguard Worker namespace armnn_driver 27*89c4ff92SAndroid Build Coastguard Worker { 28*89c4ff92SAndroid Build Coastguard Worker 29*89c4ff92SAndroid Build Coastguard Worker //using namespace android::nn; 30*89c4ff92SAndroid Build Coastguard Worker 31*89c4ff92SAndroid Build Coastguard Worker class ArmnnDriver : public IDevice 32*89c4ff92SAndroid Build Coastguard Worker { 33*89c4ff92SAndroid Build Coastguard Worker private: 34*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ArmnnDevice> m_Device; 35*89c4ff92SAndroid Build Coastguard Worker public: ArmnnDriver(DriverOptions options)36*89c4ff92SAndroid Build Coastguard Worker ArmnnDriver(DriverOptions options) 37*89c4ff92SAndroid Build Coastguard Worker { 38*89c4ff92SAndroid Build Coastguard Worker try 39*89c4ff92SAndroid Build Coastguard Worker { 40*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnDriver::ArmnnDriver()"; 41*89c4ff92SAndroid Build Coastguard Worker m_Device = std::unique_ptr<ArmnnDevice>(new ArmnnDevice(std::move(options))); 42*89c4ff92SAndroid Build Coastguard Worker } 43*89c4ff92SAndroid Build Coastguard Worker catch (armnn::InvalidArgumentException& ex) 44*89c4ff92SAndroid Build Coastguard Worker { 45*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnDevice failed to initialise: " << ex.what(); 46*89c4ff92SAndroid Build Coastguard Worker } 47*89c4ff92SAndroid Build Coastguard Worker catch (...) 48*89c4ff92SAndroid Build Coastguard Worker { 49*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnDevice failed to initialise with an unknown error"; 50*89c4ff92SAndroid Build Coastguard Worker } 51*89c4ff92SAndroid Build Coastguard Worker } 52*89c4ff92SAndroid Build Coastguard Worker 53*89c4ff92SAndroid Build Coastguard Worker public: 54*89c4ff92SAndroid Build Coastguard Worker getName() const55*89c4ff92SAndroid Build Coastguard Worker const std::string& getName() const override 56*89c4ff92SAndroid Build Coastguard Worker { 57*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnDriver::getName()"; 58*89c4ff92SAndroid Build Coastguard Worker static const std::string name = "arm-armnn-sl"; 59*89c4ff92SAndroid Build Coastguard Worker return name; 60*89c4ff92SAndroid Build Coastguard Worker } 61*89c4ff92SAndroid Build Coastguard Worker getVersionString() const62*89c4ff92SAndroid Build Coastguard Worker const std::string& getVersionString() const override 63*89c4ff92SAndroid Build Coastguard Worker { 64*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnDriver::getVersionString()"; 65*89c4ff92SAndroid Build Coastguard Worker static const std::string versionString = ARMNN_VERSION; 66*89c4ff92SAndroid Build Coastguard Worker return versionString; 67*89c4ff92SAndroid Build Coastguard Worker } 68*89c4ff92SAndroid Build Coastguard Worker getFeatureLevel() const69*89c4ff92SAndroid Build Coastguard Worker Version getFeatureLevel() const override 70*89c4ff92SAndroid Build Coastguard Worker { 71*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnDriver::getFeatureLevel()"; 72*89c4ff92SAndroid Build Coastguard Worker return kVersionFeatureLevel6; 73*89c4ff92SAndroid Build Coastguard Worker } 74*89c4ff92SAndroid Build Coastguard Worker getType() const75*89c4ff92SAndroid Build Coastguard Worker DeviceType getType() const override 76*89c4ff92SAndroid Build Coastguard Worker { 77*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnDriver::getType()"; 78*89c4ff92SAndroid Build Coastguard Worker return DeviceType::CPU; 79*89c4ff92SAndroid Build Coastguard Worker } 80*89c4ff92SAndroid Build Coastguard Worker getSupportedExtensions() const81*89c4ff92SAndroid Build Coastguard Worker const std::vector<Extension>& getSupportedExtensions() const override 82*89c4ff92SAndroid Build Coastguard Worker { 83*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnDriver::getSupportedExtensions()"; 84*89c4ff92SAndroid Build Coastguard Worker static const std::vector<Extension> extensions = {}; 85*89c4ff92SAndroid Build Coastguard Worker return extensions; 86*89c4ff92SAndroid Build Coastguard Worker } 87*89c4ff92SAndroid Build Coastguard Worker getCapabilities() const88*89c4ff92SAndroid Build Coastguard Worker const Capabilities& getCapabilities() const override 89*89c4ff92SAndroid Build Coastguard Worker { 90*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnDriver::GetCapabilities()"; 91*89c4ff92SAndroid Build Coastguard Worker return ArmnnDriverImpl::GetCapabilities(m_Device->m_Runtime); 92*89c4ff92SAndroid Build Coastguard Worker } 93*89c4ff92SAndroid Build Coastguard Worker getNumberOfCacheFilesNeeded() const94*89c4ff92SAndroid Build Coastguard Worker std::pair<uint32_t, uint32_t> getNumberOfCacheFilesNeeded() const override 95*89c4ff92SAndroid Build Coastguard Worker { 96*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnDriver::getNumberOfCacheFilesNeeded()"; 97*89c4ff92SAndroid Build Coastguard Worker unsigned int numberOfCachedModelFiles = 0; 98*89c4ff92SAndroid Build Coastguard Worker for (auto& backend : m_Device->m_Options.GetBackends()) 99*89c4ff92SAndroid Build Coastguard Worker { 100*89c4ff92SAndroid Build Coastguard Worker numberOfCachedModelFiles += GetNumberOfCacheFiles(backend); 101*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnDriver::getNumberOfCacheFilesNeeded() = " 102*89c4ff92SAndroid Build Coastguard Worker << std::to_string(numberOfCachedModelFiles); 103*89c4ff92SAndroid Build Coastguard Worker } 104*89c4ff92SAndroid Build Coastguard Worker return std::make_pair(numberOfCachedModelFiles, 1ul); 105*89c4ff92SAndroid Build Coastguard Worker } 106*89c4ff92SAndroid Build Coastguard Worker wait() const107*89c4ff92SAndroid Build Coastguard Worker GeneralResult<void> wait() const override 108*89c4ff92SAndroid Build Coastguard Worker { 109*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnDriver::wait()"; 110*89c4ff92SAndroid Build Coastguard Worker return {}; 111*89c4ff92SAndroid Build Coastguard Worker } 112*89c4ff92SAndroid Build Coastguard Worker getSupportedOperations(const Model & model) const113*89c4ff92SAndroid Build Coastguard Worker GeneralResult<std::vector<bool>> getSupportedOperations(const Model& model) const override 114*89c4ff92SAndroid Build Coastguard Worker { 115*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnDriver::getSupportedOperations()"; 116*89c4ff92SAndroid Build Coastguard Worker if (m_Device.get() == nullptr) 117*89c4ff92SAndroid Build Coastguard Worker { 118*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::DEVICE_UNAVAILABLE) << "Device Unavailable!"; 119*89c4ff92SAndroid Build Coastguard Worker } 120*89c4ff92SAndroid Build Coastguard Worker 121*89c4ff92SAndroid Build Coastguard Worker std::stringstream ss; 122*89c4ff92SAndroid Build Coastguard Worker ss << "ArmnnDriverImpl::getSupportedOperations()"; 123*89c4ff92SAndroid Build Coastguard Worker std::string fileName; 124*89c4ff92SAndroid Build Coastguard Worker std::string timestamp; 125*89c4ff92SAndroid Build Coastguard Worker if (!m_Device->m_Options.GetRequestInputsAndOutputsDumpDir().empty()) 126*89c4ff92SAndroid Build Coastguard Worker { 127*89c4ff92SAndroid Build Coastguard Worker ss << " : " 128*89c4ff92SAndroid Build Coastguard Worker << m_Device->m_Options.GetRequestInputsAndOutputsDumpDir() 129*89c4ff92SAndroid Build Coastguard Worker << "/" 130*89c4ff92SAndroid Build Coastguard Worker // << GetFileTimestamp() 131*89c4ff92SAndroid Build Coastguard Worker << "_getSupportedOperations.txt"; 132*89c4ff92SAndroid Build Coastguard Worker } 133*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << ss.str().c_str(); 134*89c4ff92SAndroid Build Coastguard Worker 135*89c4ff92SAndroid Build Coastguard Worker if (!m_Device->m_Options.GetRequestInputsAndOutputsDumpDir().empty()) 136*89c4ff92SAndroid Build Coastguard Worker { 137*89c4ff92SAndroid Build Coastguard Worker //dump the marker file 138*89c4ff92SAndroid Build Coastguard Worker std::ofstream fileStream; 139*89c4ff92SAndroid Build Coastguard Worker fileStream.open(fileName, std::ofstream::out | std::ofstream::trunc); 140*89c4ff92SAndroid Build Coastguard Worker if (fileStream.good()) 141*89c4ff92SAndroid Build Coastguard Worker { 142*89c4ff92SAndroid Build Coastguard Worker fileStream << timestamp << std::endl; 143*89c4ff92SAndroid Build Coastguard Worker fileStream << timestamp << std::endl; 144*89c4ff92SAndroid Build Coastguard Worker } 145*89c4ff92SAndroid Build Coastguard Worker fileStream.close(); 146*89c4ff92SAndroid Build Coastguard Worker } 147*89c4ff92SAndroid Build Coastguard Worker 148*89c4ff92SAndroid Build Coastguard Worker std::vector<bool> result; 149*89c4ff92SAndroid Build Coastguard Worker if (!m_Device->m_Runtime) 150*89c4ff92SAndroid Build Coastguard Worker { 151*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::DEVICE_UNAVAILABLE) << "Device Unavailable!"; 152*89c4ff92SAndroid Build Coastguard Worker } 153*89c4ff92SAndroid Build Coastguard Worker 154*89c4ff92SAndroid Build Coastguard Worker // Run general model validation, if this doesn't pass we shouldn't analyse the model anyway. 155*89c4ff92SAndroid Build Coastguard Worker if (const auto result = validate(model); !result.ok()) 156*89c4ff92SAndroid Build Coastguard Worker { 157*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Model!"; 158*89c4ff92SAndroid Build Coastguard Worker } 159*89c4ff92SAndroid Build Coastguard Worker 160*89c4ff92SAndroid Build Coastguard Worker // Attempt to convert the model to an ArmNN input network (INetwork). 161*89c4ff92SAndroid Build Coastguard Worker ModelToINetworkTransformer modelConverter(m_Device->m_Options.GetBackends(), 162*89c4ff92SAndroid Build Coastguard Worker model, 163*89c4ff92SAndroid Build Coastguard Worker m_Device->m_Options.GetForcedUnsupportedOperations()); 164*89c4ff92SAndroid Build Coastguard Worker 165*89c4ff92SAndroid Build Coastguard Worker if (modelConverter.GetConversionResult() != ConversionResult::Success 166*89c4ff92SAndroid Build Coastguard Worker && modelConverter.GetConversionResult() != ConversionResult::UnsupportedFeature) 167*89c4ff92SAndroid Build Coastguard Worker { 168*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << "Conversion Error!"; 169*89c4ff92SAndroid Build Coastguard Worker } 170*89c4ff92SAndroid Build Coastguard Worker 171*89c4ff92SAndroid Build Coastguard Worker // Check each operation if it was converted successfully and copy the flags 172*89c4ff92SAndroid Build Coastguard Worker // into the result (vector<bool>) that we need to return to Android. 173*89c4ff92SAndroid Build Coastguard Worker result.reserve(model.main.operations.size()); 174*89c4ff92SAndroid Build Coastguard Worker for (uint32_t operationIdx = 0; operationIdx < model.main.operations.size(); ++operationIdx) 175*89c4ff92SAndroid Build Coastguard Worker { 176*89c4ff92SAndroid Build Coastguard Worker bool operationSupported = modelConverter.IsOperationSupported(operationIdx); 177*89c4ff92SAndroid Build Coastguard Worker result.push_back(operationSupported); 178*89c4ff92SAndroid Build Coastguard Worker } 179*89c4ff92SAndroid Build Coastguard Worker 180*89c4ff92SAndroid Build Coastguard Worker return result; 181*89c4ff92SAndroid Build Coastguard Worker } 182*89c4ff92SAndroid Build Coastguard Worker prepareModel(const Model & model,ExecutionPreference preference,Priority priority,OptionalTimePoint deadline,const std::vector<SharedHandle> & modelCache,const std::vector<SharedHandle> & dataCache,const CacheToken & token,const std::vector<android::nn::TokenValuePair> & hints,const std::vector<android::nn::ExtensionNameAndPrefix> & extensionNameToPrefix) const183*89c4ff92SAndroid Build Coastguard Worker GeneralResult<SharedPreparedModel> prepareModel(const Model& model, 184*89c4ff92SAndroid Build Coastguard Worker ExecutionPreference preference, 185*89c4ff92SAndroid Build Coastguard Worker Priority priority, 186*89c4ff92SAndroid Build Coastguard Worker OptionalTimePoint deadline, 187*89c4ff92SAndroid Build Coastguard Worker const std::vector<SharedHandle>& modelCache, 188*89c4ff92SAndroid Build Coastguard Worker const std::vector<SharedHandle>& dataCache, 189*89c4ff92SAndroid Build Coastguard Worker const CacheToken& token, 190*89c4ff92SAndroid Build Coastguard Worker const std::vector<android::nn::TokenValuePair>& hints, 191*89c4ff92SAndroid Build Coastguard Worker const std::vector<android::nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override 192*89c4ff92SAndroid Build Coastguard Worker { 193*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnDriver::prepareModel()"; 194*89c4ff92SAndroid Build Coastguard Worker 195*89c4ff92SAndroid Build Coastguard Worker if (m_Device.get() == nullptr) 196*89c4ff92SAndroid Build Coastguard Worker { 197*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::DEVICE_UNAVAILABLE) << "Device Unavailable!"; 198*89c4ff92SAndroid Build Coastguard Worker } 199*89c4ff92SAndroid Build Coastguard Worker // Validate arguments. 200*89c4ff92SAndroid Build Coastguard Worker if (const auto result = validate(model); !result.ok()) { 201*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Model: " << result.error(); 202*89c4ff92SAndroid Build Coastguard Worker } 203*89c4ff92SAndroid Build Coastguard Worker if (const auto result = validate(preference); !result.ok()) { 204*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) 205*89c4ff92SAndroid Build Coastguard Worker << "Invalid ExecutionPreference: " << result.error(); 206*89c4ff92SAndroid Build Coastguard Worker } 207*89c4ff92SAndroid Build Coastguard Worker if (const auto result = validate(priority); !result.ok()) { 208*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Priority: " << result.error(); 209*89c4ff92SAndroid Build Coastguard Worker } 210*89c4ff92SAndroid Build Coastguard Worker 211*89c4ff92SAndroid Build Coastguard Worker // Check if deadline has passed. 212*89c4ff92SAndroid Build Coastguard Worker if (hasDeadlinePassed(deadline)) { 213*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT); 214*89c4ff92SAndroid Build Coastguard Worker } 215*89c4ff92SAndroid Build Coastguard Worker 216*89c4ff92SAndroid Build Coastguard Worker return ArmnnDriverImpl::PrepareArmnnModel(m_Device->m_Runtime, 217*89c4ff92SAndroid Build Coastguard Worker m_Device->m_ClTunedParameters, 218*89c4ff92SAndroid Build Coastguard Worker m_Device->m_Options, 219*89c4ff92SAndroid Build Coastguard Worker model, 220*89c4ff92SAndroid Build Coastguard Worker modelCache, 221*89c4ff92SAndroid Build Coastguard Worker dataCache, 222*89c4ff92SAndroid Build Coastguard Worker token, 223*89c4ff92SAndroid Build Coastguard Worker model.relaxComputationFloat32toFloat16 && m_Device->m_Options.GetFp16Enabled(), 224*89c4ff92SAndroid Build Coastguard Worker priority); 225*89c4ff92SAndroid Build Coastguard Worker } 226*89c4ff92SAndroid Build Coastguard Worker prepareModelFromCache(OptionalTimePoint deadline,const std::vector<SharedHandle> & modelCache,const std::vector<SharedHandle> & dataCache,const CacheToken & token) const227*89c4ff92SAndroid Build Coastguard Worker GeneralResult<SharedPreparedModel> prepareModelFromCache(OptionalTimePoint deadline, 228*89c4ff92SAndroid Build Coastguard Worker const std::vector<SharedHandle>& modelCache, 229*89c4ff92SAndroid Build Coastguard Worker const std::vector<SharedHandle>& dataCache, 230*89c4ff92SAndroid Build Coastguard Worker const CacheToken& token) const override 231*89c4ff92SAndroid Build Coastguard Worker { 232*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnDriver::prepareModelFromCache()"; 233*89c4ff92SAndroid Build Coastguard Worker if (m_Device.get() == nullptr) 234*89c4ff92SAndroid Build Coastguard Worker { 235*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::DEVICE_UNAVAILABLE) << "Device Unavailable!"; 236*89c4ff92SAndroid Build Coastguard Worker } 237*89c4ff92SAndroid Build Coastguard Worker // Check if deadline has passed. 238*89c4ff92SAndroid Build Coastguard Worker if (hasDeadlinePassed(deadline)) { 239*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT); 240*89c4ff92SAndroid Build Coastguard Worker } 241*89c4ff92SAndroid Build Coastguard Worker 242*89c4ff92SAndroid Build Coastguard Worker return ArmnnDriverImpl::PrepareArmnnModelFromCache( 243*89c4ff92SAndroid Build Coastguard Worker m_Device->m_Runtime, 244*89c4ff92SAndroid Build Coastguard Worker m_Device->m_ClTunedParameters, 245*89c4ff92SAndroid Build Coastguard Worker m_Device->m_Options, 246*89c4ff92SAndroid Build Coastguard Worker modelCache, 247*89c4ff92SAndroid Build Coastguard Worker dataCache, 248*89c4ff92SAndroid Build Coastguard Worker token, 249*89c4ff92SAndroid Build Coastguard Worker m_Device->m_Options.GetFp16Enabled()); 250*89c4ff92SAndroid Build Coastguard Worker } 251*89c4ff92SAndroid Build Coastguard Worker allocate(const BufferDesc &,const std::vector<SharedPreparedModel> &,const std::vector<BufferRole> &,const std::vector<BufferRole> &) const252*89c4ff92SAndroid Build Coastguard Worker GeneralResult<SharedBuffer> allocate(const BufferDesc&, 253*89c4ff92SAndroid Build Coastguard Worker const std::vector<SharedPreparedModel>&, 254*89c4ff92SAndroid Build Coastguard Worker const std::vector<BufferRole>&, 255*89c4ff92SAndroid Build Coastguard Worker const std::vector<BufferRole>&) const override 256*89c4ff92SAndroid Build Coastguard Worker { 257*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnDriver::allocate()"; 258*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "ArmnnDriver::allocate -- does not support allocate."; 259*89c4ff92SAndroid Build Coastguard Worker } 260*89c4ff92SAndroid Build Coastguard Worker }; 261*89c4ff92SAndroid Build Coastguard Worker 262*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn_driver 263