1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include <armnn/Types.hpp> 7 #include <armnn/BackendRegistry.hpp> 8 9 #include <armnn/backends/IBackendInternal.hpp> 10 #include <backendsCommon/memoryOptimizerStrategyLibrary/strategies/ConstantMemoryStrategy.hpp> 11 #include <reference/RefBackend.hpp> 12 13 #include <doctest/doctest.h> 14 15 namespace 16 { 17 18 class SwapRegistryStorage : public armnn::BackendRegistry 19 { 20 public: SwapRegistryStorage()21 SwapRegistryStorage() : armnn::BackendRegistry() 22 { 23 Swap(armnn::BackendRegistryInstance(), m_TempStorage); 24 } 25 ~SwapRegistryStorage()26 ~SwapRegistryStorage() 27 { 28 Swap(armnn::BackendRegistryInstance(),m_TempStorage); 29 } 30 31 private: 32 FactoryStorage m_TempStorage; 33 }; 34 35 } 36 37 TEST_SUITE("BackendRegistryTests") 38 { 39 TEST_CASE("SwapRegistry") 40 { 41 using namespace armnn; 42 auto nFactories = BackendRegistryInstance().Size(); 43 { 44 SwapRegistryStorage helper; 45 CHECK(BackendRegistryInstance().Size() == 0); 46 } 47 CHECK(BackendRegistryInstance().Size() == nFactories); 48 } 49 50 TEST_CASE("TestRegistryHelper") 51 { 52 using namespace armnn; 53 SwapRegistryStorage helper; 54 55 bool called = false; 56 57 BackendRegistry::StaticRegistryInitializer factoryHelper( 58 BackendRegistryInstance(), 59 "HelloWorld", 60 [&called]() __anon0e237d1a0202() 61 { 62 called = true; 63 return armnn::IBackendInternalUniquePtr(nullptr); 64 } 65 ); 66 67 // sanity check: the factory has not been called yet 68 CHECK(called == false); 69 70 auto factoryFunction = BackendRegistryInstance().GetFactory("HelloWorld"); 71 72 // sanity check: the factory still not called 73 CHECK(called == false); 74 75 factoryFunction(); 76 CHECK(called == true); 77 BackendRegistryInstance().Deregister("HelloWorld"); 78 } 79 80 TEST_CASE("TestDirectCallToRegistry") 81 { 82 using namespace armnn; 83 SwapRegistryStorage helper; 84 85 bool called = false; 86 BackendRegistryInstance().Register( 87 "HelloWorld", 88 [&called]() __anon0e237d1a0302() 89 { 90 called = true; 91 return armnn::IBackendInternalUniquePtr(nullptr); 92 } 93 ); 94 95 // sanity check: the factory has not been called yet 96 CHECK(called == false); 97 98 auto factoryFunction = BackendRegistryInstance().GetFactory("HelloWorld"); 99 100 // sanity check: the factory still not called 101 CHECK(called == false); 102 103 factoryFunction(); 104 CHECK(called == true); 105 BackendRegistryInstance().Deregister("HelloWorld"); 106 } 107 108 // Test that backends can throw exceptions during their factory function to prevent loading in an unsuitable 109 // environment. For example Neon Backend loading on armhf device without neon support. 110 // In reality the dynamic backend is loaded in during the LoadDynamicBackends(options.m_DynamicBackendsPath) 111 // step of runtime constructor, then the factory function is called to check if supported, in case 112 // of Neon not being detected the exception is raised and so the backend is not added to the supportedBackends 113 // list 114 115 TEST_CASE("ThrowBackendUnavailableException") 116 { 117 using namespace armnn; 118 119 const BackendId mockBackendId("MockDynamicBackend"); 120 121 const std::string exceptionMessage("Mock error message to test unavailable backend"); 122 123 // Register the mock backend with a factory function lambda that always throws 124 BackendRegistryInstance().Register(mockBackendId, 125 [exceptionMessage]() __anon0e237d1a0402() 126 { 127 throw armnn::BackendUnavailableException(exceptionMessage); 128 return IBackendInternalUniquePtr(); // Satisfy return type 129 }); 130 131 // Get the factory function of the mock backend 132 auto factoryFunc = BackendRegistryInstance().GetFactory(mockBackendId); 133 134 try 135 { 136 // Call the factory function as done during runtime backend registering 137 auto backend = factoryFunc(); 138 FAIL("Expected exception to have been thrown"); 139 } 140 catch (const BackendUnavailableException& e) 141 { 142 // Caught 143 CHECK_EQ(e.what(), exceptionMessage); 144 } 145 // Clean up the registry for the next test. 146 BackendRegistryInstance().Deregister(mockBackendId); 147 } 148 149 #if defined(ARMNNREF_ENABLED) 150 TEST_CASE("RegisterMemoryOptimizerStrategy") 151 { 152 using namespace armnn; 153 154 const BackendId cpuRefBackendId(armnn::Compute::CpuRef); 155 CHECK(BackendRegistryInstance().GetMemoryOptimizerStrategies().empty()); 156 157 // Register the memory optimizer 158 std::shared_ptr<IMemoryOptimizerStrategy> memoryOptimizerStrategy = 159 std::make_shared<ConstantMemoryStrategy>(); 160 BackendRegistryInstance().RegisterMemoryOptimizerStrategy(cpuRefBackendId, memoryOptimizerStrategy); 161 CHECK(!BackendRegistryInstance().GetMemoryOptimizerStrategies().empty()); 162 CHECK(BackendRegistryInstance().GetMemoryOptimizerStrategies().size() == 1); 163 // De-register the memory optimizer 164 BackendRegistryInstance().DeregisterMemoryOptimizerStrategy(cpuRefBackendId); 165 CHECK(BackendRegistryInstance().GetMemoryOptimizerStrategies().empty()); 166 } 167 #endif 168 169 } 170