1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include <CommonTestUtils.hpp> 7*89c4ff92SAndroid Build Coastguard Worker #include <backendsCommon/test/mockBackend/MockImportBackend.hpp> 8*89c4ff92SAndroid Build Coastguard Worker 9*89c4ff92SAndroid Build Coastguard Worker #include <GraphUtils.hpp> 10*89c4ff92SAndroid Build Coastguard Worker 11*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h> 12*89c4ff92SAndroid Build Coastguard Worker 13*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("NeonFallback") 14*89c4ff92SAndroid Build Coastguard Worker { 15*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FallbackImportToCpuAcc") 16*89c4ff92SAndroid Build Coastguard Worker { 17*89c4ff92SAndroid Build Coastguard Worker using namespace armnn; 18*89c4ff92SAndroid Build Coastguard Worker 19*89c4ff92SAndroid Build Coastguard Worker // Create a mock backend objectN 20*89c4ff92SAndroid Build Coastguard Worker MockImportBackendInitialiser initialiser; // Register the Mock Backend 21*89c4ff92SAndroid Build Coastguard Worker auto backendObjPtr = CreateBackendObject(MockImportBackendId()); 22*89c4ff92SAndroid Build Coastguard Worker CHECK((backendObjPtr != nullptr)); 23*89c4ff92SAndroid Build Coastguard Worker 24*89c4ff92SAndroid Build Coastguard Worker BackendIdSet backendIds = BackendRegistryInstance().GetBackendIds(); 25*89c4ff92SAndroid Build Coastguard Worker if (backendIds.find("MockRef") == backendIds.end()) 26*89c4ff92SAndroid Build Coastguard Worker { 27*89c4ff92SAndroid Build Coastguard Worker std::string message = "Cannot load MockRef"; 28*89c4ff92SAndroid Build Coastguard Worker FAIL(message); 29*89c4ff92SAndroid Build Coastguard Worker } 30*89c4ff92SAndroid Build Coastguard Worker 31*89c4ff92SAndroid Build Coastguard Worker // Create runtime in which test will run and allow fallback to CpuRef. 32*89c4ff92SAndroid Build Coastguard Worker IRuntime::CreationOptions options; 33*89c4ff92SAndroid Build Coastguard Worker IRuntimePtr runtime(IRuntime::Create(options)); 34*89c4ff92SAndroid Build Coastguard Worker 35*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network. 36*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net(INetwork::Create()); 37*89c4ff92SAndroid Build Coastguard Worker 38*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input0 = net->AddInputLayer(0, "input0"); 39*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input1 = net->AddInputLayer(1, "input1"); 40*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input2 = net->AddInputLayer(2, "input2"); 41*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* add = net->AddElementwiseBinaryLayer(BinaryOperation::Add, "add"); 42*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* sub = net->AddElementwiseBinaryLayer(BinaryOperation::Sub, "sub"); 43*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* output = net->AddOutputLayer(0, "output"); 44*89c4ff92SAndroid Build Coastguard Worker 45*89c4ff92SAndroid Build Coastguard Worker input0->GetOutputSlot(0).Connect(add->GetInputSlot(0)); 46*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot(0).Connect(add->GetInputSlot(1)); 47*89c4ff92SAndroid Build Coastguard Worker input2->GetOutputSlot(0).Connect(sub->GetInputSlot(0)); 48*89c4ff92SAndroid Build Coastguard Worker add->GetOutputSlot(0).Connect(sub->GetInputSlot(1)); 49*89c4ff92SAndroid Build Coastguard Worker sub->GetOutputSlot(0).Connect(output->GetInputSlot(0)); 50*89c4ff92SAndroid Build Coastguard Worker 51*89c4ff92SAndroid Build Coastguard Worker TensorInfo info = TensorInfo({ 1, 2, 3, 2 }, DataType::Float32); 52*89c4ff92SAndroid Build Coastguard Worker 53*89c4ff92SAndroid Build Coastguard Worker input0->GetOutputSlot(0).SetTensorInfo(info); 54*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot(0).SetTensorInfo(info); 55*89c4ff92SAndroid Build Coastguard Worker input2->GetOutputSlot(0).SetTensorInfo(info); 56*89c4ff92SAndroid Build Coastguard Worker add->GetOutputSlot(0).SetTensorInfo(info); 57*89c4ff92SAndroid Build Coastguard Worker sub->GetOutputSlot(0).SetTensorInfo(info); 58*89c4ff92SAndroid Build Coastguard Worker 59*89c4ff92SAndroid Build Coastguard Worker // optimize the network 60*89c4ff92SAndroid Build Coastguard Worker std::vector<BackendId> backends = { "MockRef", Compute::CpuAcc }; 61*89c4ff92SAndroid Build Coastguard Worker OptimizerOptionsOpaque optOptions; 62*89c4ff92SAndroid Build Coastguard Worker optOptions.SetImportEnabled(true); 63*89c4ff92SAndroid Build Coastguard Worker optOptions.SetExportEnabled(true); 64*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optOptions); 65*89c4ff92SAndroid Build Coastguard Worker 66*89c4ff92SAndroid Build Coastguard Worker Graph& graph = GetGraphForTesting(optNet.get()); 67*89c4ff92SAndroid Build Coastguard Worker 68*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer0 = GetFirstLayerWithName(graph, "input0"); 69*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer1 = GetFirstLayerWithName(graph, "input1"); 70*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer2 = GetFirstLayerWithName(graph, "input2"); 71*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer3 = GetFirstLayerWithName(graph, "add"); 72*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer4 = GetFirstLayerWithName(graph, "[ add (0) -> sub (1) ]"); 73*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer5 = GetFirstLayerWithName(graph, "sub"); 74*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer6 = GetFirstLayerWithName(graph, "output"); 75*89c4ff92SAndroid Build Coastguard Worker 76*89c4ff92SAndroid Build Coastguard Worker // Checks order is valid. 77*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer0, layer1)); 78*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer1, layer2)); 79*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer2, layer3)); 80*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer3, layer4)); 81*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer4, layer5)); 82*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer5, layer6)); 83*89c4ff92SAndroid Build Coastguard Worker 84*89c4ff92SAndroid Build Coastguard Worker // Load it into the runtime. It should pass. 85*89c4ff92SAndroid Build Coastguard Worker NetworkId netId; 86*89c4ff92SAndroid Build Coastguard Worker std::string ignoredErrorMessage; 87*89c4ff92SAndroid Build Coastguard Worker INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc); 88*89c4ff92SAndroid Build Coastguard Worker runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties); 89*89c4ff92SAndroid Build Coastguard Worker 90*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output 91*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData0 92*89c4ff92SAndroid Build Coastguard Worker { 93*89c4ff92SAndroid Build Coastguard Worker 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 6.0f 94*89c4ff92SAndroid Build Coastguard Worker }; 95*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData1 96*89c4ff92SAndroid Build Coastguard Worker { 97*89c4ff92SAndroid Build Coastguard Worker 0.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f 98*89c4ff92SAndroid Build Coastguard Worker }; 99*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData2 100*89c4ff92SAndroid Build Coastguard Worker { 101*89c4ff92SAndroid Build Coastguard Worker 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f 102*89c4ff92SAndroid Build Coastguard Worker }; 103*89c4ff92SAndroid Build Coastguard Worker 104*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputData(12); 105*89c4ff92SAndroid Build Coastguard Worker 106*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutput 107*89c4ff92SAndroid Build Coastguard Worker { 108*89c4ff92SAndroid Build Coastguard Worker 11.0f, 9.0f, 7.0f, 5.0f, 3.0f, 1.0f, -1.0f, -3.0f, -5.0f, -7.0f, -9.0f, -11.0f 109*89c4ff92SAndroid Build Coastguard Worker }; 110*89c4ff92SAndroid Build Coastguard Worker 111*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo0 = runtime->GetInputTensorInfo(netId, 0); 112*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo1 = runtime->GetInputTensorInfo(netId, 1); 113*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo2 = runtime->GetInputTensorInfo(netId, 2); 114*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo0.SetConstant(true); 115*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo1.SetConstant(true); 116*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo2.SetConstant(true); 117*89c4ff92SAndroid Build Coastguard Worker 118*89c4ff92SAndroid Build Coastguard Worker InputTensors inputTensors 119*89c4ff92SAndroid Build Coastguard Worker { 120*89c4ff92SAndroid Build Coastguard Worker { 0, armnn::ConstTensor(inputTensorInfo0, inputData0.data()) }, 121*89c4ff92SAndroid Build Coastguard Worker { 1, armnn::ConstTensor(inputTensorInfo1, inputData1.data()) }, 122*89c4ff92SAndroid Build Coastguard Worker { 2, armnn::ConstTensor(inputTensorInfo2, inputData2.data()) } 123*89c4ff92SAndroid Build Coastguard Worker }; 124*89c4ff92SAndroid Build Coastguard Worker OutputTensors outputTensors 125*89c4ff92SAndroid Build Coastguard Worker { 126*89c4ff92SAndroid Build Coastguard Worker { 0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data()) } 127*89c4ff92SAndroid Build Coastguard Worker }; 128*89c4ff92SAndroid Build Coastguard Worker 129*89c4ff92SAndroid Build Coastguard Worker runtime->GetProfiler(netId)->EnableProfiling(true); 130*89c4ff92SAndroid Build Coastguard Worker 131*89c4ff92SAndroid Build Coastguard Worker // Do the inference 132*89c4ff92SAndroid Build Coastguard Worker runtime->EnqueueWorkload(netId, inputTensors, outputTensors); 133*89c4ff92SAndroid Build Coastguard Worker 134*89c4ff92SAndroid Build Coastguard Worker // Retrieve the Profiler.Print() output to get the workload execution 135*89c4ff92SAndroid Build Coastguard Worker ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance(); 136*89c4ff92SAndroid Build Coastguard Worker std::stringstream ss; 137*89c4ff92SAndroid Build Coastguard Worker profilerManager.GetProfiler()->Print(ss);; 138*89c4ff92SAndroid Build Coastguard Worker std::string dump = ss.str(); 139*89c4ff92SAndroid Build Coastguard Worker 140*89c4ff92SAndroid Build Coastguard Worker // Contains ImportMemGeneric 141*89c4ff92SAndroid Build Coastguard Worker std::size_t found = dump.find("ImportMemGeneric"); 142*89c4ff92SAndroid Build Coastguard Worker CHECK(found != std::string::npos); 143*89c4ff92SAndroid Build Coastguard Worker 144*89c4ff92SAndroid Build Coastguard Worker // Contains SyncMemGeneric 145*89c4ff92SAndroid Build Coastguard Worker found = dump.find("SyncMemGeneric"); 146*89c4ff92SAndroid Build Coastguard Worker CHECK(found != std::string::npos); 147*89c4ff92SAndroid Build Coastguard Worker 148*89c4ff92SAndroid Build Coastguard Worker // Does not contain CopyMemGeneric 149*89c4ff92SAndroid Build Coastguard Worker found = dump.find("CopyMemGeneric"); 150*89c4ff92SAndroid Build Coastguard Worker CHECK(found == std::string::npos); 151*89c4ff92SAndroid Build Coastguard Worker 152*89c4ff92SAndroid Build Coastguard Worker // Use memory import between backends 153*89c4ff92SAndroid Build Coastguard Worker CHECK((layer4->GetType() == LayerType::MemImport)); 154*89c4ff92SAndroid Build Coastguard Worker 155*89c4ff92SAndroid Build Coastguard Worker // Check output is as expected 156*89c4ff92SAndroid Build Coastguard Worker CHECK(outputData == expectedOutput); 157*89c4ff92SAndroid Build Coastguard Worker } 158*89c4ff92SAndroid Build Coastguard Worker 159*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FallbackPaddingCopyToCpuAcc") 160*89c4ff92SAndroid Build Coastguard Worker { 161*89c4ff92SAndroid Build Coastguard Worker using namespace armnn; 162*89c4ff92SAndroid Build Coastguard Worker 163*89c4ff92SAndroid Build Coastguard Worker // Create a mock backend object 164*89c4ff92SAndroid Build Coastguard Worker MockImportBackendInitialiser initialiser; // Register the Mock Backend 165*89c4ff92SAndroid Build Coastguard Worker auto backendObjPtr = CreateBackendObject(MockImportBackendId()); 166*89c4ff92SAndroid Build Coastguard Worker CHECK((backendObjPtr != nullptr)); 167*89c4ff92SAndroid Build Coastguard Worker 168*89c4ff92SAndroid Build Coastguard Worker BackendIdSet backendIds = BackendRegistryInstance().GetBackendIds(); 169*89c4ff92SAndroid Build Coastguard Worker if (backendIds.find("MockRef") == backendIds.end()) 170*89c4ff92SAndroid Build Coastguard Worker { 171*89c4ff92SAndroid Build Coastguard Worker std::string message = "Cannot load MockRef"; 172*89c4ff92SAndroid Build Coastguard Worker FAIL(message); 173*89c4ff92SAndroid Build Coastguard Worker } 174*89c4ff92SAndroid Build Coastguard Worker 175*89c4ff92SAndroid Build Coastguard Worker // Create runtime in which test will run and allow fallback to CpuRef. 176*89c4ff92SAndroid Build Coastguard Worker IRuntime::CreationOptions options; 177*89c4ff92SAndroid Build Coastguard Worker IRuntimePtr runtime(IRuntime::Create(options)); 178*89c4ff92SAndroid Build Coastguard Worker 179*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network. 180*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net(INetwork::Create()); 181*89c4ff92SAndroid Build Coastguard Worker 182*89c4ff92SAndroid Build Coastguard Worker Pooling2dDescriptor desc; 183*89c4ff92SAndroid Build Coastguard Worker 184*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input0 = net->AddInputLayer(0, "input0"); 185*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input1 = net->AddInputLayer(1, "input1"); 186*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* add = net->AddElementwiseBinaryLayer(BinaryOperation::Add, "add"); 187*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* pooling = net->AddPooling2dLayer(desc, "pooling"); 188*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* output = net->AddOutputLayer(0, "output"); 189*89c4ff92SAndroid Build Coastguard Worker 190*89c4ff92SAndroid Build Coastguard Worker input0->GetOutputSlot(0).Connect(add->GetInputSlot(0)); 191*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot(0).Connect(add->GetInputSlot(1)); 192*89c4ff92SAndroid Build Coastguard Worker add->GetOutputSlot(0).Connect(pooling->GetInputSlot(0)); 193*89c4ff92SAndroid Build Coastguard Worker pooling->GetOutputSlot(0).Connect(output->GetInputSlot(0)); 194*89c4ff92SAndroid Build Coastguard Worker 195*89c4ff92SAndroid Build Coastguard Worker TensorInfo info = TensorInfo({ 1, 2, 3, 2 }, DataType::Float32); 196*89c4ff92SAndroid Build Coastguard Worker TensorInfo poolingInfo = TensorInfo({ 1, 2, 1, 1 }, DataType::Float32); 197*89c4ff92SAndroid Build Coastguard Worker 198*89c4ff92SAndroid Build Coastguard Worker input0->GetOutputSlot(0).SetTensorInfo(info); 199*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot(0).SetTensorInfo(info); 200*89c4ff92SAndroid Build Coastguard Worker add->GetOutputSlot(0).SetTensorInfo(info); 201*89c4ff92SAndroid Build Coastguard Worker pooling->GetOutputSlot(0).SetTensorInfo(poolingInfo); 202*89c4ff92SAndroid Build Coastguard Worker 203*89c4ff92SAndroid Build Coastguard Worker // optimize the network 204*89c4ff92SAndroid Build Coastguard Worker std::vector<BackendId> backends = { "MockRef", Compute::CpuAcc }; 205*89c4ff92SAndroid Build Coastguard Worker OptimizerOptionsOpaque optOptions; 206*89c4ff92SAndroid Build Coastguard Worker optOptions.SetImportEnabled(true); 207*89c4ff92SAndroid Build Coastguard Worker optOptions.SetExportEnabled(true); 208*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optOptions); 209*89c4ff92SAndroid Build Coastguard Worker 210*89c4ff92SAndroid Build Coastguard Worker Graph& graph = GetGraphForTesting(optNet.get()); 211*89c4ff92SAndroid Build Coastguard Worker 212*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer0 = GetFirstLayerWithName(graph, "input0"); 213*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer1 = GetFirstLayerWithName(graph, "input1"); 214*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer2 = GetFirstLayerWithName(graph, "add"); 215*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer3 = GetFirstLayerWithName(graph, "[ add (0) -> pooling (0) ]"); 216*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer4 = GetFirstLayerWithName(graph, "pooling"); 217*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer5 = GetFirstLayerWithName(graph, "output"); 218*89c4ff92SAndroid Build Coastguard Worker 219*89c4ff92SAndroid Build Coastguard Worker // Checks order is valid. 220*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer0, layer1)); 221*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer1, layer2)); 222*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer2, layer3)); 223*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer3, layer4)); 224*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer4, layer5)); 225*89c4ff92SAndroid Build Coastguard Worker 226*89c4ff92SAndroid Build Coastguard Worker // Load it into the runtime. It should pass. 227*89c4ff92SAndroid Build Coastguard Worker NetworkId netId; 228*89c4ff92SAndroid Build Coastguard Worker std::string ignoredErrorMessage; 229*89c4ff92SAndroid Build Coastguard Worker INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc); 230*89c4ff92SAndroid Build Coastguard Worker 231*89c4ff92SAndroid Build Coastguard Worker runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties); 232*89c4ff92SAndroid Build Coastguard Worker 233*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output 234*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData0 235*89c4ff92SAndroid Build Coastguard Worker { 236*89c4ff92SAndroid Build Coastguard Worker 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 6.0f 237*89c4ff92SAndroid Build Coastguard Worker }; 238*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData1 239*89c4ff92SAndroid Build Coastguard Worker { 240*89c4ff92SAndroid Build Coastguard Worker 0.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f 241*89c4ff92SAndroid Build Coastguard Worker }; 242*89c4ff92SAndroid Build Coastguard Worker 243*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputData(2); 244*89c4ff92SAndroid Build Coastguard Worker 245*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutput 246*89c4ff92SAndroid Build Coastguard Worker { 247*89c4ff92SAndroid Build Coastguard Worker 6.0f, 12.0f 248*89c4ff92SAndroid Build Coastguard Worker }; 249*89c4ff92SAndroid Build Coastguard Worker 250*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo0 = runtime->GetInputTensorInfo(netId, 0); 251*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo1 = runtime->GetInputTensorInfo(netId, 1); 252*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo0.SetConstant(true); 253*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo1.SetConstant(true); 254*89c4ff92SAndroid Build Coastguard Worker 255*89c4ff92SAndroid Build Coastguard Worker InputTensors inputTensors 256*89c4ff92SAndroid Build Coastguard Worker { 257*89c4ff92SAndroid Build Coastguard Worker { 0, armnn::ConstTensor(inputTensorInfo0, inputData0.data()) }, 258*89c4ff92SAndroid Build Coastguard Worker { 1, armnn::ConstTensor(inputTensorInfo1, inputData1.data()) } 259*89c4ff92SAndroid Build Coastguard Worker }; 260*89c4ff92SAndroid Build Coastguard Worker OutputTensors outputTensors 261*89c4ff92SAndroid Build Coastguard Worker { 262*89c4ff92SAndroid Build Coastguard Worker { 0, armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data()) } 263*89c4ff92SAndroid Build Coastguard Worker }; 264*89c4ff92SAndroid Build Coastguard Worker 265*89c4ff92SAndroid Build Coastguard Worker runtime->GetProfiler(netId)->EnableProfiling(true); 266*89c4ff92SAndroid Build Coastguard Worker 267*89c4ff92SAndroid Build Coastguard Worker // Do the inference 268*89c4ff92SAndroid Build Coastguard Worker runtime->EnqueueWorkload(netId, inputTensors, outputTensors); 269*89c4ff92SAndroid Build Coastguard Worker 270*89c4ff92SAndroid Build Coastguard Worker // Retrieve the Profiler.Print() output to get the workload execution 271*89c4ff92SAndroid Build Coastguard Worker ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance(); 272*89c4ff92SAndroid Build Coastguard Worker std::stringstream ss; 273*89c4ff92SAndroid Build Coastguard Worker profilerManager.GetProfiler()->Print(ss);; 274*89c4ff92SAndroid Build Coastguard Worker std::string dump = ss.str(); 275*89c4ff92SAndroid Build Coastguard Worker 276*89c4ff92SAndroid Build Coastguard Worker // Contains CopyMemGeneric between the backends 277*89c4ff92SAndroid Build Coastguard Worker std::size_t found = dump.find("CopyMemGeneric"); 278*89c4ff92SAndroid Build Coastguard Worker CHECK(found != std::string::npos); 279*89c4ff92SAndroid Build Coastguard Worker 280*89c4ff92SAndroid Build Coastguard Worker // Contains SyncMemGeneric for the output 281*89c4ff92SAndroid Build Coastguard Worker found = dump.find("SyncMemGeneric"); 282*89c4ff92SAndroid Build Coastguard Worker CHECK(found != std::string::npos); 283*89c4ff92SAndroid Build Coastguard Worker 284*89c4ff92SAndroid Build Coastguard Worker // Does not contain ImportMemGeneric 285*89c4ff92SAndroid Build Coastguard Worker found = dump.find("ImportMemGeneric"); 286*89c4ff92SAndroid Build Coastguard Worker CHECK(found == std::string::npos); 287*89c4ff92SAndroid Build Coastguard Worker 288*89c4ff92SAndroid Build Coastguard Worker // Use memory import between backends 289*89c4ff92SAndroid Build Coastguard Worker CHECK((layer3->GetType() == LayerType::MemCopy)); 290*89c4ff92SAndroid Build Coastguard Worker 291*89c4ff92SAndroid Build Coastguard Worker // Check output is as expected 292*89c4ff92SAndroid Build Coastguard Worker CHECK(outputData == expectedOutput); 293*89c4ff92SAndroid Build Coastguard Worker } 294*89c4ff92SAndroid Build Coastguard Worker 295*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FallbackImportFromCpuAcc") 296*89c4ff92SAndroid Build Coastguard Worker { 297*89c4ff92SAndroid Build Coastguard Worker using namespace armnn; 298*89c4ff92SAndroid Build Coastguard Worker 299*89c4ff92SAndroid Build Coastguard Worker // Create a mock backend object 300*89c4ff92SAndroid Build Coastguard Worker MockImportBackendInitialiser initialiser; // Register the Mock Backend 301*89c4ff92SAndroid Build Coastguard Worker auto backendObjPtr = CreateBackendObject(MockImportBackendId()); 302*89c4ff92SAndroid Build Coastguard Worker CHECK((backendObjPtr != nullptr)); 303*89c4ff92SAndroid Build Coastguard Worker 304*89c4ff92SAndroid Build Coastguard Worker BackendIdSet backendIds = BackendRegistryInstance().GetBackendIds(); 305*89c4ff92SAndroid Build Coastguard Worker if (backendIds.find("MockRef") == backendIds.end()) 306*89c4ff92SAndroid Build Coastguard Worker { 307*89c4ff92SAndroid Build Coastguard Worker std::string message = "Cannot load MockRef"; 308*89c4ff92SAndroid Build Coastguard Worker FAIL(message); 309*89c4ff92SAndroid Build Coastguard Worker } 310*89c4ff92SAndroid Build Coastguard Worker 311*89c4ff92SAndroid Build Coastguard Worker // Create runtime in which test will run and allow fallback to CpuRef. 312*89c4ff92SAndroid Build Coastguard Worker IRuntime::CreationOptions options; 313*89c4ff92SAndroid Build Coastguard Worker IRuntimePtr runtime(IRuntime::Create(options)); 314*89c4ff92SAndroid Build Coastguard Worker 315*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network. 316*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net(INetwork::Create()); 317*89c4ff92SAndroid Build Coastguard Worker 318*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input0 = net->AddInputLayer(0, "input0"); 319*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input1 = net->AddInputLayer(1, "input1"); 320*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input2 = net->AddInputLayer(2, "input2"); 321*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* sub = net->AddElementwiseBinaryLayer(BinaryOperation::Sub, "sub"); 322*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* add = net->AddElementwiseBinaryLayer(BinaryOperation::Add, "add"); 323*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* output = net->AddOutputLayer(0, "output"); 324*89c4ff92SAndroid Build Coastguard Worker 325*89c4ff92SAndroid Build Coastguard Worker input0->GetOutputSlot(0).Connect(sub->GetInputSlot(0)); 326*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot(0).Connect(sub->GetInputSlot(1)); 327*89c4ff92SAndroid Build Coastguard Worker input2->GetOutputSlot(0).Connect(add->GetInputSlot(0)); 328*89c4ff92SAndroid Build Coastguard Worker sub->GetOutputSlot(0).Connect(add->GetInputSlot(1)); 329*89c4ff92SAndroid Build Coastguard Worker add->GetOutputSlot(0).Connect(output->GetInputSlot(0)); 330*89c4ff92SAndroid Build Coastguard Worker 331*89c4ff92SAndroid Build Coastguard Worker TensorInfo info = TensorInfo({ 1, 2, 3, 2 }, DataType::Float32); 332*89c4ff92SAndroid Build Coastguard Worker 333*89c4ff92SAndroid Build Coastguard Worker input0->GetOutputSlot(0).SetTensorInfo(info); 334*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot(0).SetTensorInfo(info); 335*89c4ff92SAndroid Build Coastguard Worker input2->GetOutputSlot(0).SetTensorInfo(info); 336*89c4ff92SAndroid Build Coastguard Worker sub->GetOutputSlot(0).SetTensorInfo(info); 337*89c4ff92SAndroid Build Coastguard Worker add->GetOutputSlot(0).SetTensorInfo(info); 338*89c4ff92SAndroid Build Coastguard Worker 339*89c4ff92SAndroid Build Coastguard Worker // optimize the network 340*89c4ff92SAndroid Build Coastguard Worker std::vector<BackendId> backends = { "MockRef", Compute::CpuAcc }; 341*89c4ff92SAndroid Build Coastguard Worker OptimizerOptionsOpaque optOptions; 342*89c4ff92SAndroid Build Coastguard Worker optOptions.SetImportEnabled(true); 343*89c4ff92SAndroid Build Coastguard Worker optOptions.SetExportEnabled(true); 344*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optOptions); 345*89c4ff92SAndroid Build Coastguard Worker 346*89c4ff92SAndroid Build Coastguard Worker Graph& graph = GetGraphForTesting(optNet.get()); 347*89c4ff92SAndroid Build Coastguard Worker 348*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer0 = GetFirstLayerWithName(graph, "input0"); 349*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer1 = GetFirstLayerWithName(graph, "input1"); 350*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer2 = GetFirstLayerWithName(graph, "input2"); 351*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer3 = GetFirstLayerWithName(graph, "sub"); 352*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer4 = GetFirstLayerWithName(graph, "[ sub (0) -> add (1) ]"); 353*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer5 = GetFirstLayerWithName(graph, "add"); 354*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer6 = GetFirstLayerWithName(graph, "output"); 355*89c4ff92SAndroid Build Coastguard Worker 356*89c4ff92SAndroid Build Coastguard Worker // Checks order is valid. 357*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer0, layer1)); 358*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer1, layer2)); 359*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer2, layer3)); 360*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer3, layer4)); 361*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer4, layer5)); 362*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer5, layer6)); 363*89c4ff92SAndroid Build Coastguard Worker 364*89c4ff92SAndroid Build Coastguard Worker // Load it into the runtime. It should pass. 365*89c4ff92SAndroid Build Coastguard Worker NetworkId netId; 366*89c4ff92SAndroid Build Coastguard Worker std::string ignoredErrorMessage; 367*89c4ff92SAndroid Build Coastguard Worker 368*89c4ff92SAndroid Build Coastguard Worker INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc); 369*89c4ff92SAndroid Build Coastguard Worker runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties); 370*89c4ff92SAndroid Build Coastguard Worker 371*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output 372*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData0 373*89c4ff92SAndroid Build Coastguard Worker { 374*89c4ff92SAndroid Build Coastguard Worker 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 0.0f 375*89c4ff92SAndroid Build Coastguard Worker }; 376*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData1 377*89c4ff92SAndroid Build Coastguard Worker { 378*89c4ff92SAndroid Build Coastguard Worker 0.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f 379*89c4ff92SAndroid Build Coastguard Worker }; 380*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData2 381*89c4ff92SAndroid Build Coastguard Worker { 382*89c4ff92SAndroid Build Coastguard Worker 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f 383*89c4ff92SAndroid Build Coastguard Worker }; 384*89c4ff92SAndroid Build Coastguard Worker 385*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputData(12); 386*89c4ff92SAndroid Build Coastguard Worker 387*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutput 388*89c4ff92SAndroid Build Coastguard Worker { 389*89c4ff92SAndroid Build Coastguard Worker 13.0f, 11.0f, 11.0f, 9.0f, 7.0f, 7.0f, 7.0f, 5.0f, 5.0f, 3.0f, 3.0f, -5.0f 390*89c4ff92SAndroid Build Coastguard Worker }; 391*89c4ff92SAndroid Build Coastguard Worker 392*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo0 = runtime->GetInputTensorInfo(netId, 0); 393*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo1 = runtime->GetInputTensorInfo(netId, 1); 394*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo2 = runtime->GetInputTensorInfo(netId, 2); 395*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo0.SetConstant(true); 396*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo1.SetConstant(true); 397*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo2.SetConstant(true); 398*89c4ff92SAndroid Build Coastguard Worker 399*89c4ff92SAndroid Build Coastguard Worker InputTensors inputTensors 400*89c4ff92SAndroid Build Coastguard Worker { 401*89c4ff92SAndroid Build Coastguard Worker { 0, armnn::ConstTensor(inputTensorInfo0, inputData0.data()) }, 402*89c4ff92SAndroid Build Coastguard Worker { 1, armnn::ConstTensor(inputTensorInfo1, inputData1.data()) }, 403*89c4ff92SAndroid Build Coastguard Worker { 2, armnn::ConstTensor(inputTensorInfo2, inputData2.data()) } 404*89c4ff92SAndroid Build Coastguard Worker }; 405*89c4ff92SAndroid Build Coastguard Worker OutputTensors outputTensors 406*89c4ff92SAndroid Build Coastguard Worker { 407*89c4ff92SAndroid Build Coastguard Worker { 0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data()) } 408*89c4ff92SAndroid Build Coastguard Worker }; 409*89c4ff92SAndroid Build Coastguard Worker 410*89c4ff92SAndroid Build Coastguard Worker runtime->GetProfiler(netId)->EnableProfiling(true); 411*89c4ff92SAndroid Build Coastguard Worker 412*89c4ff92SAndroid Build Coastguard Worker // Do the inference 413*89c4ff92SAndroid Build Coastguard Worker runtime->EnqueueWorkload(netId, inputTensors, outputTensors); 414*89c4ff92SAndroid Build Coastguard Worker 415*89c4ff92SAndroid Build Coastguard Worker // Retrieve the Profiler.Print() output to get the workload execution 416*89c4ff92SAndroid Build Coastguard Worker ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance(); 417*89c4ff92SAndroid Build Coastguard Worker std::stringstream ss; 418*89c4ff92SAndroid Build Coastguard Worker profilerManager.GetProfiler()->Print(ss);; 419*89c4ff92SAndroid Build Coastguard Worker std::string dump = ss.str(); 420*89c4ff92SAndroid Build Coastguard Worker 421*89c4ff92SAndroid Build Coastguard Worker // Contains ImportMemGeneric 422*89c4ff92SAndroid Build Coastguard Worker std::size_t found = dump.find("ImportMemGeneric"); 423*89c4ff92SAndroid Build Coastguard Worker CHECK(found != std::string::npos); 424*89c4ff92SAndroid Build Coastguard Worker 425*89c4ff92SAndroid Build Coastguard Worker // Contains SyncMemGeneric 426*89c4ff92SAndroid Build Coastguard Worker found = dump.find("SyncMemGeneric"); 427*89c4ff92SAndroid Build Coastguard Worker CHECK(found != std::string::npos); 428*89c4ff92SAndroid Build Coastguard Worker 429*89c4ff92SAndroid Build Coastguard Worker // Does not contain CopyMemGeneric 430*89c4ff92SAndroid Build Coastguard Worker found = dump.find("CopyMemGeneric"); 431*89c4ff92SAndroid Build Coastguard Worker CHECK(found == std::string::npos); 432*89c4ff92SAndroid Build Coastguard Worker 433*89c4ff92SAndroid Build Coastguard Worker // Use memory import between backends 434*89c4ff92SAndroid Build Coastguard Worker CHECK((layer4->GetType() == LayerType::MemImport)); 435*89c4ff92SAndroid Build Coastguard Worker 436*89c4ff92SAndroid Build Coastguard Worker // Check output is as expected 437*89c4ff92SAndroid Build Coastguard Worker CHECK(outputData == expectedOutput); 438*89c4ff92SAndroid Build Coastguard Worker } 439*89c4ff92SAndroid Build Coastguard Worker 440*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FallbackPaddingCopyFromCpuAcc") 441*89c4ff92SAndroid Build Coastguard Worker { 442*89c4ff92SAndroid Build Coastguard Worker using namespace armnn; 443*89c4ff92SAndroid Build Coastguard Worker 444*89c4ff92SAndroid Build Coastguard Worker // Create a mock backend object 445*89c4ff92SAndroid Build Coastguard Worker MockImportBackendInitialiser initialiser; // Register the Mock Backend 446*89c4ff92SAndroid Build Coastguard Worker auto backendObjPtr = CreateBackendObject(MockImportBackendId()); 447*89c4ff92SAndroid Build Coastguard Worker CHECK((backendObjPtr != nullptr)); 448*89c4ff92SAndroid Build Coastguard Worker 449*89c4ff92SAndroid Build Coastguard Worker BackendIdSet backendIds = BackendRegistryInstance().GetBackendIds(); 450*89c4ff92SAndroid Build Coastguard Worker if (backendIds.find("MockRef") == backendIds.end()) 451*89c4ff92SAndroid Build Coastguard Worker { 452*89c4ff92SAndroid Build Coastguard Worker std::string message = "Cannot load MockRef"; 453*89c4ff92SAndroid Build Coastguard Worker FAIL(message); 454*89c4ff92SAndroid Build Coastguard Worker } 455*89c4ff92SAndroid Build Coastguard Worker 456*89c4ff92SAndroid Build Coastguard Worker // Create runtime in which test will run and allow fallback to CpuRef. 457*89c4ff92SAndroid Build Coastguard Worker IRuntime::CreationOptions options; 458*89c4ff92SAndroid Build Coastguard Worker IRuntimePtr runtime(IRuntime::Create(options)); 459*89c4ff92SAndroid Build Coastguard Worker 460*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network. 461*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net(INetwork::Create()); 462*89c4ff92SAndroid Build Coastguard Worker 463*89c4ff92SAndroid Build Coastguard Worker Pooling2dDescriptor desc; 464*89c4ff92SAndroid Build Coastguard Worker 465*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input0 = net->AddInputLayer(0, "input0"); 466*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input1 = net->AddInputLayer(1, "input1"); 467*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* pooling = net->AddPooling2dLayer(desc, "pooling"); 468*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* add = net->AddElementwiseBinaryLayer(BinaryOperation::Add, "add"); 469*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* output = net->AddOutputLayer(0, "output"); 470*89c4ff92SAndroid Build Coastguard Worker 471*89c4ff92SAndroid Build Coastguard Worker input0->GetOutputSlot(0).Connect(pooling->GetInputSlot(0)); 472*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot(0).Connect(add->GetInputSlot(1)); 473*89c4ff92SAndroid Build Coastguard Worker pooling->GetOutputSlot(0).Connect(add->GetInputSlot(0)); 474*89c4ff92SAndroid Build Coastguard Worker add->GetOutputSlot(0).Connect(output->GetInputSlot(0)); 475*89c4ff92SAndroid Build Coastguard Worker 476*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputInfo = TensorInfo({ 1, 2, 3, 2 }, DataType::Float32); 477*89c4ff92SAndroid Build Coastguard Worker TensorInfo poolingInfo = TensorInfo({ 1, 2, 1, 1 }, DataType::Float32); 478*89c4ff92SAndroid Build Coastguard Worker 479*89c4ff92SAndroid Build Coastguard Worker input0->GetOutputSlot(0).SetTensorInfo(inputInfo); 480*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot(0).SetTensorInfo(poolingInfo); 481*89c4ff92SAndroid Build Coastguard Worker pooling->GetOutputSlot(0).SetTensorInfo(poolingInfo); 482*89c4ff92SAndroid Build Coastguard Worker add->GetOutputSlot(0).SetTensorInfo(poolingInfo); 483*89c4ff92SAndroid Build Coastguard Worker 484*89c4ff92SAndroid Build Coastguard Worker // optimize the network 485*89c4ff92SAndroid Build Coastguard Worker std::vector<BackendId> backends = { "MockRef", Compute::CpuAcc }; 486*89c4ff92SAndroid Build Coastguard Worker OptimizerOptionsOpaque optOptions; 487*89c4ff92SAndroid Build Coastguard Worker optOptions.SetImportEnabled(true); 488*89c4ff92SAndroid Build Coastguard Worker optOptions.SetExportEnabled(true); 489*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optOptions); 490*89c4ff92SAndroid Build Coastguard Worker 491*89c4ff92SAndroid Build Coastguard Worker Graph& graph = GetGraphForTesting(optNet.get()); 492*89c4ff92SAndroid Build Coastguard Worker 493*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer0 = GetFirstLayerWithName(graph, "input0"); 494*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer1 = GetFirstLayerWithName(graph, "input1"); 495*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer2 = GetFirstLayerWithName(graph, "pooling"); 496*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer3 = GetFirstLayerWithName(graph, "[ pooling (0) -> add (0) ]"); 497*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer4 = GetFirstLayerWithName(graph, "add"); 498*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer5 = GetFirstLayerWithName(graph, "output"); 499*89c4ff92SAndroid Build Coastguard Worker 500*89c4ff92SAndroid Build Coastguard Worker // Checks order is valid. 501*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer0, layer1)); 502*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer1, layer2)); 503*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer2, layer3)); 504*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer3, layer4)); 505*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer4, layer5)); 506*89c4ff92SAndroid Build Coastguard Worker 507*89c4ff92SAndroid Build Coastguard Worker // Load it into the runtime. It should pass. 508*89c4ff92SAndroid Build Coastguard Worker NetworkId netId; 509*89c4ff92SAndroid Build Coastguard Worker std::string ignoredErrorMessage; 510*89c4ff92SAndroid Build Coastguard Worker INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc); 511*89c4ff92SAndroid Build Coastguard Worker 512*89c4ff92SAndroid Build Coastguard Worker runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties); 513*89c4ff92SAndroid Build Coastguard Worker 514*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output 515*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData0 516*89c4ff92SAndroid Build Coastguard Worker { 517*89c4ff92SAndroid Build Coastguard Worker 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f 518*89c4ff92SAndroid Build Coastguard Worker }; 519*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData1 520*89c4ff92SAndroid Build Coastguard Worker { 521*89c4ff92SAndroid Build Coastguard Worker -1.0f, 3.0f 522*89c4ff92SAndroid Build Coastguard Worker }; 523*89c4ff92SAndroid Build Coastguard Worker 524*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputData(2); 525*89c4ff92SAndroid Build Coastguard Worker 526*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutput 527*89c4ff92SAndroid Build Coastguard Worker { 528*89c4ff92SAndroid Build Coastguard Worker 5.0f, 15.0f 529*89c4ff92SAndroid Build Coastguard Worker }; 530*89c4ff92SAndroid Build Coastguard Worker 531*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo0 = runtime->GetInputTensorInfo(netId, 0); 532*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo1 = runtime->GetInputTensorInfo(netId, 1); 533*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo0.SetConstant(true); 534*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo1.SetConstant(true); 535*89c4ff92SAndroid Build Coastguard Worker 536*89c4ff92SAndroid Build Coastguard Worker InputTensors inputTensors 537*89c4ff92SAndroid Build Coastguard Worker { 538*89c4ff92SAndroid Build Coastguard Worker { 0, armnn::ConstTensor(inputTensorInfo0, inputData0.data()) }, 539*89c4ff92SAndroid Build Coastguard Worker { 1, armnn::ConstTensor(inputTensorInfo1, inputData1.data()) } 540*89c4ff92SAndroid Build Coastguard Worker }; 541*89c4ff92SAndroid Build Coastguard Worker OutputTensors outputTensors 542*89c4ff92SAndroid Build Coastguard Worker { 543*89c4ff92SAndroid Build Coastguard Worker { 0, armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data()) } 544*89c4ff92SAndroid Build Coastguard Worker }; 545*89c4ff92SAndroid Build Coastguard Worker 546*89c4ff92SAndroid Build Coastguard Worker runtime->GetProfiler(netId)->EnableProfiling(true); 547*89c4ff92SAndroid Build Coastguard Worker 548*89c4ff92SAndroid Build Coastguard Worker // Do the inference 549*89c4ff92SAndroid Build Coastguard Worker runtime->EnqueueWorkload(netId, inputTensors, outputTensors); 550*89c4ff92SAndroid Build Coastguard Worker 551*89c4ff92SAndroid Build Coastguard Worker // Retrieve the Profiler.Print() output to get the workload execution 552*89c4ff92SAndroid Build Coastguard Worker ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance(); 553*89c4ff92SAndroid Build Coastguard Worker std::stringstream ss; 554*89c4ff92SAndroid Build Coastguard Worker profilerManager.GetProfiler()->Print(ss);; 555*89c4ff92SAndroid Build Coastguard Worker std::string dump = ss.str(); 556*89c4ff92SAndroid Build Coastguard Worker 557*89c4ff92SAndroid Build Coastguard Worker // Contains CopyMemGeneric between the backends 558*89c4ff92SAndroid Build Coastguard Worker std::size_t found = dump.find("CopyMemGeneric"); 559*89c4ff92SAndroid Build Coastguard Worker CHECK(found != std::string::npos); 560*89c4ff92SAndroid Build Coastguard Worker 561*89c4ff92SAndroid Build Coastguard Worker // Contains SyncMemGeneric for the output 562*89c4ff92SAndroid Build Coastguard Worker found = dump.find("SyncMemGeneric"); 563*89c4ff92SAndroid Build Coastguard Worker CHECK(found != std::string::npos); 564*89c4ff92SAndroid Build Coastguard Worker 565*89c4ff92SAndroid Build Coastguard Worker // Does not contain ImportMemGeneric 566*89c4ff92SAndroid Build Coastguard Worker found = dump.find("ImportMemGeneric"); 567*89c4ff92SAndroid Build Coastguard Worker CHECK(found == std::string::npos); 568*89c4ff92SAndroid Build Coastguard Worker 569*89c4ff92SAndroid Build Coastguard Worker // Use memory import between backends 570*89c4ff92SAndroid Build Coastguard Worker CHECK((layer3->GetType() == LayerType::MemCopy)); 571*89c4ff92SAndroid Build Coastguard Worker 572*89c4ff92SAndroid Build Coastguard Worker // Check output is as expected 573*89c4ff92SAndroid Build Coastguard Worker CHECK(outputData == expectedOutput); 574*89c4ff92SAndroid Build Coastguard Worker } 575*89c4ff92SAndroid Build Coastguard Worker 576*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FallbackDisableImportFromCpuAcc") 577*89c4ff92SAndroid Build Coastguard Worker { 578*89c4ff92SAndroid Build Coastguard Worker using namespace armnn; 579*89c4ff92SAndroid Build Coastguard Worker 580*89c4ff92SAndroid Build Coastguard Worker // Create a mock backend object 581*89c4ff92SAndroid Build Coastguard Worker MockImportBackendInitialiser initialiser; // Register the Mock Backend 582*89c4ff92SAndroid Build Coastguard Worker auto backendObjPtr = CreateBackendObject(MockImportBackendId()); 583*89c4ff92SAndroid Build Coastguard Worker CHECK((backendObjPtr != nullptr)); 584*89c4ff92SAndroid Build Coastguard Worker 585*89c4ff92SAndroid Build Coastguard Worker BackendIdSet backendIds = BackendRegistryInstance().GetBackendIds(); 586*89c4ff92SAndroid Build Coastguard Worker if (backendIds.find("MockRef") == backendIds.end()) 587*89c4ff92SAndroid Build Coastguard Worker { 588*89c4ff92SAndroid Build Coastguard Worker std::string message = "Cannot load MockRef"; 589*89c4ff92SAndroid Build Coastguard Worker FAIL(message); 590*89c4ff92SAndroid Build Coastguard Worker } 591*89c4ff92SAndroid Build Coastguard Worker 592*89c4ff92SAndroid Build Coastguard Worker // Create runtime in which test will run and allow fallback to CpuRef. 593*89c4ff92SAndroid Build Coastguard Worker IRuntime::CreationOptions options; 594*89c4ff92SAndroid Build Coastguard Worker IRuntimePtr runtime(IRuntime::Create(options)); 595*89c4ff92SAndroid Build Coastguard Worker 596*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network. 597*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net(INetwork::Create()); 598*89c4ff92SAndroid Build Coastguard Worker 599*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input0 = net->AddInputLayer(0, "input0"); 600*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input1 = net->AddInputLayer(1, "input1"); 601*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input2 = net->AddInputLayer(2, "input2"); 602*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* sub = net->AddElementwiseBinaryLayer(BinaryOperation::Sub, "sub"); 603*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* add = net->AddElementwiseBinaryLayer(BinaryOperation::Add, "add"); 604*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* output = net->AddOutputLayer(0, "output"); 605*89c4ff92SAndroid Build Coastguard Worker 606*89c4ff92SAndroid Build Coastguard Worker input0->GetOutputSlot(0).Connect(sub->GetInputSlot(0)); 607*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot(0).Connect(sub->GetInputSlot(1)); 608*89c4ff92SAndroid Build Coastguard Worker input2->GetOutputSlot(0).Connect(add->GetInputSlot(0)); 609*89c4ff92SAndroid Build Coastguard Worker sub->GetOutputSlot(0).Connect(add->GetInputSlot(1)); 610*89c4ff92SAndroid Build Coastguard Worker add->GetOutputSlot(0).Connect(output->GetInputSlot(0)); 611*89c4ff92SAndroid Build Coastguard Worker 612*89c4ff92SAndroid Build Coastguard Worker TensorInfo info = TensorInfo({ 1, 2, 3, 2 }, DataType::Float32); 613*89c4ff92SAndroid Build Coastguard Worker 614*89c4ff92SAndroid Build Coastguard Worker input0->GetOutputSlot(0).SetTensorInfo(info); 615*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot(0).SetTensorInfo(info); 616*89c4ff92SAndroid Build Coastguard Worker input2->GetOutputSlot(0).SetTensorInfo(info); 617*89c4ff92SAndroid Build Coastguard Worker sub->GetOutputSlot(0).SetTensorInfo(info); 618*89c4ff92SAndroid Build Coastguard Worker add->GetOutputSlot(0).SetTensorInfo(info); 619*89c4ff92SAndroid Build Coastguard Worker 620*89c4ff92SAndroid Build Coastguard Worker // optimize the network 621*89c4ff92SAndroid Build Coastguard Worker std::vector<BackendId> backends = { "MockRef", Compute::CpuAcc }; 622*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec()); 623*89c4ff92SAndroid Build Coastguard Worker 624*89c4ff92SAndroid Build Coastguard Worker Graph& graph = GetGraphForTesting(optNet.get()); 625*89c4ff92SAndroid Build Coastguard Worker 626*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer0 = GetFirstLayerWithName(graph, "input0"); 627*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer1 = GetFirstLayerWithName(graph, "input1"); 628*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer2 = GetFirstLayerWithName(graph, "input2"); 629*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer3 = GetFirstLayerWithName(graph, "sub"); 630*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer4 = GetFirstLayerWithName(graph, "[ sub (0) -> add (1) ]"); 631*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer5 = GetFirstLayerWithName(graph, "add"); 632*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer6 = GetFirstLayerWithName(graph, "output"); 633*89c4ff92SAndroid Build Coastguard Worker 634*89c4ff92SAndroid Build Coastguard Worker // Checks order is valid. 635*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer0, layer1)); 636*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer1, layer2)); 637*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer2, layer3)); 638*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer3, layer4)); 639*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer4, layer5)); 640*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer5, layer6)); 641*89c4ff92SAndroid Build Coastguard Worker 642*89c4ff92SAndroid Build Coastguard Worker // Load it into the runtime. It should pass. 643*89c4ff92SAndroid Build Coastguard Worker NetworkId netId; 644*89c4ff92SAndroid Build Coastguard Worker std::string ignoredErrorMessage; 645*89c4ff92SAndroid Build Coastguard Worker INetworkProperties networkProperties(false, MemorySource::Undefined, MemorySource::Undefined); 646*89c4ff92SAndroid Build Coastguard Worker 647*89c4ff92SAndroid Build Coastguard Worker runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties); 648*89c4ff92SAndroid Build Coastguard Worker 649*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output 650*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData0 651*89c4ff92SAndroid Build Coastguard Worker { 652*89c4ff92SAndroid Build Coastguard Worker 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 0.0f 653*89c4ff92SAndroid Build Coastguard Worker }; 654*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData1 655*89c4ff92SAndroid Build Coastguard Worker { 656*89c4ff92SAndroid Build Coastguard Worker 0.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f 657*89c4ff92SAndroid Build Coastguard Worker }; 658*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData2 659*89c4ff92SAndroid Build Coastguard Worker { 660*89c4ff92SAndroid Build Coastguard Worker 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f 661*89c4ff92SAndroid Build Coastguard Worker }; 662*89c4ff92SAndroid Build Coastguard Worker 663*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputData(12); 664*89c4ff92SAndroid Build Coastguard Worker 665*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutput 666*89c4ff92SAndroid Build Coastguard Worker { 667*89c4ff92SAndroid Build Coastguard Worker 13.0f, 11.0f, 11.0f, 9.0f, 7.0f, 7.0f, 7.0f, 5.0f, 5.0f, 3.0f, 3.0f, -5.0f 668*89c4ff92SAndroid Build Coastguard Worker }; 669*89c4ff92SAndroid Build Coastguard Worker 670*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo0 = runtime->GetInputTensorInfo(netId, 0); 671*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo1 = runtime->GetInputTensorInfo(netId, 1); 672*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo2 = runtime->GetInputTensorInfo(netId, 2); 673*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo0.SetConstant(true); 674*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo1.SetConstant(true); 675*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo2.SetConstant(true); 676*89c4ff92SAndroid Build Coastguard Worker 677*89c4ff92SAndroid Build Coastguard Worker InputTensors inputTensors 678*89c4ff92SAndroid Build Coastguard Worker { 679*89c4ff92SAndroid Build Coastguard Worker { 0, armnn::ConstTensor(inputTensorInfo0, inputData0.data()) }, 680*89c4ff92SAndroid Build Coastguard Worker { 1, armnn::ConstTensor(inputTensorInfo1, inputData1.data()) }, 681*89c4ff92SAndroid Build Coastguard Worker { 2, armnn::ConstTensor(inputTensorInfo2, inputData2.data()) } 682*89c4ff92SAndroid Build Coastguard Worker }; 683*89c4ff92SAndroid Build Coastguard Worker OutputTensors outputTensors 684*89c4ff92SAndroid Build Coastguard Worker { 685*89c4ff92SAndroid Build Coastguard Worker { 0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data()) } 686*89c4ff92SAndroid Build Coastguard Worker }; 687*89c4ff92SAndroid Build Coastguard Worker 688*89c4ff92SAndroid Build Coastguard Worker runtime->GetProfiler(netId)->EnableProfiling(true); 689*89c4ff92SAndroid Build Coastguard Worker 690*89c4ff92SAndroid Build Coastguard Worker // Do the inference 691*89c4ff92SAndroid Build Coastguard Worker runtime->EnqueueWorkload(netId, inputTensors, outputTensors); 692*89c4ff92SAndroid Build Coastguard Worker 693*89c4ff92SAndroid Build Coastguard Worker // Retrieve the Profiler.Print() output to get the workload execution 694*89c4ff92SAndroid Build Coastguard Worker ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance(); 695*89c4ff92SAndroid Build Coastguard Worker std::stringstream ss; 696*89c4ff92SAndroid Build Coastguard Worker profilerManager.GetProfiler()->Print(ss);; 697*89c4ff92SAndroid Build Coastguard Worker std::string dump = ss.str(); 698*89c4ff92SAndroid Build Coastguard Worker 699*89c4ff92SAndroid Build Coastguard Worker // Contains CopyMemGeneric between the backends 700*89c4ff92SAndroid Build Coastguard Worker std::size_t found = dump.find("CopyMemGeneric"); 701*89c4ff92SAndroid Build Coastguard Worker CHECK(found != std::string::npos); 702*89c4ff92SAndroid Build Coastguard Worker 703*89c4ff92SAndroid Build Coastguard Worker // Does not contain ImportMemGeneric 704*89c4ff92SAndroid Build Coastguard Worker found = dump.find("ImportMemGeneric"); 705*89c4ff92SAndroid Build Coastguard Worker CHECK(found == std::string::npos); 706*89c4ff92SAndroid Build Coastguard Worker 707*89c4ff92SAndroid Build Coastguard Worker // Use memory import between backends 708*89c4ff92SAndroid Build Coastguard Worker CHECK((layer4->GetType() == LayerType::MemCopy)); 709*89c4ff92SAndroid Build Coastguard Worker 710*89c4ff92SAndroid Build Coastguard Worker // Check output is as expected 711*89c4ff92SAndroid Build Coastguard Worker CHECK(outputData == expectedOutput); 712*89c4ff92SAndroid Build Coastguard Worker } 713*89c4ff92SAndroid Build Coastguard Worker 714*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTECL_ENABLED) 715*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("NeonImportEnabledFallbackToCl") 716*89c4ff92SAndroid Build Coastguard Worker { 717*89c4ff92SAndroid Build Coastguard Worker using namespace armnn; 718*89c4ff92SAndroid Build Coastguard Worker 719*89c4ff92SAndroid Build Coastguard Worker IRuntime::CreationOptions options; 720*89c4ff92SAndroid Build Coastguard Worker IRuntimePtr runtime(IRuntime::Create(options)); 721*89c4ff92SAndroid Build Coastguard Worker 722*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network. 723*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net(INetwork::Create()); 724*89c4ff92SAndroid Build Coastguard Worker 725*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input0 = net->AddInputLayer(0, "input0"); 726*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input1 = net->AddInputLayer(1, "input1"); 727*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input2 = net->AddInputLayer(2, "input2"); 728*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* add = net->AddElementwiseBinaryLayer(BinaryOperation::Add, "add"); 729*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* sub = net->AddElementwiseBinaryLayer(BinaryOperation::Sub, "sub"); 730*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* output = net->AddOutputLayer(0, "output"); 731*89c4ff92SAndroid Build Coastguard Worker 732*89c4ff92SAndroid Build Coastguard Worker input0->GetOutputSlot(0).Connect(add->GetInputSlot(0)); 733*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot(0).Connect(add->GetInputSlot(1)); 734*89c4ff92SAndroid Build Coastguard Worker input2->GetOutputSlot(0).Connect(sub->GetInputSlot(0)); 735*89c4ff92SAndroid Build Coastguard Worker add->GetOutputSlot(0).Connect(sub->GetInputSlot(1)); 736*89c4ff92SAndroid Build Coastguard Worker sub->GetOutputSlot(0).Connect(output->GetInputSlot(0)); 737*89c4ff92SAndroid Build Coastguard Worker 738*89c4ff92SAndroid Build Coastguard Worker TensorInfo info = TensorInfo({ 1, 2, 4, 2 }, DataType::Float32); 739*89c4ff92SAndroid Build Coastguard Worker 740*89c4ff92SAndroid Build Coastguard Worker input0->GetOutputSlot(0).SetTensorInfo(info); 741*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot(0).SetTensorInfo(info); 742*89c4ff92SAndroid Build Coastguard Worker input2->GetOutputSlot(0).SetTensorInfo(info); 743*89c4ff92SAndroid Build Coastguard Worker add->GetOutputSlot(0).SetTensorInfo(info); 744*89c4ff92SAndroid Build Coastguard Worker sub->GetOutputSlot(0).SetTensorInfo(info); 745*89c4ff92SAndroid Build Coastguard Worker 746*89c4ff92SAndroid Build Coastguard Worker std::vector<BackendId> backends = { Compute::CpuAcc, Compute::GpuAcc }; 747*89c4ff92SAndroid Build Coastguard Worker // Use BackendSelectionHint to specify GpuAcc for Subtraction layer 748*89c4ff92SAndroid Build Coastguard Worker sub->BackendSelectionHint(backends[1]); 749*89c4ff92SAndroid Build Coastguard Worker 750*89c4ff92SAndroid Build Coastguard Worker // optimize the network 751*89c4ff92SAndroid Build Coastguard Worker OptimizerOptionsOpaque optOptions; 752*89c4ff92SAndroid Build Coastguard Worker optOptions.SetImportEnabled(true); 753*89c4ff92SAndroid Build Coastguard Worker optOptions.SetExportEnabled(true); 754*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optOptions); 755*89c4ff92SAndroid Build Coastguard Worker 756*89c4ff92SAndroid Build Coastguard Worker Graph& graph = GetGraphForTesting(optNet.get()); 757*89c4ff92SAndroid Build Coastguard Worker 758*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer0 = GetFirstLayerWithName(graph, "input0"); 759*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer1 = GetFirstLayerWithName(graph, "input1"); 760*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer2 = GetFirstLayerWithName(graph, "input2"); 761*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer3 = GetFirstLayerWithName(graph, "add"); 762*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer4 = GetFirstLayerWithName(graph, "[ add (0) -> sub (1) ]"); 763*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer5 = GetFirstLayerWithName(graph, "sub"); 764*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer6 = GetFirstLayerWithName(graph, "output"); 765*89c4ff92SAndroid Build Coastguard Worker 766*89c4ff92SAndroid Build Coastguard Worker // Checks order is valid. 767*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer0, layer1)); 768*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer1, layer2)); 769*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer2, layer3)); 770*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer3, layer4)); 771*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer4, layer5)); 772*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer5, layer6)); 773*89c4ff92SAndroid Build Coastguard Worker 774*89c4ff92SAndroid Build Coastguard Worker // Use memory import between backends 775*89c4ff92SAndroid Build Coastguard Worker CHECK((layer4->GetType() == LayerType::MemCopy)); 776*89c4ff92SAndroid Build Coastguard Worker 777*89c4ff92SAndroid Build Coastguard Worker // Correctly use backend hint 778*89c4ff92SAndroid Build Coastguard Worker CHECK((layer5->GetBackendId() == Compute::GpuAcc )); 779*89c4ff92SAndroid Build Coastguard Worker 780*89c4ff92SAndroid Build Coastguard Worker // Load it into the runtime. It should pass. 781*89c4ff92SAndroid Build Coastguard Worker NetworkId netId; 782*89c4ff92SAndroid Build Coastguard Worker std::string ignoredErrorMessage; 783*89c4ff92SAndroid Build Coastguard Worker 784*89c4ff92SAndroid Build Coastguard Worker INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc); 785*89c4ff92SAndroid Build Coastguard Worker 786*89c4ff92SAndroid Build Coastguard Worker runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties); 787*89c4ff92SAndroid Build Coastguard Worker 788*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output 789*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData0 790*89c4ff92SAndroid Build Coastguard Worker { 791*89c4ff92SAndroid Build Coastguard Worker 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 6.0f, 1.0f, 1.0f, 2.0f, 2.0f 792*89c4ff92SAndroid Build Coastguard Worker }; 793*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData1 794*89c4ff92SAndroid Build Coastguard Worker { 795*89c4ff92SAndroid Build Coastguard Worker 0.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 0.0f, 1.0f, 1.0f, 2.0f 796*89c4ff92SAndroid Build Coastguard Worker }; 797*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData2 798*89c4ff92SAndroid Build Coastguard Worker { 799*89c4ff92SAndroid Build Coastguard Worker 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 12.0f, 11.0f, 10.0f, 9.0f 800*89c4ff92SAndroid Build Coastguard Worker }; 801*89c4ff92SAndroid Build Coastguard Worker 802*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputData(16); 803*89c4ff92SAndroid Build Coastguard Worker 804*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutput 805*89c4ff92SAndroid Build Coastguard Worker { 806*89c4ff92SAndroid Build Coastguard Worker 11.0f, 9.0f, 7.0f, 5.0f, 3.0f, 1.0f, -1.0f, -3.0f, -5.0f, -7.0f, -9.0f, -11.0f, 11.0f, 9.0f, 7.0f, 5.0f 807*89c4ff92SAndroid Build Coastguard Worker }; 808*89c4ff92SAndroid Build Coastguard Worker 809*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output 810*89c4ff92SAndroid Build Coastguard Worker unsigned int numElements = info.GetNumElements(); 811*89c4ff92SAndroid Build Coastguard Worker size_t totalBytes = numElements * sizeof(float); 812*89c4ff92SAndroid Build Coastguard Worker 813*89c4ff92SAndroid Build Coastguard Worker // Prepare aligned data 814*89c4ff92SAndroid Build Coastguard Worker const size_t alignment = 64; 815*89c4ff92SAndroid Build Coastguard Worker size_t space = totalBytes + alignment + alignment; 816*89c4ff92SAndroid Build Coastguard Worker auto inputData = std::make_unique<uint8_t[]>(space); 817*89c4ff92SAndroid Build Coastguard Worker void* alignedInputPtr = inputData.get(); 818*89c4ff92SAndroid Build Coastguard Worker CHECK(std::align(alignment, totalBytes, alignedInputPtr, space)); 819*89c4ff92SAndroid Build Coastguard Worker 820*89c4ff92SAndroid Build Coastguard Worker auto* intputPtr = reinterpret_cast<float*>(alignedInputPtr); 821*89c4ff92SAndroid Build Coastguard Worker std::copy(inputData2.begin(), inputData2.end(), intputPtr); 822*89c4ff92SAndroid Build Coastguard Worker 823*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo0 = runtime->GetInputTensorInfo(netId, 0); 824*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo1 = runtime->GetInputTensorInfo(netId, 1); 825*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo2 = runtime->GetInputTensorInfo(netId, 2); 826*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo0.SetConstant(true); 827*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo1.SetConstant(true); 828*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo2.SetConstant(true); 829*89c4ff92SAndroid Build Coastguard Worker 830*89c4ff92SAndroid Build Coastguard Worker InputTensors inputTensors 831*89c4ff92SAndroid Build Coastguard Worker { 832*89c4ff92SAndroid Build Coastguard Worker { 0, armnn::ConstTensor(inputTensorInfo0, inputData0.data()) }, 833*89c4ff92SAndroid Build Coastguard Worker { 1, armnn::ConstTensor(inputTensorInfo1, inputData1.data()) }, 834*89c4ff92SAndroid Build Coastguard Worker { 2, armnn::ConstTensor(inputTensorInfo2, alignedInputPtr) } 835*89c4ff92SAndroid Build Coastguard Worker }; 836*89c4ff92SAndroid Build Coastguard Worker OutputTensors outputTensors 837*89c4ff92SAndroid Build Coastguard Worker { 838*89c4ff92SAndroid Build Coastguard Worker { 0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data()) } 839*89c4ff92SAndroid Build Coastguard Worker }; 840*89c4ff92SAndroid Build Coastguard Worker 841*89c4ff92SAndroid Build Coastguard Worker runtime->GetProfiler(netId)->EnableProfiling(true); 842*89c4ff92SAndroid Build Coastguard Worker 843*89c4ff92SAndroid Build Coastguard Worker // Do the inference 844*89c4ff92SAndroid Build Coastguard Worker runtime->EnqueueWorkload(netId, inputTensors, outputTensors); 845*89c4ff92SAndroid Build Coastguard Worker 846*89c4ff92SAndroid Build Coastguard Worker // Retrieve the Profiler.Print() output to get the workload execution 847*89c4ff92SAndroid Build Coastguard Worker ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance(); 848*89c4ff92SAndroid Build Coastguard Worker std::stringstream ss; 849*89c4ff92SAndroid Build Coastguard Worker profilerManager.GetProfiler()->Print(ss);; 850*89c4ff92SAndroid Build Coastguard Worker std::string dump = ss.str(); 851*89c4ff92SAndroid Build Coastguard Worker 852*89c4ff92SAndroid Build Coastguard Worker // Executed Subtraction using GpuAcc 853*89c4ff92SAndroid Build Coastguard Worker std::size_t found = dump.find("ClSubtractionWorkload_Execute"); 854*89c4ff92SAndroid Build Coastguard Worker CHECK(found != std::string::npos); 855*89c4ff92SAndroid Build Coastguard Worker 856*89c4ff92SAndroid Build Coastguard Worker // Contain CopyMemGeneric 857*89c4ff92SAndroid Build Coastguard Worker found = dump.find("CopyMemGeneric"); 858*89c4ff92SAndroid Build Coastguard Worker CHECK(found != std::string::npos); 859*89c4ff92SAndroid Build Coastguard Worker 860*89c4ff92SAndroid Build Coastguard Worker // Check output is as expected 861*89c4ff92SAndroid Build Coastguard Worker for(unsigned int i = 0; i < numElements; ++i) 862*89c4ff92SAndroid Build Coastguard Worker { 863*89c4ff92SAndroid Build Coastguard Worker CHECK(outputData[i] == expectedOutput[i]); 864*89c4ff92SAndroid Build Coastguard Worker } 865*89c4ff92SAndroid Build Coastguard Worker runtime->UnloadNetwork(netId); 866*89c4ff92SAndroid Build Coastguard Worker } 867*89c4ff92SAndroid Build Coastguard Worker 868*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("NeonImportDisabledFallbackToCl") 869*89c4ff92SAndroid Build Coastguard Worker { 870*89c4ff92SAndroid Build Coastguard Worker using namespace armnn; 871*89c4ff92SAndroid Build Coastguard Worker 872*89c4ff92SAndroid Build Coastguard Worker IRuntime::CreationOptions options; 873*89c4ff92SAndroid Build Coastguard Worker IRuntimePtr runtime(IRuntime::Create(options)); 874*89c4ff92SAndroid Build Coastguard Worker 875*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network. 876*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net(INetwork::Create()); 877*89c4ff92SAndroid Build Coastguard Worker 878*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input0 = net->AddInputLayer(0, "input0"); 879*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input1 = net->AddInputLayer(1, "input1"); 880*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input2 = net->AddInputLayer(2, "input2"); 881*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* add = net->AddElementwiseBinaryLayer(BinaryOperation::Add, "add"); 882*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* sub = net->AddElementwiseBinaryLayer(BinaryOperation::Sub, "sub"); 883*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* output = net->AddOutputLayer(0, "output"); 884*89c4ff92SAndroid Build Coastguard Worker 885*89c4ff92SAndroid Build Coastguard Worker input0->GetOutputSlot(0).Connect(add->GetInputSlot(0)); 886*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot(0).Connect(add->GetInputSlot(1)); 887*89c4ff92SAndroid Build Coastguard Worker input2->GetOutputSlot(0).Connect(sub->GetInputSlot(0)); 888*89c4ff92SAndroid Build Coastguard Worker add->GetOutputSlot(0).Connect(sub->GetInputSlot(1)); 889*89c4ff92SAndroid Build Coastguard Worker sub->GetOutputSlot(0).Connect(output->GetInputSlot(0)); 890*89c4ff92SAndroid Build Coastguard Worker 891*89c4ff92SAndroid Build Coastguard Worker TensorInfo info = TensorInfo({ 1, 2, 3, 2 }, DataType::Float32); 892*89c4ff92SAndroid Build Coastguard Worker 893*89c4ff92SAndroid Build Coastguard Worker input0->GetOutputSlot(0).SetTensorInfo(info); 894*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot(0).SetTensorInfo(info); 895*89c4ff92SAndroid Build Coastguard Worker input2->GetOutputSlot(0).SetTensorInfo(info); 896*89c4ff92SAndroid Build Coastguard Worker add->GetOutputSlot(0).SetTensorInfo(info); 897*89c4ff92SAndroid Build Coastguard Worker sub->GetOutputSlot(0).SetTensorInfo(info); 898*89c4ff92SAndroid Build Coastguard Worker 899*89c4ff92SAndroid Build Coastguard Worker std::vector<BackendId> backends = { Compute::CpuAcc, Compute::GpuAcc }; 900*89c4ff92SAndroid Build Coastguard Worker // Use BackendSelectionHint to specify GpuAcc for Subtraction layer 901*89c4ff92SAndroid Build Coastguard Worker sub->BackendSelectionHint(backends[1]); 902*89c4ff92SAndroid Build Coastguard Worker 903*89c4ff92SAndroid Build Coastguard Worker // optimize the network 904*89c4ff92SAndroid Build Coastguard Worker OptimizerOptionsOpaque optOptions; 905*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optOptions); 906*89c4ff92SAndroid Build Coastguard Worker 907*89c4ff92SAndroid Build Coastguard Worker Graph& graph = GetGraphForTesting(optNet.get()); 908*89c4ff92SAndroid Build Coastguard Worker 909*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer0 = GetFirstLayerWithName(graph, "input0"); 910*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer1 = GetFirstLayerWithName(graph, "input1"); 911*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer2 = GetFirstLayerWithName(graph, "input2"); 912*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer3 = GetFirstLayerWithName(graph, "add"); 913*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer4 = GetFirstLayerWithName(graph, "[ add (0) -> sub (1) ]"); 914*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer5 = GetFirstLayerWithName(graph, "sub"); 915*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer6 = GetFirstLayerWithName(graph, "output"); 916*89c4ff92SAndroid Build Coastguard Worker 917*89c4ff92SAndroid Build Coastguard Worker // Checks order is valid. 918*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer0, layer1)); 919*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer1, layer2)); 920*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer2, layer3)); 921*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer3, layer4)); 922*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer4, layer5)); 923*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer5, layer6)); 924*89c4ff92SAndroid Build Coastguard Worker 925*89c4ff92SAndroid Build Coastguard Worker // Use memory import between backends 926*89c4ff92SAndroid Build Coastguard Worker CHECK((layer4->GetType() == LayerType::MemCopy)); 927*89c4ff92SAndroid Build Coastguard Worker 928*89c4ff92SAndroid Build Coastguard Worker // Correctly use backend hint 929*89c4ff92SAndroid Build Coastguard Worker CHECK((layer5->GetBackendId() == Compute::GpuAcc )); 930*89c4ff92SAndroid Build Coastguard Worker 931*89c4ff92SAndroid Build Coastguard Worker // Load it into the runtime. It should pass. 932*89c4ff92SAndroid Build Coastguard Worker NetworkId netId; 933*89c4ff92SAndroid Build Coastguard Worker runtime->LoadNetwork(netId, std::move(optNet)); 934*89c4ff92SAndroid Build Coastguard Worker 935*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output 936*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData0 937*89c4ff92SAndroid Build Coastguard Worker { 938*89c4ff92SAndroid Build Coastguard Worker 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 6.0f 939*89c4ff92SAndroid Build Coastguard Worker }; 940*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData1 941*89c4ff92SAndroid Build Coastguard Worker { 942*89c4ff92SAndroid Build Coastguard Worker 0.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f 943*89c4ff92SAndroid Build Coastguard Worker }; 944*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData2 945*89c4ff92SAndroid Build Coastguard Worker { 946*89c4ff92SAndroid Build Coastguard Worker 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f 947*89c4ff92SAndroid Build Coastguard Worker }; 948*89c4ff92SAndroid Build Coastguard Worker 949*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputData(12); 950*89c4ff92SAndroid Build Coastguard Worker 951*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutput 952*89c4ff92SAndroid Build Coastguard Worker { 953*89c4ff92SAndroid Build Coastguard Worker 11.0f, 9.0f, 7.0f, 5.0f, 3.0f, 1.0f, -1.0f, -3.0f, -5.0f, -7.0f, -9.0f, -11.0f 954*89c4ff92SAndroid Build Coastguard Worker }; 955*89c4ff92SAndroid Build Coastguard Worker 956*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo0 = runtime->GetInputTensorInfo(netId, 0); 957*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo1 = runtime->GetInputTensorInfo(netId, 1); 958*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo2 = runtime->GetInputTensorInfo(netId, 2); 959*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo0.SetConstant(true); 960*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo1.SetConstant(true); 961*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo2.SetConstant(true); 962*89c4ff92SAndroid Build Coastguard Worker 963*89c4ff92SAndroid Build Coastguard Worker InputTensors inputTensors 964*89c4ff92SAndroid Build Coastguard Worker { 965*89c4ff92SAndroid Build Coastguard Worker { 0, armnn::ConstTensor(inputTensorInfo0, inputData0.data()) }, 966*89c4ff92SAndroid Build Coastguard Worker { 1, armnn::ConstTensor(inputTensorInfo1, inputData1.data()) }, 967*89c4ff92SAndroid Build Coastguard Worker { 2, armnn::ConstTensor(inputTensorInfo2, inputData2.data()) } 968*89c4ff92SAndroid Build Coastguard Worker }; 969*89c4ff92SAndroid Build Coastguard Worker OutputTensors outputTensors 970*89c4ff92SAndroid Build Coastguard Worker { 971*89c4ff92SAndroid Build Coastguard Worker { 0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data()) } 972*89c4ff92SAndroid Build Coastguard Worker }; 973*89c4ff92SAndroid Build Coastguard Worker 974*89c4ff92SAndroid Build Coastguard Worker runtime->GetProfiler(netId)->EnableProfiling(true); 975*89c4ff92SAndroid Build Coastguard Worker 976*89c4ff92SAndroid Build Coastguard Worker // Do the inference 977*89c4ff92SAndroid Build Coastguard Worker runtime->EnqueueWorkload(netId, inputTensors, outputTensors); 978*89c4ff92SAndroid Build Coastguard Worker 979*89c4ff92SAndroid Build Coastguard Worker // Retrieve the Profiler.Print() output to get the workload execution 980*89c4ff92SAndroid Build Coastguard Worker ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance(); 981*89c4ff92SAndroid Build Coastguard Worker std::stringstream ss; 982*89c4ff92SAndroid Build Coastguard Worker profilerManager.GetProfiler()->Print(ss);; 983*89c4ff92SAndroid Build Coastguard Worker std::string dump = ss.str(); 984*89c4ff92SAndroid Build Coastguard Worker 985*89c4ff92SAndroid Build Coastguard Worker // Executed Subtraction using GpuAcc 986*89c4ff92SAndroid Build Coastguard Worker std::size_t found = dump.find("ClSubtractionWorkload_Execute"); 987*89c4ff92SAndroid Build Coastguard Worker CHECK(found != std::string::npos); 988*89c4ff92SAndroid Build Coastguard Worker 989*89c4ff92SAndroid Build Coastguard Worker // Contain CopyMemGeneric 990*89c4ff92SAndroid Build Coastguard Worker found = dump.find("CopyMemGeneric"); 991*89c4ff92SAndroid Build Coastguard Worker CHECK(found != std::string::npos); 992*89c4ff92SAndroid Build Coastguard Worker 993*89c4ff92SAndroid Build Coastguard Worker // Check output is as expected 994*89c4ff92SAndroid Build Coastguard Worker CHECK(outputData == expectedOutput); 995*89c4ff92SAndroid Build Coastguard Worker } 996*89c4ff92SAndroid Build Coastguard Worker 997*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("NeonImportEnabledFallbackSubgraphToCl") 998*89c4ff92SAndroid Build Coastguard Worker { 999*89c4ff92SAndroid Build Coastguard Worker using namespace armnn; 1000*89c4ff92SAndroid Build Coastguard Worker 1001*89c4ff92SAndroid Build Coastguard Worker IRuntime::CreationOptions options; 1002*89c4ff92SAndroid Build Coastguard Worker IRuntimePtr runtime(IRuntime::Create(options)); 1003*89c4ff92SAndroid Build Coastguard Worker 1004*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network. 1005*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net(INetwork::Create()); 1006*89c4ff92SAndroid Build Coastguard Worker 1007*89c4ff92SAndroid Build Coastguard Worker Pooling2dDescriptor desc; 1008*89c4ff92SAndroid Build Coastguard Worker desc.m_PoolWidth = 2; 1009*89c4ff92SAndroid Build Coastguard Worker desc.m_PoolHeight = 2; 1010*89c4ff92SAndroid Build Coastguard Worker desc.m_StrideX = 2; 1011*89c4ff92SAndroid Build Coastguard Worker desc.m_StrideY = 2; 1012*89c4ff92SAndroid Build Coastguard Worker 1013*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input0 = net->AddInputLayer(0, "input0"); 1014*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input1 = net->AddInputLayer(1, "input1"); 1015*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input2 = net->AddInputLayer(2, "input2"); 1016*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* add = net->AddElementwiseBinaryLayer(BinaryOperation::Add, "add"); 1017*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* sub = net->AddElementwiseBinaryLayer(BinaryOperation::Sub, "sub"); 1018*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* pooling = net->AddPooling2dLayer(desc, "pooling"); 1019*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* output = net->AddOutputLayer(0, "output"); 1020*89c4ff92SAndroid Build Coastguard Worker 1021*89c4ff92SAndroid Build Coastguard Worker input0->GetOutputSlot(0).Connect(add->GetInputSlot(0)); 1022*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot(0).Connect(add->GetInputSlot(1)); 1023*89c4ff92SAndroid Build Coastguard Worker input2->GetOutputSlot(0).Connect(sub->GetInputSlot(0)); 1024*89c4ff92SAndroid Build Coastguard Worker add->GetOutputSlot(0).Connect(sub->GetInputSlot(1)); 1025*89c4ff92SAndroid Build Coastguard Worker sub->GetOutputSlot(0).Connect(pooling->GetInputSlot(0)); 1026*89c4ff92SAndroid Build Coastguard Worker pooling->GetOutputSlot(0).Connect(output->GetInputSlot(0)); 1027*89c4ff92SAndroid Build Coastguard Worker 1028*89c4ff92SAndroid Build Coastguard Worker TensorInfo info = TensorInfo({ 1, 2, 4, 2 }, DataType::Float32); 1029*89c4ff92SAndroid Build Coastguard Worker TensorInfo poolingInfo = TensorInfo({ 1, 2, 2, 1 }, DataType::Float32); 1030*89c4ff92SAndroid Build Coastguard Worker 1031*89c4ff92SAndroid Build Coastguard Worker input0->GetOutputSlot(0).SetTensorInfo(info); 1032*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot(0).SetTensorInfo(info); 1033*89c4ff92SAndroid Build Coastguard Worker input2->GetOutputSlot(0).SetTensorInfo(info); 1034*89c4ff92SAndroid Build Coastguard Worker add->GetOutputSlot(0).SetTensorInfo(info); 1035*89c4ff92SAndroid Build Coastguard Worker sub->GetOutputSlot(0).SetTensorInfo(info); 1036*89c4ff92SAndroid Build Coastguard Worker pooling->GetOutputSlot(0).SetTensorInfo(poolingInfo); 1037*89c4ff92SAndroid Build Coastguard Worker 1038*89c4ff92SAndroid Build Coastguard Worker std::vector<BackendId> backends = { Compute::CpuAcc, Compute::GpuAcc }; 1039*89c4ff92SAndroid Build Coastguard Worker // Use BackendSelectionHint to specify GpuAcc for Subtraction layer 1040*89c4ff92SAndroid Build Coastguard Worker sub->BackendSelectionHint(backends[1]); 1041*89c4ff92SAndroid Build Coastguard Worker 1042*89c4ff92SAndroid Build Coastguard Worker // optimize the network 1043*89c4ff92SAndroid Build Coastguard Worker OptimizerOptionsOpaque optOptions; 1044*89c4ff92SAndroid Build Coastguard Worker optOptions.SetImportEnabled(true); 1045*89c4ff92SAndroid Build Coastguard Worker optOptions.SetExportEnabled(true); 1046*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optOptions); 1047*89c4ff92SAndroid Build Coastguard Worker 1048*89c4ff92SAndroid Build Coastguard Worker Graph& graph = GetGraphForTesting(optNet.get()); 1049*89c4ff92SAndroid Build Coastguard Worker 1050*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer0 = GetFirstLayerWithName(graph, "input0"); 1051*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer1 = GetFirstLayerWithName(graph, "input1"); 1052*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer2 = GetFirstLayerWithName(graph, "input2"); 1053*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer3 = GetFirstLayerWithName(graph, "add"); 1054*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer4 = GetFirstLayerWithName(graph, "[ add (0) -> sub (1) ]"); 1055*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer5 = GetFirstLayerWithName(graph, "sub"); 1056*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer6 = GetFirstLayerWithName(graph, "[ sub (0) -> pooling (0) ]"); 1057*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer7 = GetFirstLayerWithName(graph, "pooling"); 1058*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer8 = GetFirstLayerWithName(graph, "output"); 1059*89c4ff92SAndroid Build Coastguard Worker 1060*89c4ff92SAndroid Build Coastguard Worker // Checks order is valid. 1061*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer0, layer1)); 1062*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer1, layer2)); 1063*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer2, layer3)); 1064*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer3, layer4)); 1065*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer4, layer5)); 1066*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer5, layer6)); 1067*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer6, layer7)); 1068*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer7, layer8)); 1069*89c4ff92SAndroid Build Coastguard Worker 1070*89c4ff92SAndroid Build Coastguard Worker // Use memory import between backends 1071*89c4ff92SAndroid Build Coastguard Worker CHECK((layer4->GetType() == LayerType::MemCopy)); 1072*89c4ff92SAndroid Build Coastguard Worker CHECK((layer6->GetType() == LayerType::MemCopy)); 1073*89c4ff92SAndroid Build Coastguard Worker 1074*89c4ff92SAndroid Build Coastguard Worker // Correctly use backend hint 1075*89c4ff92SAndroid Build Coastguard Worker CHECK((layer5->GetBackendId() == Compute::GpuAcc )); 1076*89c4ff92SAndroid Build Coastguard Worker 1077*89c4ff92SAndroid Build Coastguard Worker // Load it into the runtime. It should pass. 1078*89c4ff92SAndroid Build Coastguard Worker NetworkId netId; 1079*89c4ff92SAndroid Build Coastguard Worker std::string ignoredErrorMessage; 1080*89c4ff92SAndroid Build Coastguard Worker 1081*89c4ff92SAndroid Build Coastguard Worker INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc); 1082*89c4ff92SAndroid Build Coastguard Worker 1083*89c4ff92SAndroid Build Coastguard Worker runtime->LoadNetwork(netId, std::move(optNet), ignoredErrorMessage, networkProperties); 1084*89c4ff92SAndroid Build Coastguard Worker 1085*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output 1086*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData0 1087*89c4ff92SAndroid Build Coastguard Worker { 1088*89c4ff92SAndroid Build Coastguard Worker 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 6.0f, 1.0f, 1.0f, 2.0f, 2.0f 1089*89c4ff92SAndroid Build Coastguard Worker }; 1090*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData1 1091*89c4ff92SAndroid Build Coastguard Worker { 1092*89c4ff92SAndroid Build Coastguard Worker 0.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 0.0f, 1.0f, 1.0f, 2.0f 1093*89c4ff92SAndroid Build Coastguard Worker }; 1094*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData2 1095*89c4ff92SAndroid Build Coastguard Worker { 1096*89c4ff92SAndroid Build Coastguard Worker 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 12.0f, 11.0f, 10.0f, 9.0f 1097*89c4ff92SAndroid Build Coastguard Worker }; 1098*89c4ff92SAndroid Build Coastguard Worker 1099*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputData(4); 1100*89c4ff92SAndroid Build Coastguard Worker 1101*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutput{ 11.0f, 3.0f, -5.0f, 11.0f }; 1102*89c4ff92SAndroid Build Coastguard Worker 1103*89c4ff92SAndroid Build Coastguard Worker // Prepare aligned data 1104*89c4ff92SAndroid Build Coastguard Worker unsigned int numElements = info.GetNumElements(); 1105*89c4ff92SAndroid Build Coastguard Worker size_t totalBytes = numElements * sizeof(float); 1106*89c4ff92SAndroid Build Coastguard Worker const size_t alignment = 64; 1107*89c4ff92SAndroid Build Coastguard Worker size_t space = totalBytes + alignment + alignment; 1108*89c4ff92SAndroid Build Coastguard Worker auto inputData = std::make_unique<uint8_t[]>(space); 1109*89c4ff92SAndroid Build Coastguard Worker void* alignedInputPtr = inputData.get(); 1110*89c4ff92SAndroid Build Coastguard Worker CHECK(std::align(alignment, totalBytes, alignedInputPtr, space)); 1111*89c4ff92SAndroid Build Coastguard Worker 1112*89c4ff92SAndroid Build Coastguard Worker auto* intputPtr = reinterpret_cast<float*>(alignedInputPtr); 1113*89c4ff92SAndroid Build Coastguard Worker std::copy(inputData2.begin(), inputData2.end(), intputPtr); 1114*89c4ff92SAndroid Build Coastguard Worker 1115*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo0 = runtime->GetInputTensorInfo(netId, 0); 1116*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo1 = runtime->GetInputTensorInfo(netId, 1); 1117*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo2 = runtime->GetInputTensorInfo(netId, 2); 1118*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo0.SetConstant(true); 1119*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo1.SetConstant(true); 1120*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo2.SetConstant(true); 1121*89c4ff92SAndroid Build Coastguard Worker 1122*89c4ff92SAndroid Build Coastguard Worker InputTensors inputTensors 1123*89c4ff92SAndroid Build Coastguard Worker { 1124*89c4ff92SAndroid Build Coastguard Worker { 0, armnn::ConstTensor(inputTensorInfo0, inputData0.data()) }, 1125*89c4ff92SAndroid Build Coastguard Worker { 1, armnn::ConstTensor(inputTensorInfo1, inputData1.data()) }, 1126*89c4ff92SAndroid Build Coastguard Worker { 2, armnn::ConstTensor(inputTensorInfo2, alignedInputPtr) } 1127*89c4ff92SAndroid Build Coastguard Worker }; 1128*89c4ff92SAndroid Build Coastguard Worker OutputTensors outputTensors 1129*89c4ff92SAndroid Build Coastguard Worker { 1130*89c4ff92SAndroid Build Coastguard Worker { 0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data()) } 1131*89c4ff92SAndroid Build Coastguard Worker }; 1132*89c4ff92SAndroid Build Coastguard Worker 1133*89c4ff92SAndroid Build Coastguard Worker runtime->GetProfiler(netId)->EnableProfiling(true); 1134*89c4ff92SAndroid Build Coastguard Worker 1135*89c4ff92SAndroid Build Coastguard Worker // Do the inference 1136*89c4ff92SAndroid Build Coastguard Worker runtime->EnqueueWorkload(netId, inputTensors, outputTensors); 1137*89c4ff92SAndroid Build Coastguard Worker 1138*89c4ff92SAndroid Build Coastguard Worker // Retrieve the Profiler.Print() output to get the workload execution 1139*89c4ff92SAndroid Build Coastguard Worker ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance(); 1140*89c4ff92SAndroid Build Coastguard Worker std::stringstream ss; 1141*89c4ff92SAndroid Build Coastguard Worker profilerManager.GetProfiler()->Print(ss);; 1142*89c4ff92SAndroid Build Coastguard Worker std::string dump = ss.str(); 1143*89c4ff92SAndroid Build Coastguard Worker 1144*89c4ff92SAndroid Build Coastguard Worker // Executed Subtraction using GpuAcc 1145*89c4ff92SAndroid Build Coastguard Worker std::size_t found = dump.find("ClSubtractionWorkload_Execute"); 1146*89c4ff92SAndroid Build Coastguard Worker CHECK(found != std::string::npos); 1147*89c4ff92SAndroid Build Coastguard Worker 1148*89c4ff92SAndroid Build Coastguard Worker // Correctly switch back to CpuAcc 1149*89c4ff92SAndroid Build Coastguard Worker found = dump.find("NeonPooling2dWorkload_Execute"); 1150*89c4ff92SAndroid Build Coastguard Worker CHECK(found != std::string::npos); 1151*89c4ff92SAndroid Build Coastguard Worker 1152*89c4ff92SAndroid Build Coastguard Worker // Contain CopyMemGeneric 1153*89c4ff92SAndroid Build Coastguard Worker found = dump.find("CopyMemGeneric"); 1154*89c4ff92SAndroid Build Coastguard Worker CHECK(found != std::string::npos); 1155*89c4ff92SAndroid Build Coastguard Worker 1156*89c4ff92SAndroid Build Coastguard Worker // Contains SyncMemGeneric for output 1157*89c4ff92SAndroid Build Coastguard Worker found = dump.find("SyncMemGeneric"); 1158*89c4ff92SAndroid Build Coastguard Worker CHECK(found != std::string::npos); 1159*89c4ff92SAndroid Build Coastguard Worker 1160*89c4ff92SAndroid Build Coastguard Worker // Check output is as expected 1161*89c4ff92SAndroid Build Coastguard Worker CHECK(outputData == expectedOutput); 1162*89c4ff92SAndroid Build Coastguard Worker runtime->UnloadNetwork(netId); 1163*89c4ff92SAndroid Build Coastguard Worker } 1164*89c4ff92SAndroid Build Coastguard Worker 1165*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("NeonImportDisableFallbackSubgraphToCl") 1166*89c4ff92SAndroid Build Coastguard Worker { 1167*89c4ff92SAndroid Build Coastguard Worker using namespace armnn; 1168*89c4ff92SAndroid Build Coastguard Worker 1169*89c4ff92SAndroid Build Coastguard Worker IRuntime::CreationOptions options; 1170*89c4ff92SAndroid Build Coastguard Worker IRuntimePtr runtime(IRuntime::Create(options)); 1171*89c4ff92SAndroid Build Coastguard Worker 1172*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network. 1173*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net(INetwork::Create()); 1174*89c4ff92SAndroid Build Coastguard Worker 1175*89c4ff92SAndroid Build Coastguard Worker Pooling2dDescriptor desc; 1176*89c4ff92SAndroid Build Coastguard Worker 1177*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input0 = net->AddInputLayer(0, "input0"); 1178*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input1 = net->AddInputLayer(1, "input1"); 1179*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input2 = net->AddInputLayer(2, "input2"); 1180*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* add = net->AddElementwiseBinaryLayer(BinaryOperation::Add, "add"); 1181*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* sub = net->AddElementwiseBinaryLayer(BinaryOperation::Sub, "sub"); 1182*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* pooling = net->AddPooling2dLayer(desc, "pooling"); 1183*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* output = net->AddOutputLayer(0, "output"); 1184*89c4ff92SAndroid Build Coastguard Worker 1185*89c4ff92SAndroid Build Coastguard Worker input0->GetOutputSlot(0).Connect(add->GetInputSlot(0)); 1186*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot(0).Connect(add->GetInputSlot(1)); 1187*89c4ff92SAndroid Build Coastguard Worker input2->GetOutputSlot(0).Connect(sub->GetInputSlot(0)); 1188*89c4ff92SAndroid Build Coastguard Worker add->GetOutputSlot(0).Connect(sub->GetInputSlot(1)); 1189*89c4ff92SAndroid Build Coastguard Worker sub->GetOutputSlot(0).Connect(pooling->GetInputSlot(0)); 1190*89c4ff92SAndroid Build Coastguard Worker pooling->GetOutputSlot(0).Connect(output->GetInputSlot(0)); 1191*89c4ff92SAndroid Build Coastguard Worker 1192*89c4ff92SAndroid Build Coastguard Worker TensorInfo info = TensorInfo({ 1, 2, 3, 2 }, DataType::Float32); 1193*89c4ff92SAndroid Build Coastguard Worker TensorInfo poolingInfo = TensorInfo({ 1, 2, 1, 1 }, DataType::Float32); 1194*89c4ff92SAndroid Build Coastguard Worker 1195*89c4ff92SAndroid Build Coastguard Worker input0->GetOutputSlot(0).SetTensorInfo(info); 1196*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot(0).SetTensorInfo(info); 1197*89c4ff92SAndroid Build Coastguard Worker input2->GetOutputSlot(0).SetTensorInfo(info); 1198*89c4ff92SAndroid Build Coastguard Worker add->GetOutputSlot(0).SetTensorInfo(info); 1199*89c4ff92SAndroid Build Coastguard Worker sub->GetOutputSlot(0).SetTensorInfo(info); 1200*89c4ff92SAndroid Build Coastguard Worker pooling->GetOutputSlot(0).SetTensorInfo(poolingInfo); 1201*89c4ff92SAndroid Build Coastguard Worker 1202*89c4ff92SAndroid Build Coastguard Worker std::vector<BackendId> backends = { Compute::CpuAcc, Compute::GpuAcc }; 1203*89c4ff92SAndroid Build Coastguard Worker // Use BackendSelectionHint to specify GpuAcc for Subtraction layer 1204*89c4ff92SAndroid Build Coastguard Worker sub->BackendSelectionHint(backends[1]); 1205*89c4ff92SAndroid Build Coastguard Worker 1206*89c4ff92SAndroid Build Coastguard Worker // optimize the network 1207*89c4ff92SAndroid Build Coastguard Worker OptimizerOptionsOpaque optOptions; 1208*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec(), optOptions); 1209*89c4ff92SAndroid Build Coastguard Worker 1210*89c4ff92SAndroid Build Coastguard Worker Graph& graph = GetGraphForTesting(optNet.get()); 1211*89c4ff92SAndroid Build Coastguard Worker 1212*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer0 = GetFirstLayerWithName(graph, "input0"); 1213*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer1 = GetFirstLayerWithName(graph, "input1"); 1214*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer2 = GetFirstLayerWithName(graph, "input2"); 1215*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer3 = GetFirstLayerWithName(graph, "add"); 1216*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer4 = GetFirstLayerWithName(graph, "[ add (0) -> sub (1) ]"); 1217*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer5 = GetFirstLayerWithName(graph, "sub"); 1218*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer6 = GetFirstLayerWithName(graph, "[ sub (0) -> pooling (0) ]"); 1219*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer7 = GetFirstLayerWithName(graph, "pooling"); 1220*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* const layer8 = GetFirstLayerWithName(graph, "output"); 1221*89c4ff92SAndroid Build Coastguard Worker 1222*89c4ff92SAndroid Build Coastguard Worker // Checks order is valid. 1223*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer0, layer1)); 1224*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer1, layer2)); 1225*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer2, layer3)); 1226*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer3, layer4)); 1227*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer4, layer5)); 1228*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer5, layer6)); 1229*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer6, layer7)); 1230*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckOrder(graph, layer7, layer8)); 1231*89c4ff92SAndroid Build Coastguard Worker 1232*89c4ff92SAndroid Build Coastguard Worker // Use memory import between backends 1233*89c4ff92SAndroid Build Coastguard Worker CHECK((layer4->GetType() == LayerType::MemCopy)); 1234*89c4ff92SAndroid Build Coastguard Worker CHECK((layer6->GetType() == LayerType::MemCopy)); 1235*89c4ff92SAndroid Build Coastguard Worker 1236*89c4ff92SAndroid Build Coastguard Worker // Correctly use backend hint 1237*89c4ff92SAndroid Build Coastguard Worker CHECK((layer5->GetBackendId() == Compute::GpuAcc )); 1238*89c4ff92SAndroid Build Coastguard Worker 1239*89c4ff92SAndroid Build Coastguard Worker // Load it into the runtime. It should pass. 1240*89c4ff92SAndroid Build Coastguard Worker NetworkId netId; 1241*89c4ff92SAndroid Build Coastguard Worker runtime->LoadNetwork(netId, std::move(optNet)); 1242*89c4ff92SAndroid Build Coastguard Worker 1243*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output 1244*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData0 1245*89c4ff92SAndroid Build Coastguard Worker { 1246*89c4ff92SAndroid Build Coastguard Worker 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 6.0f 1247*89c4ff92SAndroid Build Coastguard Worker }; 1248*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData1 1249*89c4ff92SAndroid Build Coastguard Worker { 1250*89c4ff92SAndroid Build Coastguard Worker 0.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f 1251*89c4ff92SAndroid Build Coastguard Worker }; 1252*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData2 1253*89c4ff92SAndroid Build Coastguard Worker { 1254*89c4ff92SAndroid Build Coastguard Worker 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f 1255*89c4ff92SAndroid Build Coastguard Worker }; 1256*89c4ff92SAndroid Build Coastguard Worker 1257*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputData(2); 1258*89c4ff92SAndroid Build Coastguard Worker 1259*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutput{ 11.0f, -1.0f }; 1260*89c4ff92SAndroid Build Coastguard Worker 1261*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo0 = runtime->GetInputTensorInfo(netId, 0); 1262*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo1 = runtime->GetInputTensorInfo(netId, 1); 1263*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo2 = runtime->GetInputTensorInfo(netId, 2); 1264*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo0.SetConstant(true); 1265*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo1.SetConstant(true); 1266*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo2.SetConstant(true); 1267*89c4ff92SAndroid Build Coastguard Worker 1268*89c4ff92SAndroid Build Coastguard Worker InputTensors inputTensors 1269*89c4ff92SAndroid Build Coastguard Worker { 1270*89c4ff92SAndroid Build Coastguard Worker { 0, armnn::ConstTensor(inputTensorInfo0, inputData0.data()) }, 1271*89c4ff92SAndroid Build Coastguard Worker { 1, armnn::ConstTensor(inputTensorInfo1, inputData1.data()) }, 1272*89c4ff92SAndroid Build Coastguard Worker { 2, armnn::ConstTensor(inputTensorInfo2, inputData2.data()) } 1273*89c4ff92SAndroid Build Coastguard Worker }; 1274*89c4ff92SAndroid Build Coastguard Worker OutputTensors outputTensors 1275*89c4ff92SAndroid Build Coastguard Worker { 1276*89c4ff92SAndroid Build Coastguard Worker { 0,armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data()) } 1277*89c4ff92SAndroid Build Coastguard Worker }; 1278*89c4ff92SAndroid Build Coastguard Worker 1279*89c4ff92SAndroid Build Coastguard Worker runtime->GetProfiler(netId)->EnableProfiling(true); 1280*89c4ff92SAndroid Build Coastguard Worker 1281*89c4ff92SAndroid Build Coastguard Worker // Do the inference 1282*89c4ff92SAndroid Build Coastguard Worker runtime->EnqueueWorkload(netId, inputTensors, outputTensors); 1283*89c4ff92SAndroid Build Coastguard Worker 1284*89c4ff92SAndroid Build Coastguard Worker // Retrieve the Profiler.Print() output to get the workload execution 1285*89c4ff92SAndroid Build Coastguard Worker ProfilerManager& profilerManager = armnn::ProfilerManager::GetInstance(); 1286*89c4ff92SAndroid Build Coastguard Worker std::stringstream ss; 1287*89c4ff92SAndroid Build Coastguard Worker profilerManager.GetProfiler()->Print(ss);; 1288*89c4ff92SAndroid Build Coastguard Worker std::string dump = ss.str(); 1289*89c4ff92SAndroid Build Coastguard Worker 1290*89c4ff92SAndroid Build Coastguard Worker // Executed Subtraction using GpuAcc 1291*89c4ff92SAndroid Build Coastguard Worker std::size_t found = dump.find("ClSubtractionWorkload_Execute"); 1292*89c4ff92SAndroid Build Coastguard Worker CHECK(found != std::string::npos); 1293*89c4ff92SAndroid Build Coastguard Worker 1294*89c4ff92SAndroid Build Coastguard Worker // Correctly switch back to CpuAcc 1295*89c4ff92SAndroid Build Coastguard Worker found = dump.find("NeonPooling2dWorkload_Execute"); 1296*89c4ff92SAndroid Build Coastguard Worker CHECK(found != std::string::npos); 1297*89c4ff92SAndroid Build Coastguard Worker 1298*89c4ff92SAndroid Build Coastguard Worker // Contain CopyMemGeneric 1299*89c4ff92SAndroid Build Coastguard Worker found = dump.find("CopyMemGeneric"); 1300*89c4ff92SAndroid Build Coastguard Worker CHECK(found != std::string::npos); 1301*89c4ff92SAndroid Build Coastguard Worker 1302*89c4ff92SAndroid Build Coastguard Worker // Check output is as expected 1303*89c4ff92SAndroid Build Coastguard Worker CHECK(outputData == expectedOutput); 1304*89c4ff92SAndroid Build Coastguard Worker } 1305*89c4ff92SAndroid Build Coastguard Worker #endif 1306*89c4ff92SAndroid Build Coastguard Worker 1307*89c4ff92SAndroid Build Coastguard Worker } 1308