1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include <backendsCommon/TensorHandleFactoryRegistry.hpp> 7*89c4ff92SAndroid Build Coastguard Worker #include <neon/NeonBackend.hpp> 8*89c4ff92SAndroid Build Coastguard Worker #include <neon/NeonTensorHandleFactory.hpp> 9*89c4ff92SAndroid Build Coastguard Worker 10*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h> 11*89c4ff92SAndroid Build Coastguard Worker 12*89c4ff92SAndroid Build Coastguard Worker using namespace armnn; 13*89c4ff92SAndroid Build Coastguard Worker 14*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("NeonBackendTests") 15*89c4ff92SAndroid Build Coastguard Worker { 16*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("NeonRegisterTensorHandleFactoriesMatchingImportFactoryId") 17*89c4ff92SAndroid Build Coastguard Worker { 18*89c4ff92SAndroid Build Coastguard Worker auto neonBackend = std::make_unique<NeonBackend>(); 19*89c4ff92SAndroid Build Coastguard Worker TensorHandleFactoryRegistry registry; 20*89c4ff92SAndroid Build Coastguard Worker neonBackend->RegisterTensorHandleFactories(registry); 21*89c4ff92SAndroid Build Coastguard Worker 22*89c4ff92SAndroid Build Coastguard Worker // When calling RegisterTensorHandleFactories, CopyAndImportFactoryPair is registered 23*89c4ff92SAndroid Build Coastguard Worker // Get matching import factory id correctly 24*89c4ff92SAndroid Build Coastguard Worker CHECK((registry.GetMatchingImportFactoryId(NeonTensorHandleFactory::GetIdStatic()) == 25*89c4ff92SAndroid Build Coastguard Worker NeonTensorHandleFactory::GetIdStatic())); 26*89c4ff92SAndroid Build Coastguard Worker } 27*89c4ff92SAndroid Build Coastguard Worker 28*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("NeonCreateWorkloadFactoryMatchingImportFactoryId") 29*89c4ff92SAndroid Build Coastguard Worker { 30*89c4ff92SAndroid Build Coastguard Worker auto neonBackend = std::make_unique<NeonBackend>(); 31*89c4ff92SAndroid Build Coastguard Worker TensorHandleFactoryRegistry registry; 32*89c4ff92SAndroid Build Coastguard Worker neonBackend->CreateWorkloadFactory(registry); 33*89c4ff92SAndroid Build Coastguard Worker 34*89c4ff92SAndroid Build Coastguard Worker // When calling CreateWorkloadFactory, CopyAndImportFactoryPair is registered 35*89c4ff92SAndroid Build Coastguard Worker // Get matching import factory id correctly 36*89c4ff92SAndroid Build Coastguard Worker CHECK((registry.GetMatchingImportFactoryId(NeonTensorHandleFactory::GetIdStatic()) == 37*89c4ff92SAndroid Build Coastguard Worker NeonTensorHandleFactory::GetIdStatic())); 38*89c4ff92SAndroid Build Coastguard Worker } 39*89c4ff92SAndroid Build Coastguard Worker 40*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("NeonCreateWorkloadFactoryWithOptionsMatchingImportFactoryId") 41*89c4ff92SAndroid Build Coastguard Worker { 42*89c4ff92SAndroid Build Coastguard Worker auto neonBackend = std::make_unique<NeonBackend>(); 43*89c4ff92SAndroid Build Coastguard Worker TensorHandleFactoryRegistry registry; 44*89c4ff92SAndroid Build Coastguard Worker ModelOptions modelOptions; 45*89c4ff92SAndroid Build Coastguard Worker neonBackend->CreateWorkloadFactory(registry, modelOptions); 46*89c4ff92SAndroid Build Coastguard Worker 47*89c4ff92SAndroid Build Coastguard Worker // When calling CreateWorkloadFactory with ModelOptions, CopyAndImportFactoryPair is registered 48*89c4ff92SAndroid Build Coastguard Worker // Get matching import factory id correctly 49*89c4ff92SAndroid Build Coastguard Worker CHECK((registry.GetMatchingImportFactoryId(NeonTensorHandleFactory::GetIdStatic()) == 50*89c4ff92SAndroid Build Coastguard Worker NeonTensorHandleFactory::GetIdStatic())); 51*89c4ff92SAndroid Build Coastguard Worker } 52*89c4ff92SAndroid Build Coastguard Worker } 53