1 // 2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include <backendsCommon/TensorHandleFactoryRegistry.hpp> 7 #include <cl/ClBackend.hpp> 8 #include <cl/ClTensorHandleFactory.hpp> 9 #include <cl/ClImportTensorHandleFactory.hpp> 10 #include <cl/test/ClContextControlFixture.hpp> 11 12 #include <doctest/doctest.h> 13 14 using namespace armnn; 15 16 TEST_SUITE("ClBackendTests") 17 { 18 TEST_CASE("ClRegisterTensorHandleFactoriesMatchingImportFactoryId") 19 { 20 auto clBackend = std::make_unique<ClBackend>(); 21 TensorHandleFactoryRegistry registry; 22 clBackend->RegisterTensorHandleFactories(registry); 23 24 // When calling RegisterTensorHandleFactories, CopyAndImportFactoryPair is registered 25 // Get ClImportTensorHandleFactory id as the matching import factory id 26 CHECK((registry.GetMatchingImportFactoryId(ClTensorHandleFactory::GetIdStatic()) == 27 ClImportTensorHandleFactory::GetIdStatic())); 28 } 29 30 TEST_CASE("ClRegisterTensorHandleFactoriesWithMemorySourceFlagsMatchingImportFactoryId") 31 { 32 auto clBackend = std::make_unique<ClBackend>(); 33 TensorHandleFactoryRegistry registry; 34 clBackend->RegisterTensorHandleFactories(registry, 35 static_cast<MemorySourceFlags>(MemorySource::Malloc), 36 static_cast<MemorySourceFlags>(MemorySource::Malloc)); 37 38 // When calling RegisterTensorHandleFactories with MemorySourceFlags, CopyAndImportFactoryPair is registered 39 // Get ClImportTensorHandleFactory id as the matching import factory id 40 CHECK((registry.GetMatchingImportFactoryId(ClTensorHandleFactory::GetIdStatic()) == 41 ClImportTensorHandleFactory::GetIdStatic())); 42 } 43 44 TEST_CASE_FIXTURE(ClContextControlFixture, "ClCreateWorkloadFactoryMatchingImportFactoryId") 45 { 46 auto clBackend = std::make_unique<ClBackend>(); 47 TensorHandleFactoryRegistry registry; 48 clBackend->CreateWorkloadFactory(registry); 49 50 // When calling CreateWorkloadFactory, CopyAndImportFactoryPair is registered 51 // Get ClImportTensorHandleFactory id as the matching import factory id 52 CHECK((registry.GetMatchingImportFactoryId(ClTensorHandleFactory::GetIdStatic()) == 53 ClImportTensorHandleFactory::GetIdStatic())); 54 } 55 56 TEST_CASE_FIXTURE(ClContextControlFixture, "ClCreateWorkloadFactoryWithOptionsMatchingImportFactoryId") 57 { 58 auto clBackend = std::make_unique<ClBackend>(); 59 TensorHandleFactoryRegistry registry; 60 ModelOptions modelOptions; 61 clBackend->CreateWorkloadFactory(registry, modelOptions); 62 63 // When calling CreateWorkloadFactory with ModelOptions, CopyAndImportFactoryPair is registered 64 // Get ClImportTensorHandleFactory id as the matching import factory id 65 CHECK((registry.GetMatchingImportFactoryId(ClTensorHandleFactory::GetIdStatic()) == 66 ClImportTensorHandleFactory::GetIdStatic())); 67 } 68 69 TEST_CASE_FIXTURE(ClContextControlFixture, "ClCreateWorkloadFactoryWitMemoryFlagsMatchingImportFactoryId") 70 { 71 auto clBackend = std::make_unique<ClBackend>(); 72 TensorHandleFactoryRegistry registry; 73 ModelOptions modelOptions; 74 clBackend->CreateWorkloadFactory(registry, modelOptions, 75 static_cast<MemorySourceFlags>(MemorySource::Malloc), 76 static_cast<MemorySourceFlags>(MemorySource::Malloc)); 77 78 // When calling CreateWorkloadFactory with ModelOptions and MemorySourceFlags, 79 // CopyAndImportFactoryPair is registered 80 // Get ClImportTensorHandleFactory id as the matching import factory id 81 CHECK((registry.GetMatchingImportFactoryId(ClTensorHandleFactory::GetIdStatic()) == 82 ClImportTensorHandleFactory::GetIdStatic())); 83 } 84 } 85